扩散模型能批处理吗?为什么单次推理单批生成速度会线性增长?(Diffusion in Parallel)
如果大家使用扩散模型进行推理会发现一个现象:似乎输入多个prompt和输入一个prompt需要的时间差距很大,这不符合Batch常理。
这个表是我们在单卡3090上进行测试不同prompt的结果,最初也很反直觉,似乎batch在Unet里失效了。起初我以前是pipline内部是单线程的, 于是我去拆解了Pipline里面的包发现没有问题,之后在模型预测里面去测试了一下时间。
import time
start_time = time.time()
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
end_time = time.time()
发现随着prompt的增多,Unet预测的时间都在不停增加。What ?
查找了很多资料,反复测试代码,最终在 generate images in parallel 找到了答案。
不能并行的原因是受限于显卡的计算性能,在芯片进行处理的时候到极限的情况下,不能同时处理多个数据。
最后进行测试,生成8张图像,
8批次,每批1张
1批次,每批8张
最终证明,扩散模型满足批次处理规律。但受限于gpu的性能影响,不能同时处理成批数据的能力下,将会变成串行,但批量处理的速度还是会稍微快。
批量
非批量