7.2.2. Example: Image Generation From Text Prompt

Application Example: Generating Images Based on Prompt “Fujisan”

Note: VAE decoder support is currently unavailable.

Execution Method

$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./run_stable_diffusion.sh --skip_vae_decoder_compilation --prompt "Fujisan" --device mncore2:auto

Expected Output

The generated image will appear in /tmp/mlsdk_stable_diffusion_out/output.png.

"Fujisan" image generated on MN-Core 2

Fig. 7.2 “Fujisan” image generated on MN-Core 2

Script

Listing 7.13 /opt/pfn/pfcomp/codegen/MLSDK/examples/run_stable_diffusion.sh
 1#!/bin/bash
 2set -ex -o pipefail
 3
 4CURRENT_DIR=$(realpath $(dirname $0))
 5CODEGEN_DIR=$(realpath ${CURRENT_DIR}/../../)
 6BUILD_DIR=${BUILD_DIR:-${CODEGEN_DIR}/build}
 7
 8venv_dir=/tmp/run_stable_diffusion_venv
 9if [[ ! -d $venv_dir ]]; then
10    python3 -m venv /tmp/run_stable_diffusion_venv --system-site-packages
11    source /tmp/run_stable_diffusion_venv/bin/activate
12    # Fix versions to avoid breaking changes
13    pip install diffusers==0.8.0 transformers==4.44.0 huggingface-hub==0.24.7
14fi
15source /tmp/run_stable_diffusion_venv/bin/activate
16
17source "${BUILD_DIR}/codegen_preloads.sh"
18source "${BUILD_DIR}/codegen_pythonpath.sh"
19
20export MNCORE_USE_EXTERNAL_DATA_FORMAT=1
21
22# PFN_INTERNAL_START
23# We cannot expose this part because it depends on PFS.
24source ${CODEGEN_DIR}/python_trainer/examples/find_dataset.sh
25
26MODEL_DIR=/tmp/stable-diffusion-v1-4
27mkdir -p $MODEL_DIR
28aws s3 --endpoint=${S3_ENDPOINT} cp --no-progress ${STABLE_DIFFUSION_V1_4} /tmp/ --cli-read-timeout 1800 --cli-connect-timeout 1800
29unzip -o /tmp/stable-diffusion-v1-4.zip -d /tmp
30python3 ${CURRENT_DIR}/stable_diffusion.py --model $MODEL_DIR "$@"
31exit 0
32# PFN_INTERNAL_END
33
34python3 ${CURRENT_DIR}/stable_diffusion.py "$@"
Listing 7.14 /opt/pfn/pfcomp/codegen/MLSDK/examples/stable_diffusion.py
  1import argparse
  2import inspect
  3from typing import List, Optional, Union
  4
  5import torch
  6from diffusers import (
  7    AutoencoderKL,
  8    DDIMScheduler,
  9    LMSDiscreteScheduler,
 10    PNDMScheduler,
 11    StableDiffusionPipeline,
 12    UNet2DConditionModel,
 13)
 14from diffusers.pipelines.stable_diffusion.safety_checker import (
 15    StableDiffusionSafetyChecker,
 16)
 17from mlsdk import CacheOptions, Context, MNDevice, storage
 18from tqdm.auto import tqdm
 19from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
 20
 21
 22class StableDiffusionMNCorePipeline(StableDiffusionPipeline):
 23    def __init__(  # noqa: CFQ002
 24        self,
 25        vae: AutoencoderKL,
 26        text_encoder: CLIPTextModel,
 27        tokenizer: CLIPTokenizer,
 28        unet: UNet2DConditionModel,
 29        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
 30        safety_checker: StableDiffusionSafetyChecker,
 31        feature_extractor: CLIPFeatureExtractor,
 32    ):
 33        super().__init__(
 34            vae,
 35            text_encoder,
 36            tokenizer,
 37            unet,
 38            scheduler,
 39            safety_checker,
 40            feature_extractor,
 41        )
 42        self.compiled_text_encoder = None
 43        self.compiled_unet = None
 44        self.compiled_vae_decoder = None
 45
 46    def compile_encoder(
 47        self,
 48        context,
 49        batch_size: int,
 50        out_dir: str,
 51        num_compiler_threads: Optional[int] = None,
 52    ):
 53        seq_len = self.tokenizer.model_max_length
 54
 55        def text_encoder_fn(inp):
 56            input_ids = inp["input_ids"]
 57            position_ids = inp["position_ids"]
 58            embeddings = self.text_encoder(input_ids, position_ids=position_ids)[0]
 59            return {"embeddings": embeddings}
 60
 61        context.registry.register("text_encoder", self.text_encoder)
 62
 63        return context.compile(
 64            text_encoder_fn,
 65            {
 66                "input_ids": torch.zeros((batch_size, seq_len), dtype=torch.int64).to(
 67                    self.device
 68                ),
 69                # fx2onnx failed to process the buffer which is created by view
 70                # So pass position_ids explicitly
 71                "position_ids": torch.arange(seq_len).expand((1, -1)).to(self.device),
 72            },
 73            storage.path(out_dir + "/text_encoder"),
 74            export_kwargs={"use_fx2onnx": True},
 75            cache_options=CacheOptions(out_dir + "/encoder_cache"),
 76            num_compiler_threads=num_compiler_threads,
 77        )
 78
 79    def compile_unet(  # noqa: CFQ002
 80        self,
 81        context,
 82        batch_size: int,
 83        height: int,
 84        width: int,
 85        guidance_scale: float,
 86        out_dir: str,
 87        num_compiler_threads: Optional[int] = None,
 88    ):
 89        seq_len = self.tokenizer.model_max_length
 90        do_classifier_free_guidance = guidance_scale > 1.0
 91
 92        def unet_fn(inp):
 93            latents = inp["latents"]
 94            timesteps = inp["timesteps"]
 95            text_embeddings = inp["text_embeddings"]
 96            noise_pred = self.unet(
 97                latents, timesteps, encoder_hidden_states=text_embeddings
 98            ).sample
 99            return {"sample": noise_pred}
