1 # BSD 3-Clause License 2 # 3 # Copyright (c) 2017-2022, Pytorch contributors 4 # All rights reserved. 5 # Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. 6 # 7 # Redistribution and use in source and binary forms, with or without 8 # modification, are permitted provided that the following conditions are met: 9 # 10 # * Redistributions of source code must retain the above copyright notice, this 11 # list of conditions and the following disclaimer. 12 # 13 # * Redistributions in binary form must reproduce the above copyright notice, 14 # this list of conditions and the following disclaimer in the documentation 15 # and/or other materials provided with the distribution. 16 # 17 # * Neither the name of the copyright holder nor the names of its 18 # contributors may be used to endorse or promote products derived from 19 # this software without specific prior written permission. 20 # 21 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32 33 # This file is aggregated and lightly adapted from the PyTorch introductory tutorial, 34 # https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html 35 # with additional sections to export the model to ONNX format. 36 # 37 # Running this file requires a Python environment with the following packages installed: 38 # torch torchvision onnx onnxscript onnxruntime 39 # and their dependencies. 40 # 41 # It downloads the MNIST dataset and trains a 5 layer convolutional neural network. 42 43 import onnx 44 import torch 45 import torch.nn as nn 46 import torch.nn.functional as F 47 import torch.optim as optim 48 import torchvision 49 import torchvision.transforms as transforms 50 51 52 class Net(nn.Module): 53 54 def __init__(self): 55 super(Net, self).__init__() 56 # 1 input image channel, 6 output channels, 5x5 square convolution 57 # kernel 58 self.conv1 = nn.Conv2d(1, 6, 5) 59 self.conv2 = nn.Conv2d(6, 16, 5) 60 # an affine operation: y = Wx + b 61 self.fc1 = nn.Linear(16 * 4 * 4, 120) # 4*4 from image dimension 62 self.fc2 = nn.Linear(120, 84) 63 self.fc3 = nn.Linear(84, 10) 64 65 def forward(self, input): 66 # Convolution layer C1: 1 input image channel, 6 output channels, 67 # 5x5 square convolution, it uses RELU activation function, and 68 # outputs a Tensor with size (N, 6, 28, 28), where N is the size of the batch 69 c1 = F.relu(self.conv1(input)) 70 # Subsampling layer S2: 2x2 grid, purely functional, 71 # this layer does not have any parameter, and outputs a (N, 6, 14, 14) Tensor 72 s2 = F.max_pool2d(c1, (2, 2)) 73 # Convolution layer C3: 6 input channels, 16 output channels, 74 # 5x5 square convolution, it uses RELU activation function, and 75 # outputs a (N, 16, 10, 10) Tensor 76 c3 = F.relu(self.conv2(s2)) 77 # Subsampling layer S4: 2x2 grid, purely functional, 78 # this layer does not have any parameter, and outputs a (N, 16, 4, 4) Tensor 79 s4 = F.max_pool2d(c3, 2) 80 # Flatten operation: purely functional, outputs a (N, 256) Tensor 81 s4 = torch.flatten(s4, 1) 82 # Fully connected layer F5: (N, 400) Tensor input, 83 # and outputs a (N, 120) Tensor, it uses RELU activation function 84 f5 = F.relu(self.fc1(s4)) 85 # Fully connected layer F6: (N, 120) Tensor input, 86 # and outputs a (N, 84) Tensor, it uses RELU activation function 87 f6 = F.relu(self.fc2(f5)) 88 # Logits layer OUTPUT: (N, 84) Tensor input, and 89 # outputs a (N, 10) Tensor 90 output = self.fc3(f6) 91 return output 92 93 94 if __name__ == "__main__": 95 torch.manual_seed(0) 96 device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu') 97 print(device) 98 net = Net() 99 net = net.to(device) 100 print(net) 101 102 batch_size = 16 103 num_epochs = 2 104 transform = transforms.ToTensor() 105 106 trainset = torchvision.datasets.MNIST(root='./data', train=True, 107 download=True, transform=transform) 108 trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 109 shuffle=True, num_workers=2) 110 111 testset = torchvision.datasets.MNIST(root='./data', train=False, 112 download=True, transform=transform) 113 testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 114 shuffle=False, num_workers=2) 115 116 print("loaded data") 117 118 criterion = nn.CrossEntropyLoss() 119 optimizer = optim.Adam(net.parameters()) 120 121 for epoch in range(num_epochs): # loop over the dataset multiple times 122 123 running_loss = 0.0 124 for i, data in enumerate(trainloader, 0): 125 # get the inputs; data is a list of [inputs, labels] 126 inputs, labels = data[0].to(device), data[1].to(device) 127 128 # zero the parameter gradients 129 optimizer.zero_grad() 130 131 # forward + backward + optimize 132 outputs = net(inputs) 133 loss = criterion(outputs, labels) 134 loss.backward() 135 optimizer.step() 136 137 # print statistics 138 running_loss += loss.item() 139 if i % 500 == 499: 140 print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 500:.3f}') 141 running_loss = 0.0 142 143 print('Finished Training') 144 145 torch_path = './mnist_net.pth' 146 torch.save(net.state_dict(), torch_path) 147 print("Saved pytorch model") 148 149 net.eval() 150 151 correct = 0 152 total = 0 153 with torch.no_grad(): 154 for data in testloader: 155 inputs, labels = data[0].to(device), data[1].to(device) 156 # calculate outputs by running images through the network 157 outputs = net(inputs) 158 # the class with the highest energy is what we choose as prediction 159 _, predicted = torch.max(outputs, 1) 160 total += labels.size(0) 161 correct += (predicted == labels).sum().item() 162 163 print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %') 164 165 itr = iter(testloader) 166 test_input, test_label = next(itr) 167 test_input = test_input.to(device) 168 169 onnx_program = torch.onnx.dynamo_export(net, test_input) 170 171 onnx_program.save("lenet-dynamo.onnx") 172 print("Saved ONNX dynamo model") 173 174 torch.onnx.export(net, test_input, "lenet-torchscript.onnx", verbose=True, 175 input_names=["image"], output_names=["logits"], 176 dynamic_axes={'image' : {0 : 'batch_size'}, # variable length axes 177 'logits' : {0 : 'batch_size'}}) 178 print("Saved ONNX torchscript model") 179 180 # Strip out node docstrings which contain file paths 181 onnx_model = onnx.load_model("lenet-torchscript.onnx") 182 for n in onnx_model.graph.node: 183 n.doc_string = "" 184 onnx.save_model(onnx_model, "lenet-torchscript.onnx") 185 print("Tidied up ONNX torchscript model")