By Sayak Paul
A diffusion system is not a single monolithic model, it’s about a series of models that are interconnected. For example, the Stable Diffusion family of models (and alike) has a text encoder, a denoiser, and a VAE. There is another component that is non-parametric — the noise scheduler.
Modern diffusion systems like Flux involve multiple text encoders, a large diffusion transformer as the denoiser, and finally a VAE. Naturally, performing inference with such a system comes off as a daunting task on consumer GPUs.
In this document, we show how to decouple the different stages of a diffusion system (text encoding, denoising, and decoding) and split their computation across multiple (consumer) GPUs when available. Here, we assume access to two 16GB GPUs.
Our testbed is the Flux.1-Dev
model. It has:
If we have access to two 16GB GPUs, we won’t be able to perform inference naively. We will have to either use some form of quantization or some other model-sharding tricks. This document does the latter.
Given the input text prompt, we first compute the text embeddings. For this, we only need the two text encoders and their respective tokenizers. We keep the two text encoders on two GPUs. The code for this is fairly simple:
Once the text embeddings are computed, we can free the text encoders and claim the GPU memory to load our diffusion transformer. As mentioned earlier, it has 12.5 B parameters, making it the largest open diffusion transformer available at the time of writing. Even with reduced precisions like BFloat16, we won’t be able to load it in 16GBs. So, we will split it into two 16GB GPUs.
We first load the transformer with device_map="auto"
to let accelerate
figure out how to best split the model across the GPUs, CPU, and disk. We’d want to reduce the CPU and disk movements as much as possible, though.
If we print transformer.hf_device_map
from the code above, it will show the device-wise split of each module of transformer
.
And then we can incorporate the transformer
for denoising. We load the FluxPipeline
with our transformer
but we keep the other model-level components (such as text encoders and the VAE) to None
.