100
101        context.registry.register("unet", self.unet)
102
103        return context.compile(
104            unet_fn,
105            {
106                "latents": torch.zeros(
107                    (
108                        batch_size * 2 if do_classifier_free_guidance else 1,
109                        self.unet.in_channels,
110                        height // 8,
111                        width // 8,
112                    )
113                ).to(self.device),
114                "timesteps": torch.tensor([0], dtype=torch.long),
115                "text_embeddings": torch.zeros(
116                    (
117                        batch_size * 2 if do_classifier_free_guidance else 1,
118                        seq_len,
119                        self.text_encoder.config.hidden_size,
120                    )
121                ),
122            },
123            storage.path(out_dir + "/unet"),
124            export_kwargs={"use_fx2onnx": True},
125            cache_options=CacheOptions(out_dir + "/unet_cache"),
126            num_compiler_threads=num_compiler_threads,
127        )
128
129    def compile_vae_decoder(  # noqa: CFQ002
130        self,
131        context,
132        batch_size: int,
133        height: int,
134        width: int,
135        out_dir: str,
136        num_compiler_threads: Optional[int] = None,
137    ):
138        def vae_decoder_fn(inp):
139            z = inp["z"]
140            image = self.vae.decode(z).sample
141            return {"image": image}
142
143        context.registry.register("vae_post_quant_conv", self.vae.post_quant_conv)
144        context.registry.register("vae_decoder", self.vae.decoder)
145
146        return context.compile(
147            vae_decoder_fn,
148            {
149                "z": torch.zeros(
150                    (batch_size, self.unet.in_channels, height // 8, width // 8),
151                ).to(self.device),
152            },
153            storage.path(out_dir + "/vae_decoder"),
154            export_kwargs={"use_fx2onnx": True},
155            cache_options=CacheOptions(out_dir + "/vae_decoder_cache"),
156            num_compiler_threads=num_compiler_threads,
157        )
158
159    def compile(  # noqa: CFQ002
160        self,
161        *,
162        batch_size: int,
163        device: str,
164        height: int = 512,
165        width: int = 512,
166        guidance_scale: float = 7.5,
167        out_dir: str = "/tmp/mlsdk_stable_diffusion_out",
168        skip_text_encoder_compilation: bool = False,
169        skip_unet_compilation: bool = False,
170        skip_vae_decoder_compilation: bool = False,
171        num_compiler_threads: Optional[int] = None,
172    ):
173        device = MNDevice(device)
174        context = Context(device)
175        Context.switch_context(context)
176
177        if not skip_text_encoder_compilation:
178            self.compiled_text_encoder = self.compile_encoder(
179                context,
180                batch_size,
181                out_dir=out_dir,
182                num_compiler_threads=num_compiler_threads,
183            )
184        if not skip_unet_compilation:
185            self.compiled_unet = self.compile_unet(
186                context,
187                batch_size,
188                height,
189                width,
190                guidance_scale,
191                out_dir=out_dir,
192                num_compiler_threads=num_compiler_threads,
193            )
194        if not skip_vae_decoder_compilation:
195            self.compiled_vae_decoder = self.compile_vae_decoder(
196                context,
197                batch_size,
198                height,
199                width,
200                out_dir=out_dir,
201                num_compiler_threads=num_compiler_threads,
202            )
203
204    def infer_text_encoder(self, input_ids, position_ids=None):
205        if self.compiled_text_encoder is not None:
206            if position_ids is None:
207                position_ids = (
208                    torch.arange(self.tokenizer.model_max_length)
209                    .expand((1, -1))
210                    .to(self.device)
211                )
212            return self.compiled_text_encoder(
213                {"input_ids": input_ids, "position_ids": position_ids}
214            )["embeddings"]
215        else:
216            return self.text_encoder(input_ids)[0]
217
218    def infer_unet(self, latent_model_input, t, text_embeddings):
219        if self.compiled_unet is not None:
220            return self.compiled_unet(
221                {
222                    "latents": latent_model_input,
223                    "timesteps": torch.tensor([t], dtype=torch.long),
224                    "text_embeddings": text_embeddings,
225                }
226            )["sample"]
227        else:
228            return self.unet(
229                latent_model_input, t, encoder_hidden_states=text_embeddings
230            ).sample
231
232    def infer_vae_decode(self, z):
233        if self.compiled_vae_decoder is not None:
234            return self.compiled_vae_decoder({"z": z})["image"]
235        else:
236            return self.vae.decode(z).sample
237
238    # Ref https://github.com/huggingface/diffusers/blob/v0.2.4/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py # noqa: B950
239    # Ref https://github.com/huggingface/diffusers/blob/v0.8.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py # noqa: B950
240    @torch.no_grad()
241    def __call__(  # noqa: CFQ002,CFQ001
242        self,
243        prompt: Union[str, List[str]],
244        height: Optional[int] = 512,
245        width: Optional[int] = 512,
246        num_inference_steps: Optional[int] = 50,
247        guidance_scale: Optional[float] = 7.5,
248        eta: Optional[float] = 0.0,
249        generator: Optional[torch.Generator] = None,
250        output_type: Optional[str] = "pil",
251        **kwargs,
252    ):
253        if isinstance(prompt, str):
254            batch_size = 1
255        elif isinstance(prompt, list):
256            batch_size = len(prompt)
257        else:
258            raise ValueError(
259                f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
260            )
261
262        if height % 8 != 0 or width % 8 != 0:
263            raise ValueError(
264                f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
265            )
266
267        # get prompt text embeddings
268        text_input = self.tokenizer(
269            prompt,
270            padding="max_length",
271            max_length=self.tokenizer.model_max_length,
272            truncation=True,
273            return_tensors="pt",
274        )
275
276        text_embeddings = self.infer_text_encoder(text_input.input_ids.to(self.device))
277
278        do_classifier_free_guidance = guidance_scale > 1.0
279        if do_classifier_free_guidance:
280            max_length = text_input.input_ids.shape[-1]
281            uncond_input = self.tokenizer(
282                [""] * batch_size,
283                padding="max_length",
284                max_length=max_length,
285                return_tensors="pt",
286            )
287            uncond_embeddings = self.infer_text_encoder(
288                uncond_input.input_ids.to(self.device)
289            )
290            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
291
292        # get the intial random noise
293        latents = torch.randn(
294            (batch_size, self.unet.in_channels, height // 8, width // 8),
295            generator=generator,
296            device=self.device,
297        )
298        latents = latents * self.scheduler.init_noise_sigma
299
300        # set timesteps
301        accepts_offset = "offset" in set(
302            inspect.signature(self.scheduler.set_timesteps).parameters.keys()
303        )
304        extra_set_kwargs = {}
305        if accepts_offset:
306            extra_set_kwargs["offset"] = 1
307
308        self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
309
310        accepts_eta = "eta" in set(
311            inspect.signature(self.scheduler.step).parameters.keys()
312        )
313        extra_step_kwargs = {}
314        if accepts_eta:
315            extra_step_kwargs["eta"] = eta
316
317        for t in tqdm(self.scheduler.timesteps):
318            # expand the latents if we are doing classifier free guidance
319            latent_model_input = (
320                torch.cat([latents] * 2) if do_classifier_free_guidance else latents
321            )
322            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
323
324            # predict the noise residual
325            noise_pred = self.infer_unet(latent_model_input, t, text_embeddings)
326
327            # perform guidance
328            if do_classifier_free_guidance:
329                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
330                noise_pred = noise_pred_uncond + guidance_scale * (
331                    noise_pred_text - noise_pred_uncond
332                )
333
334            # compute the previous noisy sample x_t -> x_t-1
335            latents = self.scheduler.step(
336                noise_pred, t, latents, **extra_step_kwargs
337            ).prev_sample
338
339        # scale and decode the image latents with vae
340        latents = 1 / 0.18215 * latents
341        image = self.infer_vae_decode(latents)
342        image = (image.cpu() / 2 + 0.5).clamp(0, 1)
343        image = image.permute(0, 2, 3, 1).numpy()
344
345        # run safety checker
346        image, has_nsfw_concept = self.run_safety_checker(
347            image, self.device, text_embeddings.dtype
348        )
349
350        if output_type == "pil":
351            image = self.numpy_to_pil(image)
352
353        return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
354
355
356def main(args):
357    pipe = StableDiffusionMNCorePipeline.from_pretrained(args.model)
358    pipe.compile(
359        batch_size=1,
360        device=args.device,
361        out_dir=args.outdir,
362        skip_text_encoder_compilation=args.skip_text_encoder_compilation,
363        skip_unet_compilation=args.skip_unet_compilation,
364        skip_vae_decoder_compilation=args.skip_vae_decoder_compilation,
365        num_compiler_threads=args.num_compiler_threads,
366    )
367
368    if ":cuda:" in args.device:
369        pipe.to(args.device[args.device.find(":cuda:") + 1 :])
370
371    prompt = args.prompt
372    image = pipe(prompt)["sample"][0]
373
374    image.save(f"{args.outdir}/output.png")
375    print(f"Output image saved at {args.outdir}/output.png")
376
377
378if __name__ == "__main__":
379    parser = argparse.ArgumentParser()
380    parser.add_argument(
381        "--model",
382        type=str,
383        default="CompVis/stable-diffusion-v1-4",
384    )
385    parser.add_argument(
386        "--outdir",
387        type=str,
388        default="/tmp/mlsdk_stable_diffusion_out",
389        help="Path to store the outputs",
390    )
391    parser.add_argument("--device", default="mncore2:auto")
392    parser.add_argument(
393        "--prompt", type=str, default="a photo of an astronaut riding a horse on mars"
394    )
395    parser.add_argument(
396        "--num_compiler_threads",
397        type=int,
398        default=-1,
399        help="Number of threads to use for compilation",
400    )
401    parser.add_argument("--skip_text_encoder_compilation", action="store_true")
402    parser.add_argument("--skip_unet_compilation", action="store_true")
403    parser.add_argument("--skip_vae_decoder_compilation", action="store_true")
404
405    args = parser.parse_args()
406    main(args)