7.2.3. Example: Large Language Model (LLM) Inference

Applications for LLaMa 1B inference

We have prepared two configurations based on the presence of certain environment variables: 1. Running Prefill on MN-Core 2 while executing Decode on CPU (Prefill on MN-Core 2) 2. Running Prefill on CPU while executing Decode on MN-Core 2 (Decode on MN-Core 2)

For each configuration, you can add a prompt like --prompt 'What is the meaning of life?'. Additionally, you can limit the number of threads used during compilation with the --num_compiler_threads option. For detailed information, please refer to Compilation Errors.

Execution Method (Prefill on MN-Core 2)

Listing 7.15 Prefill on MN-Core 2
$ cd /opt/pfn/pfcomp/codegen/examples/
$ MNCORE_USE_LEGACY_ONNX_EXPORTER=1 MNCORE_USE_EXTERNAL_DATA_FORMAT=1 CODEGEN_TIME_SLICE_SCATTERED_INDEXING_BCAST=1 CODEGEN_OP_DEF=Gather=GatherBcast ./examples/run_llm_infer.sh --compile_prefill --prepare_attention_mask_on_cpu --device mncore2:auto

Expected Output (Prefill on MN-Core 2)

=========== Generated with compilation ==========
 </s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
========== Generated with model.generate ==========
 <s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
Generated outputs matched.

Execution Method (Decode on MN-Core 2)

Listing 7.16 Decode on MN-Core 2
$ cd /opt/pfn/pfcomp/codegen/examples/
$ MNCORE_USE_LEGACY_ONNX_EXPORTER=1 MNCORE_USE_EXTERNAL_DATA_FORMAT=1 ./examples/run_llm_infer.sh --compile_decode --prepare_attention_mask_on_cpu --device mncore2:auto

Expected Output (Decode on MN-Core 2)

=========== Generated with compilation ==========
 </s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
========== Generated with model.generate ==========
 <s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
Generated outputs matched.

Script

Listing 7.17 /opt/pfn/pfcomp/codegen/MLSDK/examples/run_llm_infer.sh
 1#! /bin/bash
 2
 3set -eux -o pipefail
 4
 5EXAMPLE_NAME=run_llm_infer
 6VENVDIR=/tmp/${EXAMPLE_NAME}_venv
 7
 8CURRENT_DIR=$(realpath $(dirname $0))
 9CODEGEN_DIR=$(realpath ${CURRENT_DIR}/../../)
