7.1.7. Example: Inference With Multiple Models
Sample program demonstrating how to handle multiple inference models within a single mlsdk.Context()
Execution Method
$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./exec_with_env.sh python3 infer_multi.py
Expected Output
The exit status of the program is
0.
Sample Program
1import torch
2from mlsdk import Context, MNDevice, set_tensor_name_in_module, storage
3
4
5def run_infer():
6 device = MNDevice("mncore2:auto")
7 context = Context(device)
8 Context.switch_context(context)
9
10 model0 = torch.nn.Linear(4, 4, bias=False)
11 model0.weight = torch.nn.Parameter(torch.ones(4, 4))
12 model0.eval()
13 model1 = torch.nn.Linear(4, 4, bias=False)
14 model1.weight = torch.nn.Parameter(torch.ones(4, 4) * 2)
15 model1.eval()
16
17 # To differentiate each model, Context uses the name specified by
18 # set_tensor_name_in_module, so these names must be set appropriately.
19 # Similarly, during training, set the name in set_buffer_name_in_optimizer
20 # as well.
21 set_tensor_name_in_module(model0, "model0")
22 set_tensor_name_in_module(model1, "model1")
23 for p in model0.parameters():
24 context.register_param(p)
25 for b in model0.buffers():
26 context.register_buffer(b)
27 for p in model1.parameters():
28 context.register_param(p)
29 for b in model1.buffers():
30 context.register_buffer(b)
31
32 def infer0(input: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
33 x = input["x"]
34 y = model0(x)
35 return {"out": y}
36
37 def infer1(input: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
38 x = input["x"]
39 y = model1(x)
40 return {"out": y}
41
42 sample = {"x": torch.ones(4, 4)}
43
44 compiled_infer0 = context.compile(
45 infer0,
46 sample,
47 storage.path("/tmp/infer0"),
48 )
49 compiled_infer1 = context.compile(
50 infer1,
51 sample,
52 storage.path("/tmp/infer1"),
53 )
54 result0 = compiled_infer0({"x": torch.ones(4, 4)})
55 result_on_cpu0 = result0["out"].cpu()
56 assert torch.allclose(result_on_cpu0, torch.ones(4, 4) @ torch.ones(4, 4))
57 result1 = compiled_infer1({"x": torch.ones(4, 4)})
58 result_on_cpu1 = result1["out"].cpu()
59 assert torch.allclose(result_on_cpu1, torch.ones(4, 4) @ (torch.ones(4, 4) * 2))
60
61
62if __name__ == "__main__":
63 run_infer()