1 # BSD 3-Clause License
  2 #
  3 # Copyright (c) 2017-2022, Pytorch contributors
  4 # All rights reserved.
  5 # Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  6 #
  7 # Redistribution and use in source and binary forms, with or without
  8 # modification, are permitted provided that the following conditions are met:
  9 #
 10 # * Redistributions of source code must retain the above copyright notice, this
 11 #   list of conditions and the following disclaimer.
 12 #
 13 # * Redistributions in binary form must reproduce the above copyright notice,
 14 #   this list of conditions and the following disclaimer in the documentation
 15 #   and/or other materials provided with the distribution.
 16 #
 17 # * Neither the name of the copyright holder nor the names of its
 18 #   contributors may be used to endorse or promote products derived from
 19 #   this software without specific prior written permission.
 20 #
 21 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 22 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 23 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 24 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 25 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 26 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 27 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 28 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 29 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 30 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31 
 32 
 33 # This file is aggregated and lightly adapted from the PyTorch introductory tutorial,
 34 # https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
 35 # with additional sections to export the model to ONNX format.
 36 #
 37 # Running this file requires a Python environment with the following packages installed:
 38 # torch torchvision onnx onnxscript onnxruntime
 39 # and their dependencies.
 40 #
 41 # It downloads the MNIST dataset and trains a 5 layer convolutional neural network.
 42 
 43 import onnx
 44 import torch
 45 import torch.nn as nn
 46 import torch.nn.functional as F
 47 import torch.optim as optim
 48 import torchvision
 49 import torchvision.transforms as transforms
 50 
 51 
 52 class Net(nn.Module):
 53 
 54     def __init__(self):
 55         super(Net, self).__init__()
 56         # 1 input image channel, 6 output channels, 5x5 square convolution
 57         # kernel
 58         self.conv1 = nn.Conv2d(1, 6, 5)
 59         self.conv2 = nn.Conv2d(6, 16, 5)
 60         # an affine operation: y = Wx + b
 61         self.fc1 = nn.Linear(16 * 4 * 4, 120)  # 4*4 from image dimension
 62         self.fc2 = nn.Linear(120, 84)
 63         self.fc3 = nn.Linear(84, 10)
 64 
 65     def forward(self, input):
 66         # Convolution layer C1: 1 input image channel, 6 output channels,
 67         # 5x5 square convolution, it uses RELU activation function, and
 68         # outputs a Tensor with size (N, 6, 28, 28), where N is the size of the batch
 69         c1 = F.relu(self.conv1(input))
 70         # Subsampling layer S2: 2x2 grid, purely functional,
 71         # this layer does not have any parameter, and outputs a (N, 6, 14, 14) Tensor
 72         s2 = F.max_pool2d(c1, (2, 2))
 73         # Convolution layer C3: 6 input channels, 16 output channels,
 74         # 5x5 square convolution, it uses RELU activation function, and
 75         # outputs a (N, 16, 10, 10) Tensor
 76         c3 = F.relu(self.conv2(s2))
 77         # Subsampling layer S4: 2x2 grid, purely functional,
 78         # this layer does not have any parameter, and outputs a (N, 16, 4, 4) Tensor
 79         s4 = F.max_pool2d(c3, 2)
 80         # Flatten operation: purely functional, outputs a (N, 256) Tensor
 81         s4 = torch.flatten(s4, 1)
 82         # Fully connected layer F5: (N, 400) Tensor input,
 83         # and outputs a (N, 120) Tensor, it uses RELU activation function
 84         f5 = F.relu(self.fc1(s4))
 85         # Fully connected layer F6: (N, 120) Tensor input,
 86         # and outputs a (N, 84) Tensor, it uses RELU activation function
 87         f6 = F.relu(self.fc2(f5))
 88         # Logits layer OUTPUT: (N, 84) Tensor input, and
 89         # outputs a (N, 10) Tensor
 90         output = self.fc3(f6)
 91         return output
 92 
 93 
 94 if __name__ == "__main__":
 95     torch.manual_seed(0)
 96     device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
 97     print(device)
 98     net = Net()
 99     net = net.to(device)
100     print(net)
101 
102     batch_size = 16
103     num_epochs = 2
104     transform = transforms.ToTensor()
105 
106     trainset = torchvision.datasets.MNIST(root='./data', train=True,
107                                             download=True, transform=transform)
108     trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
109                                               shuffle=True, num_workers=2)
110 
111     testset = torchvision.datasets.MNIST(root='./data', train=False,
112                                            download=True, transform=transform)
113     testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
114                                              shuffle=False, num_workers=2)
115 
116     print("loaded data")
117 
118     criterion = nn.CrossEntropyLoss()
119     optimizer = optim.Adam(net.parameters())
120 
121     for epoch in range(num_epochs):  # loop over the dataset multiple times
122 
123         running_loss = 0.0
124         for i, data in enumerate(trainloader, 0):
125             # get the inputs; data is a list of [inputs, labels]
126             inputs, labels = data[0].to(device), data[1].to(device)
127 
128             # zero the parameter gradients
129             optimizer.zero_grad()
130 
131             # forward + backward + optimize
132             outputs = net(inputs)
133             loss = criterion(outputs, labels)
134             loss.backward()
135             optimizer.step()
136 
137             # print statistics
138             running_loss += loss.item()
139             if i % 500 == 499:
140                 print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 500:.3f}')
141                 running_loss = 0.0
142 
143     print('Finished Training')
144 
145     torch_path = './mnist_net.pth'
146     torch.save(net.state_dict(), torch_path)
147     print("Saved pytorch model")
148 
149     net.eval()
150 
151     correct = 0
152     total = 0
153     with torch.no_grad():
154         for data in testloader:
155             inputs, labels = data[0].to(device), data[1].to(device)
156             # calculate outputs by running images through the network
157             outputs = net(inputs)
158             # the class with the highest energy is what we choose as prediction
159             _, predicted = torch.max(outputs, 1)
160             total += labels.size(0)
161             correct += (predicted == labels).sum().item()
162 
163     print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
164 
165     itr = iter(testloader)
166     test_input, test_label = next(itr)
167     test_input = test_input.to(device)
168 
169     onnx_program = torch.onnx.dynamo_export(net, test_input)
170 
171     onnx_program.save("lenet-dynamo.onnx")
172     print("Saved ONNX dynamo model")
173 
174     torch.onnx.export(net, test_input, "lenet-torchscript.onnx", verbose=True,
175                       input_names=["image"], output_names=["logits"],
176                       dynamic_axes={'image' : {0 : 'batch_size'},  # variable length axes
177                                 'logits' : {0 : 'batch_size'}})
178     print("Saved ONNX torchscript model")
179 
180     # Strip out node docstrings which contain file paths
181     onnx_model = onnx.load_model("lenet-torchscript.onnx")
182     for n in onnx_model.graph.node:
183         n.doc_string = ""
184     onnx.save_model(onnx_model, "lenet-torchscript.onnx")
185     print("Tidied up ONNX torchscript model")