7.1.10. Example: Training MNIST

A sample program that removes MLSDK API from mnist.py and performs training using PyTorch.

Similar to Example: MNIST on MN-Core 2, but just outputting checkpoint.pt to a directory specified by --outdir (default is /tmp/mlsdk_mnist_train/checkpoint.pt).

Execution Method

$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./exec_with_env.sh python3 mnist_train.py

Expected Output

epoch 0, iter    0, loss 2.29758358001709
epoch 0, iter  100, loss 0.6065061688423157
...
epoch 9, iter  900, loss 0.12388602644205093
epoch 9, loss 0.12544165551662445
  • Checkpoint file (checkpoint.pt)

    • Supposed to be checked if the training performed properly by using mnist_infer.py

    • Accuracy metric should be larger than 0.95

Related Links

  • Migration Tutorial

    • This material serves as a reference for gradually introducing MLSDK API.

Sample Program

Listing 7.10 /opt/pfn/pfcomp/codegen/MLSDK/examples/mnist_train.py
 1import argparse
 2import random
 3import os
 4from pathlib import Path
 5from typing import Mapping, Optional
 6
 7import numpy as np
 8import torch
 9from mlsdk import storage
10
11from mnist_common import mnist_loaders, MNCoreClassifier
12
13torch.manual_seed(0)
14random.seed(0)
15np.random.seed(0)
16
17
18def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
19    batch_size = 64
20    eval_batch_size = 125
21
22    train_loader, _ = mnist_loaders(batch_size, eval_batch_size)
23
24    model_with_loss_fn = MNCoreClassifier()
25    model_with_loss_fn.train()
26
27    optimizer = torch.optim.SGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
28
29    def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
30        x = inp["x"]
31        t = inp["t"]
32        optimizer.zero_grad()
33        output = model_with_loss_fn(x, t)
34        loss = output["loss"]
35        loss.backward()
36        optimizer.step()
37        return {"loss": loss}
38
39    for epoch in range(10):
40        loss = 0.0
41        for i, sample in enumerate(train_loader):
42            curr_loss = train_step(sample)["loss"]
43            loss += (curr_loss - loss) / (i + 1)
44            if i % 100 == 0:
45                print(f"epoch {epoch}, iter {i:4}, loss {loss}")
46        print(f"epoch {epoch}, loss {loss}")
47
48    os.makedirs(outdir, exist_ok=True)
49    torch.save(
50        {
51            "model_state_dict": model_with_loss_fn.state_dict(),
52            "optim_state_dict": optimizer.state_dict(),
53        },
54        storage.path(outdir) / "checkpoint.pt",
55    )
56
57
58if __name__ == "__main__":
59    parser = argparse.ArgumentParser()
60    parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist_train")
61    parser.add_argument("--option_json", type=Path, default=None)
62    parser.add_argument("--device", type=str, default="mncore2:auto")
63    args = parser.parse_args()
64    main(args.outdir, args.option_json, args.device)