7.2.3. Example: Large Language Model (LLM) Inference
Applications for LLaMa 1B inference
We have prepared two configurations based on the presence of certain environment variables: 1. Running Prefill on MN-Core 2 while executing Decode on CPU (Prefill on MN-Core 2) 2. Running Prefill on CPU while executing Decode on MN-Core 2 (Decode on MN-Core 2)
For each configuration, you can add a prompt like --prompt 'What is the meaning of life?'.
Additionally, you can limit the number of threads used during compilation with the --num_compiler_threads option.
For detailed information, please refer to Compilation Errors.
Execution Method (Prefill on MN-Core 2)
$ cd /opt/pfn/pfcomp/codegen/examples/
$ MNCORE_USE_LEGACY_ONNX_EXPORTER=1 MNCORE_USE_EXTERNAL_DATA_FORMAT=1 CODEGEN_TIME_SLICE_SCATTERED_INDEXING_BCAST=1 CODEGEN_OP_DEF=Gather=GatherBcast ./examples/run_llm_infer.sh --compile_prefill --prepare_attention_mask_on_cpu --device mncore2:auto
Expected Output (Prefill on MN-Core 2)
=========== Generated with compilation ==========
</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
========== Generated with model.generate ==========
<s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
Generated outputs matched.
Execution Method (Decode on MN-Core 2)
$ cd /opt/pfn/pfcomp/codegen/examples/
$ MNCORE_USE_LEGACY_ONNX_EXPORTER=1 MNCORE_USE_EXTERNAL_DATA_FORMAT=1 ./examples/run_llm_infer.sh --compile_decode --prepare_attention_mask_on_cpu --device mncore2:auto
Expected Output (Decode on MN-Core 2)
=========== Generated with compilation ==========
</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
========== Generated with model.generate ==========
<s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
Generated outputs matched.
Script
1#! /bin/bash
2
3set -eux -o pipefail
4
5EXAMPLE_NAME=run_llm_infer
6VENVDIR=/tmp/${EXAMPLE_NAME}_venv
7
8CURRENT_DIR=$(realpath $(dirname $0))
9CODEGEN_DIR=$(realpath ${CURRENT_DIR}/../../)
10BUILD_DIR=${BUILD_DIR:-${CODEGEN_DIR}/build}
11
12if [[ ! -d ${VENVDIR} ]]; then
13 python3 -m venv --system-site-packages ${VENVDIR}
14fi
15
16source ${VENVDIR}/bin/activate
17# See https://discuss.huggingface.co/t/cas-bridge-xethub-hf-co-broke/158626/8 for hf_xet.
18pip3 install transformers==4.44.0 huggingface-hub==0.34.4 hf_xet==v1.1.5
19
20source "${BUILD_DIR}/codegen_preloads.sh"
21source "${BUILD_DIR}/codegen_pythonpath.sh"
22
23exec python3 ${CURRENT_DIR}/llm_infer.py "$@"
1import argparse
2from typing import Mapping
3
4import torch
5from mlsdk import CacheOptions, Context, MNDevice, storage
6from mlsdk.experimental.llm.attention_mask import (
7 prepare_4d_causal_attention_mask_with_cache_position,
8)
9from mlsdk.experimental.llm.kv_cache import (
10 kv_cache_to_legacy,
11 kv_cache_to_plamo,
12 kv_cache_to_tensor,
13)
14from transformers import AutoModelForCausalLM, AutoTokenizer
15
16
17def prepare_prompt(tokenizer, prompt, system_prompt):
18 if tokenizer.chat_template:
19 messages = [
20 {
21 "role": "system",
22 "content": system_prompt,
23 },
24 {"role": "user", "content": prompt},
25 ]
26 prompt = tokenizer.apply_chat_template(
27 messages,
28 tokenize=False,
29 add_generation_prompt=True,
30 )
31 return prompt
32
33
34def infer_with_generate(
35 prompt: str,
36 model: AutoModelForCausalLM,
37 tokenizer: AutoTokenizer,
38 max_new_tokens: int,
39) -> torch.Tensor:
40 inputs = tokenizer(prompt, return_tensors="pt")
41 # Greedy decoding for simplicity for comparing the results with the compiled version.
42 output_ids = model.generate(
43 inputs["input_ids"], do_sample=False, max_new_tokens=max_new_tokens
44 )
45 assert isinstance(output_ids, torch.Tensor)
46 return output_ids
47
48
49def infer_with_compilation( # NOQA: CFQ002, CFQ001
50 *,
51 prompt: str,
52 model: AutoModelForCausalLM,
53 tokenizer: AutoTokenizer,
54 max_length: int,
55 max_new_tokens: int,
56 compile_prefill: bool,
57 compile_decode: bool,
58 device_name: str,
59 outdir: str,
60 check_intermediate_outputs: bool,
61 prepare_attention_mask_on_cpu: bool,
62 disable_cache: bool,
63 num_compiler_threads: int,
64) -> torch.Tensor:
65 is_plamo_model = any("plamo" in a.lower() for a in model.config.architectures)
66
67 def forward(inputs: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
68 assert all(isinstance(v, torch.Tensor) for v in inputs.values()), {
69 k: type(v) for k, v in inputs.items()
70 }
71 if "past_key_values" in inputs:
72 if is_plamo_model:
73 kv_cache_func = kv_cache_to_plamo
74 else:
75 # @todo (hvy): Stop using the deprecated legacy KV cache format of tuples.
76 kv_cache_func = kv_cache_to_legacy
77 past_key_values = kv_cache_func(inputs["past_key_values"])
78 else:
79 past_key_values = None
80
81 outputs = model.forward(
82 input_ids=inputs["input_ids"],
83 attention_mask=inputs["attention_mask"],
84 position_ids=inputs["position_ids"],
85 past_key_values=past_key_values,
86 use_cache=True,
87 )
88 return {
89 "logits": outputs.logits,
90 "next_past_key_values": kv_cache_to_tensor(outputs.past_key_values)[
91 :, :, :, :, 1:, :
92 ], # Do every operation, including the shifting, for the KV cache on device.
93 }
94
95 assert tokenizer.padding_side == "left"
96 inputs = tokenizer(
97 prompt, return_tensors="pt", padding="max_length", max_length=max_length
98 )
99 # @todo (hvy): Consider subtracting 1 from the position_ids to match modeling_llama.py.
100 assert "position_ids" not in inputs
101 inputs["position_ids"] = inputs["attention_mask"].cumsum(1)
102 if prepare_attention_mask_on_cpu:
103 inputs["attention_mask"] = prepare_4d_causal_attention_mask_with_cache_position(
104 inputs["attention_mask"], inputs["position_ids"], model.dtype
105 )
106 output_ids = inputs["input_ids"]
107
108 device = MNDevice(device_name)
109 context = Context(device)
110 Context.switch_context(context)
111 context.registry.register("model", model)
112
113 compiled_funcs = {}
114
115 for step in range(max_new_tokens):
116 if step == 0:
117 if compile_prefill and "prefill" not in compiled_funcs:
118 compiled_funcs["prefill"] = context.compile(
119 forward,
120 inputs,
121 storage.path(outdir + "/prefill"),
122 cache_options=(
123 CacheOptions(outdir + "/prefill_cache")
124 if not disable_cache
125 else None
126 ),
127 num_compiler_threads=num_compiler_threads,
128 )
129 forward_for_step = compiled_funcs.get("prefill", forward)
130 else:
131 if compile_decode and "decode" not in compiled_funcs:
132 compiled_funcs["decode"] = context.compile(
133 forward,
134 inputs,
135 storage.path(outdir + "/decode"),
136 cache_options=(
137 CacheOptions(outdir + "/decode_cache")
138 if not disable_cache
139 else None
140 ),
141 num_compiler_threads=num_compiler_threads,
142 )
143 forward_for_step = compiled_funcs.get("decode", forward)
144
145 outputs = forward_for_step(inputs)
146
147 if check_intermediate_outputs:
148 # @todo (hvy): Consider using a more sophisticated check for the outputs.
149 if "mncore" in device_name:
150 atol = 1.0
151 else:
152 assert device_name == "pfvm:cpu"
153 atol = 5e-3
154 n_tokens = inputs["position_ids"].max()
155 outputs_expected = forward(inputs)
156 logits = outputs["logits"][:, -n_tokens:]
157 logits_expected = outputs_expected["logits"][:, -n_tokens:]
158 next_past_key_values = outputs["next_past_key_values"][
159 :, :, :, :, -n_tokens:
160 ]
161 next_past_key_values_expected = outputs_expected["next_past_key_values"][
162 :, :, :, :, -n_tokens:
163 ]
164
165 assert torch.allclose(logits, logits_expected, atol=atol), (
166 step,
167 (logits - logits_expected).abs().max(),
168 )
169 assert torch.allclose(
170 next_past_key_values, next_past_key_values_expected, atol=atol
171 ), (
172 step,
173 (next_past_key_values - next_past_key_values_expected).abs().max(),
174 )
175
176 next_input_ids = (
177 outputs["logits"].cpu().argmax(dim=2)[:, -1:]
178 ) # Greedy decoding.
179 if prepare_attention_mask_on_cpu:
180 next_attention_mask = inputs["attention_mask"][:, :, -1:, :]
181 next_attention_mask = torch.roll(next_attention_mask, shifts=-1, dims=-1)
182 next_attention_mask[:, :, :, -1] = 0
183 else:
184 next_attention_mask = inputs["attention_mask"]
185 next_attention_mask = torch.roll(next_attention_mask, shifts=-1, dims=-1)
186 next_attention_mask[:, -1] = 1
187 next_position_ids = inputs["position_ids"][:, -1:] + 1
188 next_past_key_values = outputs["next_past_key_values"].cpu()
189 inputs = {
190 "input_ids": next_input_ids.detach(),
191 "attention_mask": next_attention_mask.detach(),
192 "position_ids": next_position_ids.detach(),
193 "past_key_values": next_past_key_values.detach(),
194 }
195
196 output_ids = torch.cat([output_ids, next_input_ids], dim=1)
197
198 if next_input_ids.item() == tokenizer.eos_token_id:
199 break
200
201 return output_ids[:, max_new_tokens:]
202
203
204def main(args):
205 prompt = args.prompt
206 system_prompt = args.system_prompt
207 model_name = args.model_name
208 max_length = args.max_length
209 max_new_tokens = args.max_new_tokens
210 compile_prefill = args.compile_prefill
211 compile_decode = args.compile_decode
212 device_name = args.device_name
213 outdir = args.outdir
214 check_intermediate_outputs = args.check_intermediate_outputs
215 prepare_attention_mask_on_cpu = args.prepare_attention_mask_on_cpu
216 disable_cache = args.disable_cache
217 num_compiler_threads = args.num_compiler_threads
218
219 model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
220 model.eval() # Some models do not return the KV cache in training mode.
221 tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
222 if tokenizer.pad_token_id is None:
223 tokenizer.pad_token_id = tokenizer.eos_token_id
224 tokenizer.padding_side = "left" # For static KV caching
225 tokenizer.truncation_side = "left"
226
227 prompt = prepare_prompt(tokenizer, prompt, system_prompt)
228
229 outputs = infer_with_compilation(
230 prompt=prompt,
231 model=model,
232 tokenizer=tokenizer,
233 max_length=max_length,
234 max_new_tokens=max_new_tokens,
235 compile_prefill=compile_prefill,
236 compile_decode=compile_decode,
237 device_name=device_name,
238 outdir=outdir,
239 check_intermediate_outputs=check_intermediate_outputs,
240 prepare_attention_mask_on_cpu=prepare_attention_mask_on_cpu,
241 disable_cache=disable_cache,
242 num_compiler_threads=num_compiler_threads,
243 )
244 print(
245 "=========== Generated with compilation ==========\n",
246 tokenizer.decode(outputs[0]),
247 )
248
249 outputs_expected = infer_with_generate(prompt, model, tokenizer, max_new_tokens)
250 print(
251 "========== Generated with model.generate ==========\n",
252 tokenizer.decode(outputs_expected[0]),
253 )
254
255 # @todo (hvy): Do not rely on `max_new_tokens` tokens always being generated?
256 assert torch.equal(
257 outputs[:, -max_new_tokens:], outputs_expected[:, -max_new_tokens:]
258 ), "Outputs differed. Check generated outputs above."
259 print("Generated outputs matched.")
260
261
262if __name__ == "__main__":
263 parser = argparse.ArgumentParser()
264 parser.add_argument(
265 "--prompt",
266 type=str,
267 default='The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.', # NOQA
268 )
269 parser.add_argument(
270 "--system_prompt",
271 type=str,
272 default="You are a friendly chatbot who is an expert on MN-Core.",
273 )
274 parser.add_argument(
275 "--model_name", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
276 )
277 parser.add_argument("--max_length", type=int, default=256)
278 parser.add_argument("--max_new_tokens", type=int, default=64)
279 parser.add_argument("--num_compiler_threads", type=int, default=-1)
280 parser.add_argument("--compile_prefill", action="store_true")
281 parser.add_argument("--compile_decode", action="store_true")
282 parser.add_argument("--device_name", type=str, default="mncore2:auto")
283 parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_llm_infer")
284 parser.add_argument("--check_intermediate_outputs", action="store_true")
285 parser.add_argument("--prepare_attention_mask_on_cpu", action="store_true")
286 parser.add_argument("--disable_cache", action="store_true")
287 args = parser.parse_args()
288 main(args)