7.1.10. Example: Training MNIST
A sample program that removes MLSDK API from mnist.py and performs training using PyTorch.
Similar to Example: MNIST on MN-Core 2, but just outputting checkpoint.pt to a directory specified by --outdir (default is /tmp/mlsdk_mnist_train/checkpoint.pt).
Execution Method
$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./exec_with_env.sh python3 mnist_train.py
Expected Output
Training log output
A loss curve can differ from that of Example: MNIST on MN-Core 2 because different backends are used.
epoch 0, iter 0, loss 2.29758358001709
epoch 0, iter 100, loss 0.6065061688423157
...
epoch 9, iter 900, loss 0.12388602644205093
epoch 9, loss 0.12544165551662445
Checkpoint file (
checkpoint.pt)Supposed to be checked if the training performed properly by using
mnist_infer.pyAccuracymetric should be larger than0.95
Related Links
-
This material serves as a reference for gradually introducing MLSDK API.
Sample Program
1import argparse
2import random
3import os
4from pathlib import Path
5from typing import Mapping, Optional
6
7import numpy as np
8import torch
9from mlsdk import storage
10
11from mnist_common import mnist_loaders, MNCoreClassifier
12
13torch.manual_seed(0)
14random.seed(0)
15np.random.seed(0)
16
17
18def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
19 batch_size = 64
20 eval_batch_size = 125
21
22 train_loader, _ = mnist_loaders(batch_size, eval_batch_size)
23
24 model_with_loss_fn = MNCoreClassifier()
25 model_with_loss_fn.train()
26
27 optimizer = torch.optim.SGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
28
29 def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
30 x = inp["x"]
31 t = inp["t"]
32 optimizer.zero_grad()
33 output = model_with_loss_fn(x, t)
34 loss = output["loss"]
35 loss.backward()
36 optimizer.step()
37 return {"loss": loss}
38
39 for epoch in range(10):
40 loss = 0.0
41 for i, sample in enumerate(train_loader):
42 curr_loss = train_step(sample)["loss"]
43 loss += (curr_loss - loss) / (i + 1)
44 if i % 100 == 0:
45 print(f"epoch {epoch}, iter {i:4}, loss {loss}")
46 print(f"epoch {epoch}, loss {loss}")
47
48 os.makedirs(outdir, exist_ok=True)
49 torch.save(
50 {
51 "model_state_dict": model_with_loss_fn.state_dict(),
52 "optim_state_dict": optimizer.state_dict(),
53 },
54 storage.path(outdir) / "checkpoint.pt",
55 )
56
57
58if __name__ == "__main__":
59 parser = argparse.ArgumentParser()
60 parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist_train")
61 parser.add_argument("--option_json", type=Path, default=None)
62 parser.add_argument("--device", type=str, default="mncore2:auto")
63 args = parser.parse_args()
64 main(args.outdir, args.option_json, args.device)