7.1.9. Example: Inference MNIST
A sample program that removes MLSDK API from mnist.py and performs inference using PyTorch.
This example assumes you’ve already tried Example: MNIST on MN-Core 2 once and checkpoint.pt exists (default is /tmp/mlsdk_mnist/checkpoint.pt).
Execution Method
$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./exec_with_env.sh python3 mnist_infer.py /tmp/mlsdk_mnist/checkpoint.pt
Expected Output
Inference results (should be identical to those from Example: MNIST on MN-Core 2)
Correct: 9609 / 10000. Accuracy: 0.9609
Related Links
-
This material serves as a reference for gradually introducing MLSDK API.
Sample Program
1import argparse
2import random
3from pathlib import Path
4from typing import Mapping, Optional
5
6import numpy as np
7import torch
8
9from mnist_common import mnist_loaders, MNCoreClassifier
10
11torch.manual_seed(0)
12random.seed(0)
13np.random.seed(0)
14
15
16def main(checkpoint_path: str, outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
17 batch_size = 64
18 eval_batch_size = 125
19
20 _, eval_loader = mnist_loaders(batch_size, eval_batch_size)
21
22 checkpoint = torch.load(checkpoint_path)
23
24 model_with_loss_fn = MNCoreClassifier()
25 model_with_loss_fn.load_state_dict(checkpoint["model_state_dict"])
26 model_with_loss_fn.eval()
27
28 def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
29 x = inp["x"]
30 t = inp["t"]
31 output = model_with_loss_fn(x, t)
32 y = output["y"]
33 _, predicted = torch.max(y, 1)
34 correct = (predicted == t).sum()
35 return {"correct": correct}
36
37 correct = 0
38 for sample in eval_loader:
39 correct += eval_step(sample)["correct"]
40 print(
41 f"Correct: {correct} / {len(eval_loader.dataset)}. "
42 f"Accuracy: {correct / len(eval_loader.dataset)}"
43 )
44 assert 0.95 < correct / len(eval_loader.dataset)
45
46
47if __name__ == "__main__":
48 parser = argparse.ArgumentParser()
49 parser.add_argument("checkpoint_path", type=str)
50 parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist_infer")
51 parser.add_argument("--option_json", type=Path, default=None)
52 parser.add_argument("--device", type=str, default="mncore2:auto")
53 args = parser.parse_args()
54 main(args.checkpoint_path, args.outdir, args.option_json, args.device)