7.2.2. Example: Image Generation From Text Prompt
Application Example: Generating Images Based on Prompt “Fujisan”
Note: VAE decoder support is currently unavailable.
Execution Method
$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./run_stable_diffusion.sh --skip_vae_decoder_compilation --prompt "Fujisan" --device mncore2:auto
Expected Output
The generated image will appear in /tmp/mlsdk_stable_diffusion_out/output.png.
Fig. 7.2 “Fujisan” image generated on MN-Core 2
Script
1#!/bin/bash
2set -ex -o pipefail
3
4CURRENT_DIR=$(realpath $(dirname $0))
5CODEGEN_DIR=$(realpath ${CURRENT_DIR}/../../)
6BUILD_DIR=${BUILD_DIR:-${CODEGEN_DIR}/build}
7
8venv_dir=/tmp/run_stable_diffusion_venv
9if [[ ! -d $venv_dir ]]; then
10 python3 -m venv /tmp/run_stable_diffusion_venv --system-site-packages
11 source /tmp/run_stable_diffusion_venv/bin/activate
12 # Fix versions to avoid breaking changes
13 pip install diffusers==0.8.0 transformers==4.44.0 huggingface-hub==0.24.7
14fi
15source /tmp/run_stable_diffusion_venv/bin/activate
16
17source "${BUILD_DIR}/codegen_preloads.sh"
18source "${BUILD_DIR}/codegen_pythonpath.sh"
19
20export MNCORE_USE_EXTERNAL_DATA_FORMAT=1
21
22# PFN_INTERNAL_START
23# We cannot expose this part because it depends on PFS.
24source ${CODEGEN_DIR}/python_trainer/examples/find_dataset.sh
25
26MODEL_DIR=/tmp/stable-diffusion-v1-4
27mkdir -p $MODEL_DIR
28aws s3 --endpoint=${S3_ENDPOINT} cp --no-progress ${STABLE_DIFFUSION_V1_4} /tmp/ --cli-read-timeout 1800 --cli-connect-timeout 1800
29unzip -o /tmp/stable-diffusion-v1-4.zip -d /tmp
30python3 ${CURRENT_DIR}/stable_diffusion.py --model $MODEL_DIR "$@"
31exit 0
32# PFN_INTERNAL_END
33
34python3 ${CURRENT_DIR}/stable_diffusion.py "$@"
1import argparse
2import inspect
3from typing import List, Optional, Union
4
5import torch
6from diffusers import (
7 AutoencoderKL,
8 DDIMScheduler,
9 LMSDiscreteScheduler,
10 PNDMScheduler,
11 StableDiffusionPipeline,
12 UNet2DConditionModel,
13)
14from diffusers.pipelines.stable_diffusion.safety_checker import (
15 StableDiffusionSafetyChecker,
16)
17from mlsdk import CacheOptions, Context, MNDevice, storage
18from tqdm.auto import tqdm
19from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
20
21
22class StableDiffusionMNCorePipeline(StableDiffusionPipeline):
23 def __init__( # noqa: CFQ002
24 self,
25 vae: AutoencoderKL,
26 text_encoder: CLIPTextModel,
27 tokenizer: CLIPTokenizer,
28 unet: UNet2DConditionModel,
29 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
30 safety_checker: StableDiffusionSafetyChecker,
31 feature_extractor: CLIPFeatureExtractor,
32 ):
33 super().__init__(
34 vae,
35 text_encoder,
36 tokenizer,
37 unet,
38 scheduler,
39 safety_checker,
40 feature_extractor,
41 )
42 self.compiled_text_encoder = None
43 self.compiled_unet = None
44 self.compiled_vae_decoder = None
45
46 def compile_encoder(
47 self,
48 context,
49 batch_size: int,
50 out_dir: str,
51 num_compiler_threads: Optional[int] = None,
52 ):
53 seq_len = self.tokenizer.model_max_length
54
55 def text_encoder_fn(inp):
56 input_ids = inp["input_ids"]
57 position_ids = inp["position_ids"]
58 embeddings = self.text_encoder(input_ids, position_ids=position_ids)[0]
59 return {"embeddings": embeddings}
60
61 context.registry.register("text_encoder", self.text_encoder)
62
63 return context.compile(
64 text_encoder_fn,
65 {
66 "input_ids": torch.zeros((batch_size, seq_len), dtype=torch.int64).to(
67 self.device
68 ),
69 # fx2onnx failed to process the buffer which is created by view
70 # So pass position_ids explicitly
71 "position_ids": torch.arange(seq_len).expand((1, -1)).to(self.device),
72 },
73 storage.path(out_dir + "/text_encoder"),
74 export_kwargs={"use_fx2onnx": True},
75 cache_options=CacheOptions(out_dir + "/encoder_cache"),
76 num_compiler_threads=num_compiler_threads,
77 )
78
79 def compile_unet( # noqa: CFQ002
80 self,
81 context,
82 batch_size: int,
83 height: int,
84 width: int,
85 guidance_scale: float,
86 out_dir: str,
87 num_compiler_threads: Optional[int] = None,
88 ):
89 seq_len = self.tokenizer.model_max_length
90 do_classifier_free_guidance = guidance_scale > 1.0
91
92 def unet_fn(inp):
93 latents = inp["latents"]
94 timesteps = inp["timesteps"]
95 text_embeddings = inp["text_embeddings"]
96 noise_pred = self.unet(
97 latents, timesteps, encoder_hidden_states=text_embeddings
98 ).sample
99 return {"sample": noise_pred}
100
101 context.registry.register("unet", self.unet)
102
103 return context.compile(
104 unet_fn,
105 {
106 "latents": torch.zeros(
107 (
108 batch_size * 2 if do_classifier_free_guidance else 1,
109 self.unet.in_channels,
110 height // 8,
111 width // 8,
112 )
113 ).to(self.device),
114 "timesteps": torch.tensor([0], dtype=torch.long),
115 "text_embeddings": torch.zeros(
116 (
117 batch_size * 2 if do_classifier_free_guidance else 1,
118 seq_len,
119 self.text_encoder.config.hidden_size,
120 )
121 ),
122 },
123 storage.path(out_dir + "/unet"),
124 export_kwargs={"use_fx2onnx": True},
125 cache_options=CacheOptions(out_dir + "/unet_cache"),
126 num_compiler_threads=num_compiler_threads,
127 )
128
129 def compile_vae_decoder( # noqa: CFQ002
130 self,
131 context,
132 batch_size: int,
133 height: int,
134 width: int,
135 out_dir: str,
136 num_compiler_threads: Optional[int] = None,
137 ):
138 def vae_decoder_fn(inp):
139 z = inp["z"]
140 image = self.vae.decode(z).sample
141 return {"image": image}
142
143 context.registry.register("vae_post_quant_conv", self.vae.post_quant_conv)
144 context.registry.register("vae_decoder", self.vae.decoder)
145
146 return context.compile(
147 vae_decoder_fn,
148 {
149 "z": torch.zeros(
150 (batch_size, self.unet.in_channels, height // 8, width // 8),
151 ).to(self.device),
152 },
153 storage.path(out_dir + "/vae_decoder"),
154 export_kwargs={"use_fx2onnx": True},
155 cache_options=CacheOptions(out_dir + "/vae_decoder_cache"),
156 num_compiler_threads=num_compiler_threads,
157 )
158
159 def compile( # noqa: CFQ002
160 self,
161 *,
162 batch_size: int,
163 device: str,
164 height: int = 512,
165 width: int = 512,
166 guidance_scale: float = 7.5,
167 out_dir: str = "/tmp/mlsdk_stable_diffusion_out",
168 skip_text_encoder_compilation: bool = False,
169 skip_unet_compilation: bool = False,
170 skip_vae_decoder_compilation: bool = False,
171 num_compiler_threads: Optional[int] = None,
172 ):
173 device = MNDevice(device)
174 context = Context(device)
175 Context.switch_context(context)
176
177 if not skip_text_encoder_compilation:
178 self.compiled_text_encoder = self.compile_encoder(
179 context,
180 batch_size,
181 out_dir=out_dir,
182 num_compiler_threads=num_compiler_threads,
183 )
184 if not skip_unet_compilation:
185 self.compiled_unet = self.compile_unet(
186 context,
187 batch_size,
188 height,
189 width,
190 guidance_scale,
191 out_dir=out_dir,
192 num_compiler_threads=num_compiler_threads,
193 )
194 if not skip_vae_decoder_compilation:
195 self.compiled_vae_decoder = self.compile_vae_decoder(
196 context,
197 batch_size,
198 height,
199 width,
200 out_dir=out_dir,
201 num_compiler_threads=num_compiler_threads,
202 )
203
204 def infer_text_encoder(self, input_ids, position_ids=None):
205 if self.compiled_text_encoder is not None:
206 if position_ids is None:
207 position_ids = (
208 torch.arange(self.tokenizer.model_max_length)
209 .expand((1, -1))
210 .to(self.device)
211 )
212 return self.compiled_text_encoder(
213 {"input_ids": input_ids, "position_ids": position_ids}
214 )["embeddings"]
215 else:
216 return self.text_encoder(input_ids)[0]
217
218 def infer_unet(self, latent_model_input, t, text_embeddings):
219 if self.compiled_unet is not None:
220 return self.compiled_unet(
221 {
222 "latents": latent_model_input,
223 "timesteps": torch.tensor([t], dtype=torch.long),
224 "text_embeddings": text_embeddings,
225 }
226 )["sample"]
227 else:
228 return self.unet(
229 latent_model_input, t, encoder_hidden_states=text_embeddings
230 ).sample
231
232 def infer_vae_decode(self, z):
233 if self.compiled_vae_decoder is not None:
234 return self.compiled_vae_decoder({"z": z})["image"]
235 else:
236 return self.vae.decode(z).sample
237
238 # Ref https://github.com/huggingface/diffusers/blob/v0.2.4/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py # noqa: B950
239 # Ref https://github.com/huggingface/diffusers/blob/v0.8.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py # noqa: B950
240 @torch.no_grad()
241 def __call__( # noqa: CFQ002,CFQ001
242 self,
243 prompt: Union[str, List[str]],
244 height: Optional[int] = 512,
245 width: Optional[int] = 512,
246 num_inference_steps: Optional[int] = 50,
247 guidance_scale: Optional[float] = 7.5,
248 eta: Optional[float] = 0.0,
249 generator: Optional[torch.Generator] = None,
250 output_type: Optional[str] = "pil",
251 **kwargs,
252 ):
253 if isinstance(prompt, str):
254 batch_size = 1
255 elif isinstance(prompt, list):
256 batch_size = len(prompt)
257 else:
258 raise ValueError(
259 f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
260 )
261
262 if height % 8 != 0 or width % 8 != 0:
263 raise ValueError(
264 f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
265 )
266
267 # get prompt text embeddings
268 text_input = self.tokenizer(
269 prompt,
270 padding="max_length",
271 max_length=self.tokenizer.model_max_length,
272 truncation=True,
273 return_tensors="pt",
274 )
275
276 text_embeddings = self.infer_text_encoder(text_input.input_ids.to(self.device))
277
278 do_classifier_free_guidance = guidance_scale > 1.0
279 if do_classifier_free_guidance:
280 max_length = text_input.input_ids.shape[-1]
281 uncond_input = self.tokenizer(
282 [""] * batch_size,
283 padding="max_length",
284 max_length=max_length,
285 return_tensors="pt",
286 )
287 uncond_embeddings = self.infer_text_encoder(
288 uncond_input.input_ids.to(self.device)
289 )
290 text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
291
292 # get the intial random noise
293 latents = torch.randn(
294 (batch_size, self.unet.in_channels, height // 8, width // 8),
295 generator=generator,
296 device=self.device,
297 )
298 latents = latents * self.scheduler.init_noise_sigma
299
300 # set timesteps
301 accepts_offset = "offset" in set(
302 inspect.signature(self.scheduler.set_timesteps).parameters.keys()
303 )
304 extra_set_kwargs = {}
305 if accepts_offset:
306 extra_set_kwargs["offset"] = 1
307
308 self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
309
310 accepts_eta = "eta" in set(
311 inspect.signature(self.scheduler.step).parameters.keys()
312 )
313 extra_step_kwargs = {}
314 if accepts_eta:
315 extra_step_kwargs["eta"] = eta
316
317 for t in tqdm(self.scheduler.timesteps):
318 # expand the latents if we are doing classifier free guidance
319 latent_model_input = (
320 torch.cat([latents] * 2) if do_classifier_free_guidance else latents
321 )
322 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
323
324 # predict the noise residual
325 noise_pred = self.infer_unet(latent_model_input, t, text_embeddings)
326
327 # perform guidance
328 if do_classifier_free_guidance:
329 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
330 noise_pred = noise_pred_uncond + guidance_scale * (
331 noise_pred_text - noise_pred_uncond
332 )
333
334 # compute the previous noisy sample x_t -> x_t-1
335 latents = self.scheduler.step(
336 noise_pred, t, latents, **extra_step_kwargs
337 ).prev_sample
338
339 # scale and decode the image latents with vae
340 latents = 1 / 0.18215 * latents
341 image = self.infer_vae_decode(latents)
342 image = (image.cpu() / 2 + 0.5).clamp(0, 1)
343 image = image.permute(0, 2, 3, 1).numpy()
344
345 # run safety checker
346 image, has_nsfw_concept = self.run_safety_checker(
347 image, self.device, text_embeddings.dtype
348 )
349
350 if output_type == "pil":
351 image = self.numpy_to_pil(image)
352
353 return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
354
355
356def main(args):
357 pipe = StableDiffusionMNCorePipeline.from_pretrained(args.model)
358 pipe.compile(
359 batch_size=1,
360 device=args.device,
361 out_dir=args.outdir,
362 skip_text_encoder_compilation=args.skip_text_encoder_compilation,
363 skip_unet_compilation=args.skip_unet_compilation,
364 skip_vae_decoder_compilation=args.skip_vae_decoder_compilation,
365 num_compiler_threads=args.num_compiler_threads,
366 )
367
368 if ":cuda:" in args.device:
369 pipe.to(args.device[args.device.find(":cuda:") + 1 :])
370
371 prompt = args.prompt
372 image = pipe(prompt)["sample"][0]
373
374 image.save(f"{args.outdir}/output.png")
375 print(f"Output image saved at {args.outdir}/output.png")
376
377
378if __name__ == "__main__":
379 parser = argparse.ArgumentParser()
380 parser.add_argument(
381 "--model",
382 type=str,
383 default="CompVis/stable-diffusion-v1-4",
384 )
385 parser.add_argument(
386 "--outdir",
387 type=str,
388 default="/tmp/mlsdk_stable_diffusion_out",
389 help="Path to store the outputs",
390 )
391 parser.add_argument("--device", default="mncore2:auto")
392 parser.add_argument(
393 "--prompt", type=str, default="a photo of an astronaut riding a horse on mars"
394 )
395 parser.add_argument(
396 "--num_compiler_threads",
397 type=int,
398 default=-1,
399 help="Number of threads to use for compilation",
400 )
401 parser.add_argument("--skip_text_encoder_compilation", action="store_true")
402 parser.add_argument("--skip_unet_compilation", action="store_true")
403 parser.add_argument("--skip_vae_decoder_compilation", action="store_true")
404
405 args = parser.parse_args()
406 main(args)