By Sayak Paul

A diffusion system is not a single monolithic model, it’s about a series of models that are interconnected. For example, the Stable Diffusion family of models (and alike) has a text encoder, a denoiser, and a VAE. There is another component that is non-parametric — the noise scheduler.

Modern diffusion systems like Flux involve multiple text encoders, a large diffusion transformer as the denoiser, and finally a VAE. Naturally, performing inference with such a system comes off as a daunting task on consumer GPUs.

In this document, we show how to decouple the different stages of a diffusion system (text encoding, denoising, and decoding) and split their computation across multiple (consumer) GPUs when available. Here, we assume access to two 16GB GPUs.

Our testbed is the Flux.1-Dev model. It has:

If we have access to two 16GB GPUs, we won’t be able to perform inference naively. We will have to either use some form of quantization or some other model-sharding tricks. This document does the latter.

Untitled.png

Text encoding

Given the input text prompt, we first compute the text embeddings. For this, we only need the two text encoders and their respective tokenizers. We keep the two text encoders on two GPUs. The code for this is fairly simple:

Denoising

Once the text embeddings are computed, we can free the text encoders and claim the GPU memory to load our diffusion transformer. As mentioned earlier, it has 12.5 B parameters, making it the largest open diffusion transformer available at the time of writing. Even with reduced precisions like BFloat16, we won’t be able to load it in 16GBs. So, we will split it into two 16GB GPUs.

We first load the transformer with device_map="auto" to let accelerate figure out how to best split the model across the GPUs, CPU, and disk. We’d want to reduce the CPU and disk movements as much as possible, though.

If we print transformer.hf_device_map from the code above, it will show the device-wise split of each module of transformer.

And then we can incorporate the transformer for denoising. We load the FluxPipeline with our transformer but we keep the other model-level components (such as text encoders and the VAE) to None.