7.2.1. Example: timm Model Inference

An application example that retrieves a model from timm and performs inference on the Image (beignets-task-guide.png)

beignets-task-guide.png

Fig. 7.1 beignets-task-guide.png

Execution Method (resnet50.a1h_in1k)

$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./run_timm.sh --model_name resnet50.a1h_in1k --batch_size 16

Expected Output (resnet50.a1h_in1k)

MNCore2 top-5 classes:
- espresso (967)
- cup (968)
- chocolate sauce, chocolate syrup (960)
- consomme (925)
- eggnog (969)
Torch top-5 classes:
- espresso (967)
- cup (968)
- chocolate sauce, chocolate syrup (960)
- eggnog (969)
- consomme (925)

Execution Method (mobilenetv3_small_050.lamb_in1k)

$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./run_timm.sh --model_name mobilenetv3_small_050.lamb_in1k --batch_size 16

Expected Output (mobilenetv3_small_050.lamb_in1k)

MNCore2 top-5 classes:
- cup (968)
- trifle (927)
- face powder (551)
- ice cream, icecream (928)
- coffee mug (504)
Torch top-5 classes:
- cup (968)
- trifle (927)
- ice cream, icecream (928)
- face powder (551)
- coffee mug (504)

Script

Listing 7.11 /opt/pfn/pfcomp/codegen/MLSDK/examples/run_timm.sh
 1#! /bin/bash
 2
 3set -eux -o pipefail
 4
 5EXAMPLE_NAME=run_timm
 6VENVDIR=/tmp/${EXAMPLE_NAME}_venv
 7
 8CURRENT_DIR=$(realpath $(dirname $0))
 9CODEGEN_DIR=$(realpath ${CURRENT_DIR}/../../)
