7.1.10. Example: Training MNIST

mnist.py から MLSDK API を取り除き、PyTorch で学習を行うサンプルプログラム

Example: MNIST on MN-Core 2 と同様ですが、単に checkpoint.pt--output で指定された場所に保存するだけです。 (デフォルトでは /tmp/mlsdk_mnist_train/checkpoint.pt)

実行方法

$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./exec_with_env.sh python3 mnist_train.py

想定出力

  • 学習中のログ

    • Loss curve が Example: MNIST on MN-Core 2 のものと異なることがありますが、これは異なるバックエンドが使用されているためです。

epoch 0, iter    0, loss 2.29758358001709
epoch 0, iter  100, loss 0.6065061688423157
...
epoch 9, iter  900, loss 0.12388602644205093
epoch 9, loss 0.12544165551662445
  • チェックポイント (checkpoint.pt)

    • 学習が正常に完了したかは mnist_infer.py を使ってチェックします。

    • Accuracy 指標が 0.95 よりも大きければ良いです。

関連リンク

サンプルプログラム

リスト 7.10 /opt/pfn/pfcomp/codegen/MLSDK/examples/mnist_train.py
 1import argparse
 2import random
 3import os
 4from pathlib import Path
 5from typing import Mapping, Optional
 6
 7import numpy as np
 8import torch
 9from mlsdk import storage
10
11from mnist_common import mnist_loaders, MNCoreClassifier
12
13torch.manual_seed(0)
14random.seed(0)
15np.random.seed(0)
16
17
18def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
19    batch_size = 64
20    eval_batch_size = 125
21
22    train_loader, _ = mnist_loaders(batch_size, eval_batch_size)
23
24    model_with_loss_fn = MNCoreClassifier()
25    model_with_loss_fn.train()
26
27    optimizer = torch.optim.SGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
28
29    def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
30        x = inp["x"]
31        t = inp["t"]
32        optimizer.zero_grad()
33        output = model_with_loss_fn(x, t)
34        loss = output["loss"]
35        loss.backward()
36        optimizer.step()
37        return {"loss": loss}
38
39    for epoch in range(10):
40        loss = 0.0
41        for i, sample in enumerate(train_loader):
42            curr_loss = train_step(sample)["loss"]
43            loss += (curr_loss - loss) / (i + 1)
44            if i % 100 == 0:
45                print(f"epoch {epoch}, iter {i:4}, loss {loss}")
46        print(f"epoch {epoch}, loss {loss}")
47
48    os.makedirs(outdir, exist_ok=True)
49    torch.save(
50        {
51            "model_state_dict": model_with_loss_fn.state_dict(),
52            "optim_state_dict": optimizer.state_dict(),
53        },
54        storage.path(outdir) / "checkpoint.pt",
55    )
56
57
58if __name__ == "__main__":
59    parser = argparse.ArgumentParser()
60    parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist_train")
61    parser.add_argument("--option_json", type=Path, default=None)
62    parser.add_argument("--device", type=str, default="mncore2:auto")
63    args = parser.parse_args()
64    main(args.outdir, args.option_json, args.device)