7.2.1. Example: timm Model Inference
An application example that retrieves a model from timm and performs inference on the Image (beignets-task-guide.png)
Fig. 7.1 beignets-task-guide.png
Execution Method (resnet50.a1h_in1k)
$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./run_timm.sh --model_name resnet50.a1h_in1k --batch_size 16
Expected Output (resnet50.a1h_in1k)
MNCore2 top-5 classes:
- espresso (967)
- cup (968)
- chocolate sauce, chocolate syrup (960)
- consomme (925)
- eggnog (969)
Torch top-5 classes:
- espresso (967)
- cup (968)
- chocolate sauce, chocolate syrup (960)
- eggnog (969)
- consomme (925)
Execution Method (mobilenetv3_small_050.lamb_in1k)
$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./run_timm.sh --model_name mobilenetv3_small_050.lamb_in1k --batch_size 16
Expected Output (mobilenetv3_small_050.lamb_in1k)
MNCore2 top-5 classes:
- cup (968)
- trifle (927)
- face powder (551)
- ice cream, icecream (928)
- coffee mug (504)
Torch top-5 classes:
- cup (968)
- trifle (927)
- ice cream, icecream (928)
- face powder (551)
- coffee mug (504)
Script
1#! /bin/bash
2
3set -eux -o pipefail
4
5EXAMPLE_NAME=run_timm
6VENVDIR=/tmp/${EXAMPLE_NAME}_venv
7
8CURRENT_DIR=$(realpath $(dirname $0))
9CODEGEN_DIR=$(realpath ${CURRENT_DIR}/../../)
10BUILD_DIR=${BUILD_DIR:-${CODEGEN_DIR}/build}
11
12if [[ ! -d ${VENVDIR} ]]; then
13 python3 -m venv --system-site-packages ${VENVDIR}
14 source ${VENVDIR}/bin/activate
15 # PFN_INTERNAL_START
16 # Install huggingface-hub==0.28.1 resolves CI issues. See
17 # https://github.com/pfn-qfin/hfscs/issues/21 for more details.
18 # PFN_INTERNAL_END
19 pip3 install timm==1.0.14 huggingface-hub==0.28.1
20else
21 source ${VENVDIR}/bin/activate
22fi
23
24source "${BUILD_DIR}/codegen_preloads.sh"
25source "${BUILD_DIR}/codegen_pythonpath.sh"
26
27exec python3 ${CURRENT_DIR}/${EXAMPLE_NAME}.py "$@"
1import argparse
2import os
3from pathlib import Path
4from typing import Any, Optional, Union
5
6import mncore # noqa: F401
7import timm
8import torch
9from mlsdk import (
10 Context,
11 MNCoreSGD,
12 MNDevice,
13 set_buffer_name_in_optimizer,
14 set_tensor_name_in_module,
15 storage,
16)
17from PIL import Image
18
19SAMPLE_IMAGE_PATH = os.path.join(
20 os.path.dirname(__file__), "./datasets/mncore2_chip.png"
21)
22
23
24def escape_path(path: str) -> str:
25 escaped = ""
26 for c in path:
27 if c.isalnum() or c in "_-":
28 escaped += c
29 else:
30 escaped += "_"
31 return escaped
32
33
34def create_model_with_cache(
35 model_name: str, model_cache_dir: Optional[str] = None, **kwargs: Any
36) -> Any:
37 if not model_cache_dir:
38 return timm.create_model(model_name, **kwargs)
39 else:
40 timm_version = "timm_version" + timm.__version__
41 torch_version = "torch_version" + torch.__version__
42 cache_dir = os.path.join(
43 model_cache_dir,
44 escape_path(f"{torch_version}_{timm_version}_{model_name}"),
45 )
46 # Load the model always from the cache to return the same model object always.
47 # This should also create the cache if it does not exist.
48 return timm.create_model(model_name, **kwargs, cache_dir=cache_dir)
49
50
51def imagenet_classes() -> list[str]:
52 script_dir = os.path.dirname(__file__)
53 imagenet_classes_path = os.path.join(script_dir, "imagenet_classes.txt")
54 with open(imagenet_classes_path) as f:
55 return [line.strip() for line in f]
56
57
58def run_inference(
59 model_name: str,
60 batch_size: int,
61 outdir: str,
62 option_json_path: Optional[Path],
63 device_str: str,
64 model_cache_dir: Optional[str],
65) -> None:
66 img = Image.open(SAMPLE_IMAGE_PATH)
67 model = create_model_with_cache(
68 model_name,
69 pretrained=True,
70 model_cache_dir=model_cache_dir,
71 )
72 model = model.eval()
73
74 def infer(input: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
75 with torch.no_grad():
76 x = input["images"]
77 return {"out": model(x)}
78
79 data_config = timm.data.resolve_model_data_config(model)
80 transforms = timm.data.create_transform(**data_config, is_training=False)
81 images = transforms(img).unsqueeze(0).expand(batch_size, -1, -1, -1)
82 sample = {"images": images}
83
84 device = MNDevice(device_str)
85 context = Context(device)
86 Context.switch_context(context)
87 context.registry.register("model", model)
88
89 compile_options: dict[str, str] = {}
90 if option_json_path is not None:
91 compile_options = {"option_json": str(option_json_path)}
92
93 compiled_infer = context.compile(
94 infer,
95 sample,
96 storage.path(outdir) / "infer",
97 options=compile_options,
98 )
99 result_on_mncore2 = compiled_infer(sample)
100 result_on_torch = infer(sample)
101
102 torch.allclose(result_on_mncore2["out"].cpu(), result_on_torch["out"], atol=1e-5)
103
104 if "in1k" in model_name:
105 classes = imagenet_classes()
106 mncore_top5_classes = torch.topk(
107 result_on_mncore2["out"].cpu()[0], 5
108 ).indices.cpu()
109 print("MNCore2 top-5 classes:")
110 for i in mncore_top5_classes:
111 print(f"- {classes[i]} ({i.item()})")
112 torch_top5_classes = torch.topk(result_on_torch["out"][0], 5).indices
113 print("Torch top-5 classes:")
114 for i in torch_top5_classes:
115 print(f"- {classes[i]} ({i.item()})")
116
117
118def run_training_torch_onnx(
119 model_name: str,
120 batch_size: int,
121 outdir: str,
122 option_json_path: Optional[Path],
123 device: str,
124 model_cache_dir: Optional[str],
125) -> None:
126 device = MNDevice(device)
127 context = Context(device)
128 Context.switch_context(context)
129
130 img = Image.open(SAMPLE_IMAGE_PATH)
131
132 model = create_model_with_cache(
133 model_name,
134 pretrained=True,
135 num_classes=1000,
136 model_cache_dir=model_cache_dir,
137 )
138 data_config = timm.data.resolve_model_data_config(model)
139 transforms = timm.data.create_transform(**data_config, is_training=False)
140 images = transforms(img).unsqueeze(0).expand(batch_size, -1, -1, -1)
141 labels = torch.randint(0, 1000, (batch_size,))
142 sample = {"images": images, "labels": labels}
143
144 model = model.train()
145 context.registry.register("model", model)
146 optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
147 context.registry.register("optimizer", optimizer)
148 loss_fn = torch.nn.CrossEntropyLoss()
149
150 def f(inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
151 return {"loss": loss_fn(model(inputs["images"]), inputs["labels"])}
152
153 compile_options: dict[str, Union[str, bool]] = {}
154 if option_json_path is not None:
155 compile_options = {"option_json": str(option_json_path)}
156 compile_options["backprop"] = True
157 compiled_f = context.compile(
158 f,
159 sample,
160 storage.path(outdir) / "train_step_torch_onnx",
161 optimizers=[optimizer],
162 options=compile_options,
163 )
164
165 first_loss = compiled_f(sample)["loss"].cpu()
166 for _ in range(10):
167 compiled_f(sample)
168 context.synchronize()
169 last_loss = compiled_f(sample)["loss"].cpu()
170
171 assert last_loss < first_loss
172
173
174def run_training_fx2onnx(
175 model_name: str,
176 batch_size: int,
177 outdir: str,
178 option_json_path: Optional[Path],
179 device_str: str,
180 model_cache_dir: Optional[str],
181) -> None:
182 device = MNDevice(device_str)
183 context = Context(device)
184 Context.switch_context(context)
185
186 img = Image.open(SAMPLE_IMAGE_PATH)
187
188 model = create_model_with_cache(
189 model_name,
190 pretrained=True,
191 num_classes=1000,
192 model_cache_dir=model_cache_dir,
193 )
194 model = model.train()
195 set_tensor_name_in_module(model, "model0")
196 for p in model.parameters():
197 context.register_param(p)
198
199 optimizer = MNCoreSGD(model.parameters(), 0.1, 0.9, 0.0)
200 set_buffer_name_in_optimizer(optimizer, "optimizer0")
201 context.register_optimizer_buffers(optimizer)
202 loss_fn = torch.nn.CrossEntropyLoss()
203
204 def train_step(input: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
205 x = input["images"]
206 t = input["labels"]
207 optimizer.zero_grad()
208 y = model(x)
209 loss = loss_fn(y, t)
210 loss.backward()
211 optimizer.step()
212 return {"loss": loss}
213
214 data_config = timm.data.resolve_model_data_config(model)
215 transforms = timm.data.create_transform(**data_config, is_training=False)
216 images = transforms(img).unsqueeze(0).expand(batch_size, -1, -1, -1)
217 labels = torch.randint(0, 1000, (batch_size,))
218 sample = {"images": images, "labels": labels}
219
220 compile_options: dict[str, str] = {}
221 if option_json_path is not None:
222 compile_options = {"option_json": str(option_json_path)}
223
224 compiled_train_step = context.compile(
225 train_step,
226 sample,
227 storage.path(outdir) / "train_step_fx2onnx",
228 options=compile_options,
229 export_kwargs={"use_fx2onnx": True},
230 )
231
232 first_loss = compiled_train_step(sample)["loss"].cpu()
233 for _ in range(10):
234 compiled_train_step(sample)
235 context.synchronize()
236 last_loss = compiled_train_step(sample)["loss"].cpu()
237
238 assert last_loss < first_loss
239
240
241if __name__ == "__main__":
242 parser = argparse.ArgumentParser()
243 parser.add_argument("--batch_size", type=int, default=1, required=True)
244 parser.add_argument("--model_name", type=str)
245 parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_timm")
246 parser.add_argument("--option_json", type=Path, default=None)
247 parser.add_argument("--is_training", action="store_true")
248 parser.add_argument(
249 "--device",
250 type=str,
251 default="mncore2:auto",
252 choices=["mncore2:auto", "pfvm:cpu"],
253 )
254 parser.add_argument(
255 "--model_cache_dir",
256 type=str,
257 default=None,
258 help="Directory to cache the model weights. "
259 "If not set, weights are always downloaded from the hub. default: None",
260 )
261 args = parser.parse_args()
262
263 outdir = args.outdir
264 if outdir is None:
265 outdir = f"/tmp/MLSDK_codegen_dir_{args.model_name}"
266 if args.is_training:
267 outdir += "_training"
268 else:
269 outdir += "_inference"
270
271 # TODO (akirakawata): Should we make this argument?
272 use_fx2onnx = not bool(
273 int(os.environ.get("MNCORE_USE_LEGACY_ONNX_EXPORTER", False))
274 )
275 if args.is_training:
276 if use_fx2onnx:
277 run_training_fx2onnx(
278 args.model_name,
279 args.batch_size,
280 outdir,
281 args.option_json,
282 args.device,
283 args.model_cache_dir,
284 )
285 else:
286 run_training_torch_onnx(
287 args.model_name,
288 args.batch_size,
289 args.outdir,
290 args.option_json,
291 args.device,
292 args.model_cache_dir,
293 )
294 else:
295 run_inference(
296 args.model_name,
297 args.batch_size,
298 args.outdir,
299 args.option_json,
300 args.device,
301 args.model_cache_dir,
302 )