扩散模型能批处理吗?为什么单次推理单批生成速度会线性增长?(Diffusion in Parallel)

如果大家使用扩散模型进行推理会发现一个现象:似乎输入多个prompt和输入一个prompt需要的时间差距很大,这不符合Batch常理。

image-20240728191635281

这个表是我们在单卡3090上进行测试不同prompt的结果,最初也很反直觉,似乎batch在Unet里失效了。起初我以前是pipline内部是单线程的, 于是我去拆解了Pipline里面的包发现没有问题,之后在模型预测里面去测试了一下时间。

import time

start_time = time.time()

                # predict the noise residual
noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=timestep_cond,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]
end_time = time.time()

发现随着prompt的增多,Unet预测的时间都在不停增加。What ?

查找了很多资料,反复测试代码,最终在 generate images in parallel 找到了答案。

不能并行的原因是受限于显卡的计算性能,在芯片进行处理的时候到极限的情况下,不能同时处理多个数据。

最后进行测试,生成8张图像,

8批次,每批1张

image-20240728193240549

1批次,每批8张

image-20240728193317955

最终证明,扩散模型满足批次处理规律。但受限于gpu的性能影响,不能同时处理成批数据的能力下,将会变成串行,但批量处理的速度还是会稍微快。

批量

image-20240728193453934

非批量

image-20240728193750949