10BUILD_DIR=${BUILD_DIR:-${CODEGEN_DIR}/build}
11
12if [[ ! -d ${VENVDIR} ]]; then
13    python3 -m venv --system-site-packages ${VENVDIR}
14    source ${VENVDIR}/bin/activate
15    # PFN_INTERNAL_START
16    # Install huggingface-hub==0.28.1 resolves CI issues. See
17    # https://github.com/pfn-qfin/hfscs/issues/21 for more details.
18    # PFN_INTERNAL_END
19    pip3 install timm==1.0.14 huggingface-hub==0.28.1
20else
21    source ${VENVDIR}/bin/activate
22fi
23
24source "${BUILD_DIR}/codegen_preloads.sh"
25source "${BUILD_DIR}/codegen_pythonpath.sh"
26
27exec python3 ${CURRENT_DIR}/${EXAMPLE_NAME}.py "$@"
Listing 7.12 /opt/pfn/pfcomp/codegen/MLSDK/examples/run_timm.py
  1import argparse
  2import os
  3from pathlib import Path
  4from typing import Any, Optional, Union
  5
  6import mncore  # noqa: F401
  7import timm
  8import torch
  9from mlsdk import (
 10    Context,
 11    MNCoreSGD,
 12    MNDevice,
 13    set_buffer_name_in_optimizer,
 14    set_tensor_name_in_module,
 15    storage,
 16)
 17from PIL import Image
 18
 19SAMPLE_IMAGE_PATH = os.path.join(
 20    os.path.dirname(__file__), "./datasets/mncore2_chip.png"
 21)
 22
 23
 24def escape_path(path: str) -> str:
 25    escaped = ""
 26    for c in path:
 27        if c.isalnum() or c in "_-":
 28            escaped += c
 29        else:
 30            escaped += "_"
 31    return escaped
 32
 33
 34def create_model_with_cache(
 35    model_name: str, model_cache_dir: Optional[str] = None, **kwargs: Any
 36) -> Any:
 37    if not model_cache_dir:
 38        return timm.create_model(model_name, **kwargs)
 39    else:
 40        timm_version = "timm_version" + timm.__version__
 41        torch_version = "torch_version" + torch.__version__
 42        cache_dir = os.path.join(
 43            model_cache_dir,
 44            escape_path(f"{torch_version}_{timm_version}_{model_name}"),
 45        )
 46        # Load the model always from the cache to return the same model object always.
 47        # This should also create the cache if it does not exist.
 48        return timm.create_model(model_name, **kwargs, cache_dir=cache_dir)
 49
 50
 51def imagenet_classes() -> list[str]:
 52    script_dir = os.path.dirname(__file__)
 53    imagenet_classes_path = os.path.join(script_dir, "imagenet_classes.txt")
 54    with open(imagenet_classes_path) as f:
 55        return [line.strip() for line in f]
 56
 57
 58def run_inference(
 59    model_name: str,
 60    batch_size: int,
 61    outdir: str,
 62    option_json_path: Optional[Path],
 63    device_str: str,
 64    model_cache_dir: Optional[str],
 65) -> None:
 66    img = Image.open(SAMPLE_IMAGE_PATH)
 67    model = create_model_with_cache(
 68        model_name,
 69        pretrained=True,
 70        model_cache_dir=model_cache_dir,
 71    )
 72    model = model.eval()
 73
 74    def infer(input: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
 75        with torch.no_grad():
 76            x = input["images"]
 77            return {"out": model(x)}
 78
 79    data_config = timm.data.resolve_model_data_config(model)
 80    transforms = timm.data.create_transform(**data_config, is_training=False)
 81    images = transforms(img).unsqueeze(0).expand(batch_size, -1, -1, -1)
 82    sample = {"images": images}
 83
 84    device = MNDevice(device_str)
 85    context = Context(device)
 86    Context.switch_context(context)
 87    context.registry.register("model", model)
 88
 89    compile_options: dict[str, str] = {}
 90    if option_json_path is not None:
 91        compile_options = {"option_json": str(option_json_path)}
 92
 93    compiled_infer = context.compile(
 94        infer,
 95        sample,
 96        storage.path(outdir) / "infer",
 97        options=compile_options,
 98    )
 99    result_on_mncore2 = compiled_infer(sample)
100    result_on_torch = infer(sample)
101
102    torch.allclose(result_on_mncore2["out"].cpu(), result_on_torch["out"], atol=1e-5)
103
104    if "in1k" in model_name:
105        classes = imagenet_classes()
106        mncore_top5_classes = torch.topk(
107            result_on_mncore2["out"].cpu()[0], 5
108        ).indices.cpu()
109        print("MNCore2 top-5 classes:")
110        for i in mncore_top5_classes:
111            print(f"- {classes[i]} ({i.item()})")
112        torch_top5_classes = torch.topk(result_on_torch["out"][0], 5).indices
113        print("Torch top-5 classes:")
114        for i in torch_top5_classes:
115            print(f"- {classes[i]} ({i.item()})")
116
117
118def run_training_torch_onnx(
119    model_name: str,
120    batch_size: int,
121    outdir: str,
122    option_json_path: Optional[Path],
123    device: str,
124    model_cache_dir: Optional[str],
125) -> None:
126    device = MNDevice(device)
127    context = Context(device)
128    Context.switch_context(context)
129
130    img = Image.open(SAMPLE_IMAGE_PATH)
131
132    model = create_model_with_cache(
133        model_name,
134        pretrained=True,
135        num_classes=1000,
136        model_cache_dir=model_cache_dir,
137    )
138    data_config = timm.data.resolve_model_data_config(model)
139    transforms = timm.data.create_transform(**data_config, is_training=False)
140    images = transforms(img).unsqueeze(0).expand(batch_size, -1, -1, -1)
141    labels = torch.randint(0, 1000, (batch_size,))
142    sample = {"images": images, "labels": labels}
143
144    model = model.train()
145    context.registry.register("model", model)
146    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
147    context.registry.register("optimizer", optimizer)
148    loss_fn = torch.nn.CrossEntropyLoss()
149
150    def f(inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
151        return {"loss": loss_fn(model(inputs["images"]), inputs["labels"])}
152
153    compile_options: dict[str, Union[str, bool]] = {}
154    if option_json_path is not None:
155        compile_options = {"option_json": str(option_json_path)}
156    compile_options["backprop"] = True
157    compiled_f = context.compile(
158        f,
159        sample,
160        storage.path(outdir) / "train_step_torch_onnx",
161        optimizers=[optimizer],
162        options=compile_options,
163    )
164
165    first_loss = compiled_f(sample)["loss"].cpu()
166    for _ in range(10):
167        compiled_f(sample)
168    context.synchronize()
169    last_loss = compiled_f(sample)["loss"].cpu()
170
171    assert last_loss < first_loss
172
173
174def run_training_fx2onnx(
175    model_name: str,
176    batch_size: int,
177    outdir: str,
178    option_json_path: Optional[Path],
179    device_str: str,
180    model_cache_dir: Optional[str],
181) -> None:
182    device = MNDevice(device_str)
183    context = Context(device)
184    Context.switch_context(context)
185
186    img = Image.open(SAMPLE_IMAGE_PATH)
187
188    model = create_model_with_cache(
189        model_name,
190        pretrained=True,
191        num_classes=1000,
192        model_cache_dir=model_cache_dir,
193    )
194    model = model.train()
195    set_tensor_name_in_module(model, "model0")
196    for p in model.parameters():
197        context.register_param(p)
198
199    optimizer = MNCoreSGD(model.parameters(), 0.1, 0.9, 0.0)
200    set_buffer_name_in_optimizer(optimizer, "optimizer0")
201    context.register_optimizer_buffers(optimizer)
202    loss_fn = torch.nn.CrossEntropyLoss()
203
204    def train_step(input: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
205        x = input["images"]
206        t = input["labels"]
207        optimizer.zero_grad()
208        y = model(x)
209        loss = loss_fn(y, t)
210        loss.backward()
211        optimizer.step()
212        return {"loss": loss}
213
214    data_config = timm.data.resolve_model_data_config(model)
215    transforms = timm.data.create_transform(**data_config, is_training=False)
216    images = transforms(img).unsqueeze(0).expand(batch_size, -1, -1, -1)
217    labels = torch.randint(0, 1000, (batch_size,))
218    sample = {"images": images, "labels": labels}
219
220    compile_options: dict[str, str] = {}
221    if option_json_path is not None:
222        compile_options = {"option_json": str(option_json_path)}
223
224    compiled_train_step = context.compile(
225        train_step,
226        sample,
227        storage.path(outdir) / "train_step_fx2onnx",
228        options=compile_options,
229        export_kwargs={"use_fx2onnx": True},
230    )
231
232    first_loss = compiled_train_step(sample)["loss"].cpu()
233    for _ in range(10):
234        compiled_train_step(sample)
235    context.synchronize()
236    last_loss = compiled_train_step(sample)["loss"].cpu()
237
238    assert last_loss < first_loss
239
240
241if __name__ == "__main__":
242    parser = argparse.ArgumentParser()
243    parser.add_argument("--batch_size", type=int, default=1, required=True)
244    parser.add_argument("--model_name", type=str)
245    parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_timm")
246    parser.add_argument("--option_json", type=Path, default=None)
247    parser.add_argument("--is_training", action="store_true")
248    parser.add_argument(
249        "--device",
250        type=str,
251        default="mncore2:auto",
252        choices=["mncore2:auto", "pfvm:cpu"],
253    )
254    parser.add_argument(
255        "--model_cache_dir",
256        type=str,
257        default=None,
258        help="Directory to cache the model weights. "
259        "If not set, weights are always downloaded from the hub. default: None",
260    )
261    args = parser.parse_args()
262
263    outdir = args.outdir
264    if outdir is None:
265        outdir = f"/tmp/MLSDK_codegen_dir_{args.model_name}"
266        if args.is_training:
267            outdir += "_training"
268        else:
269            outdir += "_inference"
270
271    # TODO (akirakawata): Should we make this argument?
272    use_fx2onnx = not bool(
273        int(os.environ.get("MNCORE_USE_LEGACY_ONNX_EXPORTER", False))
274    )
275    if args.is_training:
276        if use_fx2onnx:
277            run_training_fx2onnx(
278                args.model_name,
279                args.batch_size,
280                outdir,
281                args.option_json,
282                args.device,
283                args.model_cache_dir,
284            )
285        else:
286            run_training_torch_onnx(
287                args.model_name,
288                args.batch_size,
289                args.outdir,
290                args.option_json,
291                args.device,
292                args.model_cache_dir,
293            )
294    else:
295        run_inference(
296            args.model_name,
297            args.batch_size,
298            args.outdir,
299            args.option_json,
300            args.device,
301            args.model_cache_dir,
302        )