7.1.7. Example: Inference With Multiple Models

Sample program demonstrating how to handle multiple inference models within a single mlsdk.Context()

Execution Method

$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./exec_with_env.sh python3 infer_multi.py

Expected Output

  • The exit status of the program is 0.

Sample Program

Listing 7.7 /opt/pfn/pfcomp/codegen/MLSDK/examples/infer_multi.py
 1import torch
 2from mlsdk import Context, MNDevice, set_tensor_name_in_module, storage
 3
 4
 5def run_infer():
 6    device = MNDevice("mncore2:auto")
 7    context = Context(device)
 8    Context.switch_context(context)
 9
10    model0 = torch.nn.Linear(4, 4, bias=False)
11    model0.weight = torch.nn.Parameter(torch.ones(4, 4))
12    model0.eval()
13    model1 = torch.nn.Linear(4, 4, bias=False)
14    model1.weight = torch.nn.Parameter(torch.ones(4, 4) * 2)
15    model1.eval()
16
17    # To differentiate each model, Context uses the name specified by
18    # set_tensor_name_in_module, so these names must be set appropriately.
19    # Similarly, during training, set the name in set_buffer_name_in_optimizer
20    # as well.
21    set_tensor_name_in_module(model0, "model0")
22    set_tensor_name_in_module(model1, "model1")
23    for p in model0.parameters():
24        context.register_param(p)
25    for b in model0.buffers():
26        context.register_buffer(b)
27    for p in model1.parameters():
28        context.register_param(p)
29    for b in model1.buffers():
30        context.register_buffer(b)
31
32    def infer0(input: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
33        x = input["x"]
34        y = model0(x)
35        return {"out": y}
36
37    def infer1(input: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
38        x = input["x"]
39        y = model1(x)
40        return {"out": y}
41
42    sample = {"x": torch.ones(4, 4)}
43
44    compiled_infer0 = context.compile(
45        infer0,
46        sample,
47        storage.path("/tmp/infer0"),
48    )
49    compiled_infer1 = context.compile(
50        infer1,
51        sample,
52        storage.path("/tmp/infer1"),
53    )
54    result0 = compiled_infer0({"x": torch.ones(4, 4)})
55    result_on_cpu0 = result0["out"].cpu()
56    assert torch.allclose(result_on_cpu0, torch.ones(4, 4) @ torch.ones(4, 4))
57    result1 = compiled_infer1({"x": torch.ones(4, 4)})
58    result_on_cpu1 = result1["out"].cpu()
59    assert torch.allclose(result_on_cpu1, torch.ones(4, 4) @ (torch.ones(4, 4) * 2))
60
61
62if __name__ == "__main__":
63    run_infer()