7.1.2. Example: MNIST on MN-Core 2
MNIST データセットを対象に、MN-Core 2 上で学習と推論を行うサンプルプログラム
学習結果は --outdir に指定された先の checkpoint.pt ファイルに保存されます (デフォルトでは /tmp/mlsdk_mnist/checkpoint.pt)。
実行方法
$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./exec_with_env.sh python3 mnist.py
想定出力
学習中のログ
epoch 0, iter 0, loss 2.3125
epoch 0, iter 100, loss 0.6226431969368814
...
epoch 9, iter 900, loss 0.10909322893182918
epoch 9, loss 0.11064393848594248
推論結果
Correct: 9609 / 10000. Accuracy: 0.9609
関連リンク
サンプルプログラム
1import argparse
2import random
3from pathlib import Path
4from typing import Mapping, Optional
5
6import numpy as np
7import torch
8from mlsdk import (
9 Context,
10 MNCoreSGD,
11 MNDevice,
12 set_buffer_name_in_optimizer,
13 set_tensor_name_in_module,
14 storage,
15)
16
17from mnist_common import mnist_loaders, MNCoreClassifier
18
19torch.manual_seed(0)
20random.seed(0)
21np.random.seed(0)
22
23
24def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
25 batch_size = 64
26 eval_batch_size = 125
27
28 device = MNDevice(device_str)
29 context = Context(device)
30 Context.switch_context(context)
31
32 train_loader, eval_loader = mnist_loaders(batch_size, eval_batch_size)
33
34 model_with_loss_fn = MNCoreClassifier()
35 model_with_loss_fn.train()
36 set_tensor_name_in_module(model_with_loss_fn, "model_with_loss_fn")
37 for p in model_with_loss_fn.parameters():
38 context.register_param(p)
39
40 optimizer = MNCoreSGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
41 set_buffer_name_in_optimizer(optimizer, "optimizer")
42 context.register_optimizer_buffers(optimizer)
43
44 def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
45 x = inp["x"]
46 t = inp["t"]
47 optimizer.zero_grad()
48 output = model_with_loss_fn(x, t)
49 loss = output["loss"]
50 loss.backward()
51 optimizer.step()
52 return {"loss": loss}
53
54 compile_options = {}
55 if option_json_path is not None:
56 compile_options["option_json"] = str(option_json_path)
57
58 sample = next(iter(train_loader))
59 compiled_train_step = context.compile(
60 train_step,
61 sample,
62 storage.path(outdir) / "train_step",
63 options=compile_options,
64 )
65
66 for epoch in range(10):
67 loss = 0.0
68 for i, sample in enumerate(train_loader):
69 curr_loss = compiled_train_step(sample)["loss"].item()
70 loss += (curr_loss - loss) / (i + 1)
71 if i % 100 == 0:
72 print(f"epoch {epoch}, iter {i:4}, loss {loss}")
73 print(f"epoch {epoch}, loss {loss}")
74
75 context.synchronize()
76
77 torch.save(
78 {
79 "model_state_dict": model_with_loss_fn.state_dict(),
80 "optim_state_dict": optimizer.state_dict(),
81 },
82 storage.path(outdir) / "checkpoint.pt",
83 )
84
85 model_with_loss_fn.eval()
86
87 def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
88 x = inp["x"]
89 t = inp["t"]
90 output = model_with_loss_fn(x, t)
91 y = output["y"]
92 _, predicted = torch.max(y, 1)
93 correct = (predicted == t).sum()
94 return {"correct": correct}
95
96 sample = next(iter(eval_loader))
97 compiled_eval_step = context.compile(
98 eval_step,
99 sample,
100 storage.path(outdir) / "eval_step",
101 options=compile_options,
102 )
103 correct = 0
104 for sample in eval_loader:
105 correct += compiled_eval_step(sample)["correct"].item()
106 print(
107 f"Correct: {correct} / {len(eval_loader.dataset)}. "
108 f"Accuracy: {correct / len(eval_loader.dataset)}"
109 )
110 assert 0.95 < correct / len(eval_loader.dataset)
111
112
113if __name__ == "__main__":
114 parser = argparse.ArgumentParser()
115 parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist")
116 parser.add_argument("--option_json", type=Path, default=None)
117 parser.add_argument("--device", type=str, default="mncore2:auto")
118 args = parser.parse_args()
119 main(args.outdir, args.option_json, args.device)