10BUILD_DIR=${BUILD_DIR:-${CODEGEN_DIR}/build}
11
12if [[ ! -d ${VENVDIR} ]]; then
13    python3 -m venv --system-site-packages ${VENVDIR}
14fi
15
16source ${VENVDIR}/bin/activate
17# See https://discuss.huggingface.co/t/cas-bridge-xethub-hf-co-broke/158626/8 for hf_xet.
18pip3 install transformers==4.44.0 huggingface-hub==0.34.4 hf_xet==v1.1.5
19
20source "${BUILD_DIR}/codegen_preloads.sh"
21source "${BUILD_DIR}/codegen_pythonpath.sh"
22
23exec python3 ${CURRENT_DIR}/llm_infer.py "$@"
Listing 7.18 /opt/pfn/pfcomp/codegen/MLSDK/examples/llm_infer.py
  1import argparse
  2from typing import Mapping
  3
  4import torch
  5from mlsdk import CacheOptions, Context, MNDevice, storage
  6from mlsdk.experimental.llm.attention_mask import (
  7    prepare_4d_causal_attention_mask_with_cache_position,
  8)
  9from mlsdk.experimental.llm.kv_cache import (
 10    kv_cache_to_legacy,
 11    kv_cache_to_plamo,
 12    kv_cache_to_tensor,
 13)
 14from transformers import AutoModelForCausalLM, AutoTokenizer
 15
 16
 17def prepare_prompt(tokenizer, prompt, system_prompt):
 18    if tokenizer.chat_template:
 19        messages = [
 20            {
 21                "role": "system",
 22                "content": system_prompt,
 23            },
 24            {"role": "user", "content": prompt},
 25        ]
 26        prompt = tokenizer.apply_chat_template(
 27            messages,
 28            tokenize=False,
 29            add_generation_prompt=True,
 30        )
 31    return prompt
 32
 33
 34def infer_with_generate(
 35    prompt: str,
 36    model: AutoModelForCausalLM,
 37    tokenizer: AutoTokenizer,
 38    max_new_tokens: int,
 39) -> torch.Tensor:
 40    inputs = tokenizer(prompt, return_tensors="pt")
 41    # Greedy decoding for simplicity for comparing the results with the compiled version.
 42    output_ids = model.generate(
 43        inputs["input_ids"], do_sample=False, max_new_tokens=max_new_tokens
 44    )
 45    assert isinstance(output_ids, torch.Tensor)
 46    return output_ids
 47
 48
 49def infer_with_compilation(  # NOQA: CFQ002, CFQ001
 50    *,
 51    prompt: str,
 52    model: AutoModelForCausalLM,
 53    tokenizer: AutoTokenizer,
 54    max_length: int,
 55    max_new_tokens: int,
 56    compile_prefill: bool,
 57    compile_decode: bool,
 58    device_name: str,
 59    outdir: str,
 60    check_intermediate_outputs: bool,
 61    prepare_attention_mask_on_cpu: bool,
 62    disable_cache: bool,
 63    num_compiler_threads: int,
 64) -> torch.Tensor:
 65    is_plamo_model = any("plamo" in a.lower() for a in model.config.architectures)
 66
 67    def forward(inputs: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
 68        assert all(isinstance(v, torch.Tensor) for v in inputs.values()), {
 69            k: type(v) for k, v in inputs.items()
 70        }
 71        if "past_key_values" in inputs:
 72            if is_plamo_model:
 73                kv_cache_func = kv_cache_to_plamo
 74            else:
 75                # @todo (hvy): Stop using the deprecated legacy KV cache format of tuples.
 76                kv_cache_func = kv_cache_to_legacy
 77            past_key_values = kv_cache_func(inputs["past_key_values"])
 78        else:
 79            past_key_values = None
 80
 81        outputs = model.forward(
 82            input_ids=inputs["input_ids"],
 83            attention_mask=inputs["attention_mask"],
 84            position_ids=inputs["position_ids"],
 85            past_key_values=past_key_values,
 86            use_cache=True,
 87        )
 88        return {
 89            "logits": outputs.logits,
 90            "next_past_key_values": kv_cache_to_tensor(outputs.past_key_values)[
 91                :, :, :, :, 1:, :
 92            ],  # Do every operation, including the shifting, for the KV cache on device.
 93        }
 94
 95    assert tokenizer.padding_side == "left"
 96    inputs = tokenizer(
 97        prompt, return_tensors="pt", padding="max_length", max_length=max_length
 98    )
 99    # @todo (hvy): Consider subtracting 1 from the position_ids to match modeling_llama.py.
100    assert "position_ids" not in inputs
101    inputs["position_ids"] = inputs["attention_mask"].cumsum(1)
102    if prepare_attention_mask_on_cpu:
103        inputs["attention_mask"] = prepare_4d_causal_attention_mask_with_cache_position(
104            inputs["attention_mask"], inputs["position_ids"], model.dtype
105        )
106    output_ids = inputs["input_ids"]
107
108    device = MNDevice(device_name)
109    context = Context(device)
110    Context.switch_context(context)
111    context.registry.register("model", model)
112
113    compiled_funcs = {}
114
115    for step in range(max_new_tokens):
116        if step == 0:
117            if compile_prefill and "prefill" not in compiled_funcs:
118                compiled_funcs["prefill"] = context.compile(
119                    forward,
120                    inputs,
121                    storage.path(outdir + "/prefill"),
122                    cache_options=(
123                        CacheOptions(outdir + "/prefill_cache")
124                        if not disable_cache
125                        else None
126                    ),
127                    num_compiler_threads=num_compiler_threads,
128                )
129            forward_for_step = compiled_funcs.get("prefill", forward)
130        else:
131            if compile_decode and "decode" not in compiled_funcs:
132                compiled_funcs["decode"] = context.compile(
133                    forward,
134                    inputs,
135                    storage.path(outdir + "/decode"),
136                    cache_options=(
137                        CacheOptions(outdir + "/decode_cache")
138                        if not disable_cache
139                        else None
140                    ),
141                    num_compiler_threads=num_compiler_threads,
142                )
143            forward_for_step = compiled_funcs.get("decode", forward)
144
145        outputs = forward_for_step(inputs)
146
147        if check_intermediate_outputs:
148            # @todo (hvy): Consider using a more sophisticated check for the outputs.
149            if "mncore" in device_name:
150                atol = 1.0
151            else:
152                assert device_name == "pfvm:cpu"
153                atol = 5e-3
154            n_tokens = inputs["position_ids"].max()
155            outputs_expected = forward(inputs)
156            logits = outputs["logits"][:, -n_tokens:]
157            logits_expected = outputs_expected["logits"][:, -n_tokens:]
158            next_past_key_values = outputs["next_past_key_values"][
159                :, :, :, :, -n_tokens:
160            ]
161            next_past_key_values_expected = outputs_expected["next_past_key_values"][
162                :, :, :, :, -n_tokens:
163            ]
164
165            assert torch.allclose(logits, logits_expected, atol=atol), (
166                step,
167                (logits - logits_expected).abs().max(),
168            )
169            assert torch.allclose(
170                next_past_key_values, next_past_key_values_expected, atol=atol
171            ), (
172                step,
173                (next_past_key_values - next_past_key_values_expected).abs().max(),
174            )
175
176        next_input_ids = (
177            outputs["logits"].cpu().argmax(dim=2)[:, -1:]
178        )  # Greedy decoding.
179        if prepare_attention_mask_on_cpu:
180            next_attention_mask = inputs["attention_mask"][:, :, -1:, :]
181            next_attention_mask = torch.roll(next_attention_mask, shifts=-1, dims=-1)
182            next_attention_mask[:, :, :, -1] = 0
183        else:
184            next_attention_mask = inputs["attention_mask"]
185            next_attention_mask = torch.roll(next_attention_mask, shifts=-1, dims=-1)
186            next_attention_mask[:, -1] = 1
187        next_position_ids = inputs["position_ids"][:, -1:] + 1
188        next_past_key_values = outputs["next_past_key_values"].cpu()
189        inputs = {
190            "input_ids": next_input_ids.detach(),
191            "attention_mask": next_attention_mask.detach(),
192            "position_ids": next_position_ids.detach(),
193            "past_key_values": next_past_key_values.detach(),
194        }
195
196        output_ids = torch.cat([output_ids, next_input_ids], dim=1)
197
198        if next_input_ids.item() == tokenizer.eos_token_id:
199            break
200
201    return output_ids[:, max_new_tokens:]
202
203
204def main(args):
205    prompt = args.prompt
206    system_prompt = args.system_prompt
207    model_name = args.model_name
208    max_length = args.max_length
209    max_new_tokens = args.max_new_tokens
210    compile_prefill = args.compile_prefill
211    compile_decode = args.compile_decode
212    device_name = args.device_name
213    outdir = args.outdir
214    check_intermediate_outputs = args.check_intermediate_outputs
215    prepare_attention_mask_on_cpu = args.prepare_attention_mask_on_cpu
216    disable_cache = args.disable_cache
217    num_compiler_threads = args.num_compiler_threads
218
219    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
220    model.eval()  # Some models do not return the KV cache in training mode.
221    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
222    if tokenizer.pad_token_id is None:
223        tokenizer.pad_token_id = tokenizer.eos_token_id
224    tokenizer.padding_side = "left"  # For static KV caching
225    tokenizer.truncation_side = "left"
226
227    prompt = prepare_prompt(tokenizer, prompt, system_prompt)
228
229    outputs = infer_with_compilation(
230        prompt=prompt,
231        model=model,
232        tokenizer=tokenizer,
233        max_length=max_length,
234        max_new_tokens=max_new_tokens,
235        compile_prefill=compile_prefill,
236        compile_decode=compile_decode,
237        device_name=device_name,
238        outdir=outdir,
239        check_intermediate_outputs=check_intermediate_outputs,
240        prepare_attention_mask_on_cpu=prepare_attention_mask_on_cpu,
241        disable_cache=disable_cache,
242        num_compiler_threads=num_compiler_threads,
243    )
244    print(
245        "=========== Generated with compilation ==========\n",
246        tokenizer.decode(outputs[0]),
247    )
248
249    outputs_expected = infer_with_generate(prompt, model, tokenizer, max_new_tokens)
250    print(
251        "========== Generated with model.generate ==========\n",
252        tokenizer.decode(outputs_expected[0]),
253    )
254
255    # @todo (hvy): Do not rely on `max_new_tokens` tokens always being generated?
256    assert torch.equal(
257        outputs[:, -max_new_tokens:], outputs_expected[:, -max_new_tokens:]
258    ), "Outputs differed. Check generated outputs above."
259    print("Generated outputs matched.")
260
261
262if __name__ == "__main__":
263    parser = argparse.ArgumentParser()
264    parser.add_argument(
265        "--prompt",
266        type=str,
267        default='The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.',  # NOQA
268    )
269    parser.add_argument(
270        "--system_prompt",
271        type=str,
272        default="You are a friendly chatbot who is an expert on MN-Core.",
273    )
274    parser.add_argument(
275        "--model_name", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
276    )
277    parser.add_argument("--max_length", type=int, default=256)
278    parser.add_argument("--max_new_tokens", type=int, default=64)
279    parser.add_argument("--num_compiler_threads", type=int, default=-1)
280    parser.add_argument("--compile_prefill", action="store_true")
281    parser.add_argument("--compile_decode", action="store_true")
282    parser.add_argument("--device_name", type=str, default="mncore2:auto")
283    parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_llm_infer")
284    parser.add_argument("--check_intermediate_outputs", action="store_true")
285    parser.add_argument("--prepare_attention_mask_on_cpu", action="store_true")
286    parser.add_argument("--disable_cache", action="store_true")
287    args = parser.parse_args()
288    main(args)