7.1.2. Example: MNIST on MN-Core 2
A sample program demonstrating training and inference operations on the MNIST dataset using MN-Core 2.
Training results are saved to the checkpoint.pt file located in the directory specified by the --outdir flag (default is /tmp/mlsdk_mnist/checkpoint.pt).
Execution Method
$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./exec_with_env.sh python3 mnist.py
Expected Output
Training log output
epoch 0, iter 0, loss 2.3125
epoch 0, iter 100, loss 0.6226431969368814
...
epoch 9, iter 900, loss 0.10909322893182918
epoch 9, loss 0.11064393848594248
Inference results
Correct: 9609 / 10000. Accuracy: 0.9609
Related Links
Sample Program
1import argparse
2import random
3from pathlib import Path
4from typing import Mapping, Optional
5
6import numpy as np
7import torch
8from mlsdk import (
9 Context,
10 MNCoreSGD,
11 MNDevice,
12 set_buffer_name_in_optimizer,
13 set_tensor_name_in_module,
14 storage,
15)
16
17from mnist_common import mnist_loaders, MNCoreClassifier
18
19torch.manual_seed(0)
20random.seed(0)
21np.random.seed(0)
22
23
24def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
25 batch_size = 64
26 eval_batch_size = 125
27
28 device = MNDevice(device_str)
29 context = Context(device)
30 Context.switch_context(context)
31
32 train_loader, eval_loader = mnist_loaders(batch_size, eval_batch_size)
33
34 model_with_loss_fn = MNCoreClassifier()
35 model_with_loss_fn.train()
36 set_tensor_name_in_module(model_with_loss_fn, "model_with_loss_fn")
37 for p in model_with_loss_fn.parameters():
38 context.register_param(p)
39
40 optimizer = MNCoreSGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
41 set_buffer_name_in_optimizer(optimizer, "optimizer")
42 context.register_optimizer_buffers(optimizer)
43
44 def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
45 x = inp["x"]
46 t = inp["t"]
47 optimizer.zero_grad()
48 output = model_with_loss_fn(x, t)
49 loss = output["loss"]
50 loss.backward()
51 optimizer.step()
52 return {"loss": loss}
53
54 compile_options = {}
55 if option_json_path is not None:
56 compile_options["option_json"] = str(option_json_path)
57
58 sample = next(iter(train_loader))
59 compiled_train_step = context.compile(
60 train_step,
61 sample,
62 storage.path(outdir) / "train_step",
63 options=compile_options,
64 )
65
66 for epoch in range(10):
67 loss = 0.0
68 for i, sample in enumerate(train_loader):
69 curr_loss = compiled_train_step(sample)["loss"].item()
70 loss += (curr_loss - loss) / (i + 1)
71 if i % 100 == 0:
72 print(f"epoch {epoch}, iter {i:4}, loss {loss}")
73 print(f"epoch {epoch}, loss {loss}")
74
75 context.synchronize()
76
77 torch.save(
78 {
79 "model_state_dict": model_with_loss_fn.state_dict(),
80 "optim_state_dict": optimizer.state_dict(),
81 },
82 storage.path(outdir) / "checkpoint.pt",
83 )
84
85 model_with_loss_fn.eval()
86
87 def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
88 x = inp["x"]
89 t = inp["t"]
90 output = model_with_loss_fn(x, t)
91 y = output["y"]
92 _, predicted = torch.max(y, 1)
93 correct = (predicted == t).sum()
94 return {"correct": correct}
95
96 sample = next(iter(eval_loader))
97 compiled_eval_step = context.compile(
98 eval_step,
99 sample,
100 storage.path(outdir) / "eval_step",
101 options=compile_options,
102 )
103 correct = 0
104 for sample in eval_loader:
105 correct += compiled_eval_step(sample)["correct"].item()
106 print(
107 f"Correct: {correct} / {len(eval_loader.dataset)}. "
108 f"Accuracy: {correct / len(eval_loader.dataset)}"
109 )
110 assert 0.95 < correct / len(eval_loader.dataset)
111
112
113if __name__ == "__main__":
114 parser = argparse.ArgumentParser()
115 parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist")
116 parser.add_argument("--option_json", type=Path, default=None)
117 parser.add_argument("--device", type=str, default="mncore2:auto")
118 args = parser.parse_args()
119 main(args.outdir, args.option_json, args.device)