7.1.9. Example: Inference MNIST

A sample program that removes MLSDK API from mnist.py and performs inference using PyTorch.

This example assumes you’ve already tried Example: MNIST on MN-Core 2 once and checkpoint.pt exists (default is /tmp/mlsdk_mnist/checkpoint.pt).

Execution Method

$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./exec_with_env.sh python3 mnist_infer.py /tmp/mlsdk_mnist/checkpoint.pt

Expected Output

Correct: 9609 / 10000. Accuracy: 0.9609

Related Links

  • Migration Tutorial

    • This material serves as a reference for gradually introducing MLSDK API.

Sample Program

Listing 7.9 /opt/pfn/pfcomp/codegen/MLSDK/examples/mnist_infer.py
 1import argparse
 2import random
 3from pathlib import Path
 4from typing import Mapping, Optional
 5
 6import numpy as np
 7import torch
 8
 9from mnist_common import mnist_loaders, MNCoreClassifier
10
11torch.manual_seed(0)
12random.seed(0)
13np.random.seed(0)
14
15
16def main(checkpoint_path: str, outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
17    batch_size = 64
18    eval_batch_size = 125
19
20    _, eval_loader = mnist_loaders(batch_size, eval_batch_size)
21
22    checkpoint = torch.load(checkpoint_path)
23
24    model_with_loss_fn = MNCoreClassifier()
25    model_with_loss_fn.load_state_dict(checkpoint["model_state_dict"])
26    model_with_loss_fn.eval()
27
28    def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
29        x = inp["x"]
30        t = inp["t"]
31        output = model_with_loss_fn(x, t)
32        y = output["y"]
33        _, predicted = torch.max(y, 1)
34        correct = (predicted == t).sum()
35        return {"correct": correct}
36
37    correct = 0
38    for sample in eval_loader:
39        correct += eval_step(sample)["correct"]
40    print(
41        f"Correct: {correct} / {len(eval_loader.dataset)}. "
42        f"Accuracy: {correct / len(eval_loader.dataset)}"
43    )
44    assert 0.95 < correct / len(eval_loader.dataset)
45
46
47if __name__ == "__main__":
48    parser = argparse.ArgumentParser()
49    parser.add_argument("checkpoint_path", type=str)
50    parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist_infer")
51    parser.add_argument("--option_json", type=Path, default=None)
52    parser.add_argument("--device", type=str, default="mncore2:auto")
53    args = parser.parse_args()
54    main(args.checkpoint_path, args.outdir, args.option_json, args.device)