7.1.2. Example: MNIST on MN-Core 2

A sample program demonstrating training and inference operations on the MNIST dataset using MN-Core 2.

Training results are saved to the checkpoint.pt file located in the directory specified by the --outdir flag (default is /tmp/mlsdk_mnist/checkpoint.pt).

Execution Method

$ cd /opt/pfn/pfcomp/codegen/examples/
$ ./exec_with_env.sh python3 mnist.py

Expected Output

  • Training log output

epoch 0, iter    0, loss 2.3125
epoch 0, iter  100, loss 0.6226431969368814
...
epoch 9, iter  900, loss 0.10909322893182918
epoch 9, loss 0.11064393848594248
  • Inference results

Correct: 9609 / 10000. Accuracy: 0.9609

Related Links

Sample Program

Listing 7.2 /opt/pfn/pfcomp/codegen/MLSDK/examples/mnist.py
  1import argparse
  2import random
  3from pathlib import Path
  4from typing import Mapping, Optional
  5
  6import numpy as np
  7import torch
  8from mlsdk import (
  9    Context,
 10    MNCoreSGD,
 11    MNDevice,
 12    set_buffer_name_in_optimizer,
 13    set_tensor_name_in_module,
 14    storage,
 15)
 16
 17from mnist_common import mnist_loaders, MNCoreClassifier
 18
 19torch.manual_seed(0)
 20random.seed(0)
 21np.random.seed(0)
 22
 23
 24def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
 25    batch_size = 64
 26    eval_batch_size = 125
 27
 28    device = MNDevice(device_str)
 29    context = Context(device)
 30    Context.switch_context(context)
 31
 32    train_loader, eval_loader = mnist_loaders(batch_size, eval_batch_size)
 33
 34    model_with_loss_fn = MNCoreClassifier()
 35    model_with_loss_fn.train()
 36    set_tensor_name_in_module(model_with_loss_fn, "model_with_loss_fn")
 37    for p in model_with_loss_fn.parameters():
 38        context.register_param(p)
 39
 40    optimizer = MNCoreSGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
 41    set_buffer_name_in_optimizer(optimizer, "optimizer")
 42    context.register_optimizer_buffers(optimizer)
 43
 44    def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
 45        x = inp["x"]
 46        t = inp["t"]
 47        optimizer.zero_grad()
 48        output = model_with_loss_fn(x, t)
 49        loss = output["loss"]
 50        loss.backward()
 51        optimizer.step()
 52        return {"loss": loss}
 53
 54    compile_options = {}
 55    if option_json_path is not None:
 56        compile_options["option_json"] = str(option_json_path)
 57
 58    sample = next(iter(train_loader))
 59    compiled_train_step = context.compile(
 60        train_step,
 61        sample,
 62        storage.path(outdir) / "train_step",
 63        options=compile_options,
 64    )
 65
 66    for epoch in range(10):
 67        loss = 0.0
 68        for i, sample in enumerate(train_loader):
 69            curr_loss = compiled_train_step(sample)["loss"].item()
 70            loss += (curr_loss - loss) / (i + 1)
 71            if i % 100 == 0:
 72                print(f"epoch {epoch}, iter {i:4}, loss {loss}")
 73        print(f"epoch {epoch}, loss {loss}")
 74
 75    context.synchronize()
 76
 77    torch.save(
 78        {
 79            "model_state_dict": model_with_loss_fn.state_dict(),
 80            "optim_state_dict": optimizer.state_dict(),
 81        },
 82        storage.path(outdir) / "checkpoint.pt",
 83    )
 84
 85    model_with_loss_fn.eval()
 86
 87    def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
 88        x = inp["x"]
 89        t = inp["t"]
 90        output = model_with_loss_fn(x, t)
 91        y = output["y"]
 92        _, predicted = torch.max(y, 1)
 93        correct = (predicted == t).sum()
 94        return {"correct": correct}
 95
 96    sample = next(iter(eval_loader))
 97    compiled_eval_step = context.compile(
 98        eval_step,
 99        sample,
100        storage.path(outdir) / "eval_step",
101        options=compile_options,
102    )
103    correct = 0
104    for sample in eval_loader:
105        correct += compiled_eval_step(sample)["correct"].item()
106    print(
107        f"Correct: {correct} / {len(eval_loader.dataset)}. "
108        f"Accuracy: {correct / len(eval_loader.dataset)}"
109    )
110    assert 0.95 < correct / len(eval_loader.dataset)
111
112
113if __name__ == "__main__":
114    parser = argparse.ArgumentParser()
115    parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist")
116    parser.add_argument("--option_json", type=Path, default=None)
117    parser.add_argument("--device", type=str, default="mncore2:auto")
118    args = parser.parse_args()
119    main(args.outdir, args.option_json, args.device)