1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.onnx;
 25 
 26 import java.io.*;
 27 import java.lang.foreign.Arena;
 28 import jdk.incubator.code.Block;
 29 import jdk.incubator.code.Reflect;
 30 import jdk.incubator.code.Op;
 31 import jdk.incubator.code.dialect.core.CoreOp;
 32 import jdk.incubator.code.dialect.core.CoreType;
 33 import jdk.incubator.code.dialect.core.FunctionType;
 34 import jdk.incubator.code.dialect.core.TupleType;
 35 import jdk.incubator.code.extern.OpWriter;
 36 import oracle.code.onnx.ir.OnnxOps;
 37 import oracle.code.onnx.ir.OnnxType;
 38 import org.junit.jupiter.api.Test;
 39 import org.junit.jupiter.api.Assertions;
 40 
 41 import java.lang.invoke.MethodHandles;
 42 import java.lang.reflect.Method;
 43 import java.util.Optional;
 44 import java.util.Set;
 45 import java.util.function.Supplier;
 46 import java.util.stream.Stream;
 47 import java.lang.foreign.MemorySegment;
 48 import java.lang.foreign.ValueLayout;
 49 import java.nio.channels.FileChannel;
 50 import java.util.function.Function;
 51 import oracle.code.onnx.compiler.OnnxTransformer;
 52 
 53 import static java.util.Optional.empty;
 54 import static java.util.Optional.of;
 55 import static oracle.code.onnx.OnnxOperators.Cast;
 56 import static oracle.code.onnx.OnnxOperators.Constant;
 57 import static oracle.code.onnx.OnnxOperators.Conv;
 58 import static oracle.code.onnx.OnnxOperators.Div;
 59 import static oracle.code.onnx.OnnxOperators.Flatten;
 60 import static oracle.code.onnx.OnnxOperators.Gemm;
 61 import static oracle.code.onnx.OnnxOperators.MaxPool;
 62 import static oracle.code.onnx.OnnxOperators.Relu;
 63 import static oracle.code.onnx.OnnxOperators.Softmax;
 64 
 65 // A rough CNN implementation which expects a input [batch_size, 1, 28, 28].
 66 // Over time we will improve the operator expressions to reduce
 67 // the verbosity e.g., esp. scalar constant expressions
 68 public class CNNTest {
 69 
 70     private static final String IMAGES_PATH = CNNTest.class.getResource("images-ubyte").getPath();
 71     private static final String LABELS_PATH = CNNTest.class.getResource("labels-ubyte").getPath();
 72     private static final int IMAGES_HEADER_SIZE = 0;
 73     private static final int LABELS_HEADER_SIZE = 0;
 74 
 75 //    static final String IMAGES_PATH = CNNTest.class.getResource("t10k-images-idx3-ubyte").getPath();
 76 //    static final String LABELS_PATH = CNNTest.class.getResource("t10k-labels-idx1-ubyte").getPath();
 77 //    static final int IMAGES_HEADER_SIZE = 16;
 78 //    static final int LABELS_HEADER_SIZE = 8;
 79 
 80     private static final String GREY_SCALE = " .'`^\",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$";
 81     private static final int PIXEL_DEPTH = 255;
 82     private static final int NUM_CHANNELS = 1;
 83     private static final int IMAGE_SIZE = 28;
 84 
 85     @Reflect
 86     public static Tensor<Float> cnn(
 87             // Weights and biases
 88             // [6, 1, 5, 5]
 89             Tensor<Float> conv1Weights,
 90             // [6]
 91             Tensor<Float> conv1Biases,
 92             // [16, 6, 5, 5]
 93             Tensor<Float> conv2Weights,
 94             // [16]
 95             Tensor<Float> conv2Biases,
 96             // [120, 256]
 97             Tensor<Float> fc1Weights,
 98             // [120]
 99             Tensor<Float> fc1Biases,
100             // [84, 120]
101             Tensor<Float> fc2Weights,
102             // [84]
103             Tensor<Float> fc2Biases,
104             // [NUM_LABELS, 84]
105             Tensor<Float> fc3Weights,
106             // [NUM_LABELS]
107             Tensor<Float> fc3Biases,
108             // Inputs
109             Tensor<Byte> ubyteImage) {
110 
111         Tensor<Float> inputImage = Cast(ubyteImage, empty(), Tensor.ElementType.FLOAT.id, empty());
112 
113         // Scaling the features to 0-1
114         var scalingFactor = Constant((float) PIXEL_DEPTH);
115         var scaledInput = Div(inputImage, scalingFactor);
116 
117         // First conv layer
118         var conv1 = Conv(scaledInput, conv1Weights, of(conv1Biases), of(new long[4]),
119                 of(new long[]{1,1}), empty(), of(new long[]{1, 1, 1, 1}),
120                 of(1L), of(new long[]{5,5}));
121         var relu1 = Relu(conv1);
122 
123         // First pooling layer
124         var pool1 = MaxPool(relu1, of(new long[4]), of(new long[]{1,1}), empty(),
125                 of(0L), empty(), of(new long[]{2, 2}), new long[]{2, 2});
126 
127         // Second conv layer
128         var conv2 = Conv(pool1.Y(), conv2Weights, of(conv2Biases), of(new long[4]),
129                 of(new long[]{1,1}), empty(), of(new long[]{1, 1, 1, 1}),
130                 of(1L), of(new long[]{5,5}));
131         var relu2 = Relu(conv2);
132 
133         // Second pooling layer
134         var pool2 = MaxPool(relu2, of(new long[4]), of(new long[]{1,1}), empty(),
135                 of(0L), empty(), of(new long[]{2, 2}), new long[]{2, 2});
136 
137         // Flatten inputs
138         var flatten = Flatten(pool2.Y(), of(1L));
139 
140         // First fully connected layer
141         var fc1 = Gemm(flatten, fc1Weights, of(fc1Biases), of(1f), of(1L), of(1f), empty());
142         var relu3 = Relu(fc1);
143 
144         // Second fully connected layer
145         var fc2 = Gemm(relu3, fc2Weights, of(fc2Biases), of(1f), of(1L), of(1f), empty());
146         var relu4 = Relu(fc2);
147 
148         // Softmax layer
149         var fc3 = Gemm(relu4, fc3Weights, of(fc3Biases), of(1f), of(1L), of(1f), empty());
150         var prediction = Softmax(fc3, of(1L));
151 
152         return prediction;
153     }
154 
155     static CoreOp.ModuleOp cnnModel() {
156         // @@@ function type and result types with correct tensor element and shape
157 
158         FunctionType functionType = CoreType.functionType(
159                 OnnxType.TENSOR_FLOAT32, // return
160                 OnnxType.TENSOR_FLOAT32, // conv1Weights
161                 OnnxType.TENSOR_FLOAT32, // conv1Biases
162                 OnnxType.TENSOR_FLOAT32, // conv2Weights
163                 OnnxType.TENSOR_FLOAT32, // conv2Biases
164                 OnnxType.TENSOR_FLOAT32, // fc1Weights
165                 OnnxType.TENSOR_FLOAT32, // fc1Biases
166                 OnnxType.TENSOR_FLOAT32, // fc2Weights
167                 OnnxType.TENSOR_FLOAT32, // fc2Biases
168                 OnnxType.TENSOR_FLOAT32, // fc3Weights
169                 OnnxType.TENSOR_FLOAT32,  // fc3Biases
170                 OnnxType.TENSOR_UINT8 // input
171         );
172 
173         return CoreOp.module(CoreOp.func("cnn", functionType).body(b -> {
174             // weights & biases
175             Block.Parameter conv1Weights = b.parameters().get(0);
176             Block.Parameter conv1Biases = b.parameters().get(1);
177             Block.Parameter conv2Weights = b.parameters().get(2);
178             Block.Parameter conv2Biases = b.parameters().get(3);
179             Block.Parameter fc1Weights = b.parameters().get(4);
180             Block.Parameter fc1Biases = b.parameters().get(5);
181             Block.Parameter fc2Weights = b.parameters().get(6);
182             Block.Parameter fc2Biases = b.parameters().get(7);
183             Block.Parameter fc3Weights = b.parameters().get(8);
184             Block.Parameter fc3Biases = b.parameters().get(9);
185             Block.Parameter ubyteImage = b.parameters().get(10);
186 
187             var inputImage = b.op(OnnxOps.Cast(OnnxType.TENSOR_FLOAT32,
188                     ubyteImage,
189                     empty(),
190                     OnnxType.TENSOR_FLOAT32.eType().id(),
191                     empty()
192             ));
193 
194             // Scaling the features
195             var scalingFactor = b.op(OnnxOps.Constant(OnnxType.TENSOR_FLOAT32,
196                     empty(),
197                     empty(),
198                     empty(),
199                     of((float) PIXEL_DEPTH),
200                     empty(),
201                     empty(),
202                     empty(),
203                     empty()));
204             var scaledInput = b.op(OnnxOps.Div(inputImage.type(), inputImage, scalingFactor));
205 
206             // First conv layer
207             var conv1 = b.op(OnnxOps.Conv(scaledInput.type(),
208                     scaledInput,
209                     conv1Weights,
210                     of(conv1Biases),
211                     of(new long[4]),
212                     of(new long[]{1,1}),
213                     empty(),
214                     of(new long[]{1, 1, 1, 1}),
215                     of(1L),
216                     of(new long[]{5,5})));
217             var relu1 = b.op(OnnxOps.Relu(conv1.type(),
218                     conv1));
219 
220             // First pooling layer
221             // @@@ multiple results?
222             var pool1Result = b.op(OnnxOps.MaxPool(CoreType.tupleType(relu1.type(), OnnxType.TENSOR_INT64),
223                     Set.of(OnnxOps.MaxPool.OutputParameter.Indices),
224                     relu1,
225                     of(new long[4]),
226                     of(new long[]{1,1}),
227                     empty(),
228                     of(0L),
229                     empty(),
230                     of(new long[]{2, 2}),
231                     new long[]{2, 2}));
232 
233             // Second conv layer
234             var pool1 = b.op(CoreOp.tupleLoad(pool1Result, 0));
235             var conv2 = b.op(OnnxOps.Conv(pool1.type(),
236                     pool1,
237                     conv2Weights,
238                     of(conv2Biases),
239                     of(new long[4]),
240                     of(new long[]{1,1}),
241                     empty(),
242                     of(new long[]{1, 1, 1, 1}),
243                     of(1L),
244                     of(new long[]{5,5})));
245             var relu2 = b.op(OnnxOps.Relu(conv2.type(),
246                     conv2));
247 
248             // Second pooling layer
249             // @@@ multiple results?
250             var pool2Result = b.op(OnnxOps.MaxPool(CoreType.tupleType(relu2.type(), OnnxType.TENSOR_INT64),
251                     Set.of(OnnxOps.MaxPool.OutputParameter.Indices),
252                     relu2,
253                     of(new long[4]),
254                     of(new long[]{1,1}),
255                     empty(),
256                     of(0L),
257                     empty(),
258                     of(new long[]{2, 2}),
259                     new long[]{2, 2}));
260 
261             // Flatten inputs
262             var pool2 = b.op(CoreOp.tupleLoad(pool2Result, 0));
263             var flatten = b.op(OnnxOps.Flatten(pool2.type(),
264                     pool2,
265                     of(1L)));
266 
267             // First fully connected layer
268             var fc1 = b.op(OnnxOps.Gemm(flatten.type(),
269                     flatten,
270                     fc1Weights,
271                     of(fc1Biases),
272                     of(1f),
273                     of(1L),
274                     of(1f),
275                     empty()));
276             var relu3 = b.op(OnnxOps.Relu(fc1.type(),
277                     fc1));
278 
279             // Second fully connected layer
280             var fc2 = b.op(OnnxOps.Gemm(relu3.type(),
281                     relu3,
282                     fc2Weights,
283                     of(fc2Biases),
284                     of(1f),
285                     of(1L),
286                     of(1f),
287                     empty()));
288             var relu4 = b.op(OnnxOps.Relu(fc2.type(),
289                     fc2));
290 
291             // Softmax layer
292             var fc3 = b.op(OnnxOps.Gemm(relu4.type(),
293                     relu4,
294                     fc3Weights,
295                     of(fc3Biases),
296                     of(1f),
297                     of(1L),
298                     of(1f),
299                     empty()));
300             var prediction = b.op(OnnxOps.Softmax(fc3.type(),
301                     fc3,
302                     of(1L)));
303 
304             b.op(CoreOp.return_(prediction));
305         }));
306     }
307 
308     static void printImage(int imageIndex, MemorySegment data) {
309         System.out.println("Image #" + imageIndex + " :");
310         int offset = imageIndex * 28 * 28;
311         for (int y = 0; y < 28; y++) {
312             for (int x = 0; x < 28; x++) {
313                 System.out.print(GREY_SCALE.charAt(GREY_SCALE.length() * (0xff & data.get(ValueLayout.JAVA_BYTE, offset + y * 28 + x)) / 256));
314             }
315             System.out.println();
316         }
317     }
318 
319     private Tensor<Float> floatTensor(Arena arena, String resource, long... shape) throws IOException {
320         try (var file = new RandomAccessFile(CNNTest.class.getResource(resource).getPath(), "r")) {
321             return new Tensor(arena, file.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, file.length(), arena), Tensor.ElementType.FLOAT, shape);
322         }
323     }
324 
325     @Test
326     public void testModels() {
327         try (var arena = Arena.ofConfined()) {
328             CoreOp.FuncOp f = getFuncOp("cnn");
329             CoreOp.ModuleOp onnxModel = OnnxTransformer.transform(MethodHandles.lookup(), f).module();
330             System.out.println(onnxModel.toText());
331 
332             var expectedOnnxModel = cnnModel();
333             System.out.println(expectedOnnxModel.toText());
334 
335             Assertions.assertEquals(serialize(expectedOnnxModel), serialize(onnxModel));
336         }
337     }
338 
339     @Test
340     public void testInterpreter() throws Exception {
341         try (var arena = Arena.ofConfined()) {
342             var conv1Weight = floatTensor(arena, "mnist/conv1-weight-float-le", 6, 1, 5, 5);
343             var conv1Bias = floatTensor(arena, "mnist/conv1-bias-float-le", 6);
344             var conv2Weight = floatTensor(arena, "mnist/conv2-weight-float-le", 16, 6, 5, 5);
345             var conv2Bias = floatTensor(arena, "mnist/conv2-bias-float-le", 16);
346             var fc1Weight = floatTensor(arena, "mnist/fc1-weight-float-le", 120, 256);
347             var fc1Bias = floatTensor(arena, "mnist/fc1-bias-float-le", 120);
348             var fc2Weight = floatTensor(arena, "mnist/fc2-weight-float-le", 84, 120);
349             var fc2Bias = floatTensor(arena, "mnist/fc2-bias-float-le", 84);
350             var fc3Weight = floatTensor(arena, "mnist/fc3-weight-float-le", 10, 84);
351             var fc3Bias = floatTensor(arena, "mnist/fc3-bias-float-le", 10);
352             test(arena, inputImage -> cnn(conv1Weight, conv1Bias,
353                                           conv2Weight, conv2Bias,
354                                           fc1Weight, fc1Bias,
355                                           fc2Weight, fc2Bias,
356                                           fc3Weight, fc3Bias,
357                                           inputImage));
358         }
359     }
360 
361     @Test
362     public void testProtobufModel() throws Exception {
363         try (var arena = Arena.ofConfined()) {
364             var conv1Weight = floatTensor(arena, "mnist/conv1-weight-float-le", 6, 1, 5, 5);
365             var conv1Bias = floatTensor(arena, "mnist/conv1-bias-float-le", 6);
366             var conv2Weight = floatTensor(arena, "mnist/conv2-weight-float-le", 16, 6, 5, 5);
367             var conv2Bias = floatTensor(arena, "mnist/conv2-bias-float-le", 16);
368             var fc1Weight = floatTensor(arena, "mnist/fc1-weight-float-le", 120, 256);
369             var fc1Bias = floatTensor(arena, "mnist/fc1-bias-float-le", 120);
370             var fc2Weight = floatTensor(arena, "mnist/fc2-weight-float-le", 84, 120);
371             var fc2Bias = floatTensor(arena, "mnist/fc2-bias-float-le", 84);
372             var fc3Weight = floatTensor(arena, "mnist/fc3-weight-float-le", 10, 84);
373             var fc3Bias = floatTensor(arena, "mnist/fc3-bias-float-le", 10);
374             test(arena, inputImage -> OnnxRuntime.execute(arena, MethodHandles.lookup(), (@Reflect Supplier<Tensor<Float>>) () ->
375                     cnn(conv1Weight, conv1Bias, conv2Weight, conv2Bias,
376                         fc1Weight, fc1Bias, fc2Weight, fc2Bias, fc3Weight, fc3Bias,
377                         inputImage)));
378         }
379     }
380 
381     private static final long[] IMAGE_SHAPE = new long[]{-1, 1, 28, 28};
382 
383     private void test(Arena arena, Function<Tensor<Byte>, Tensor<Float>> executor) throws Exception {
384         try (RandomAccessFile imagesF = new RandomAccessFile(IMAGES_PATH, "r");
385              RandomAccessFile labelsF = new RandomAccessFile(LABELS_PATH, "r")) {
386 
387             MemorySegment imagesIn = imagesF.getChannel().map(FileChannel.MapMode.READ_ONLY, IMAGES_HEADER_SIZE, imagesF.length() - IMAGES_HEADER_SIZE, arena);
388             MemorySegment labelsIn = labelsF.getChannel().map(FileChannel.MapMode.READ_ONLY, LABELS_HEADER_SIZE, labelsF.length() - LABELS_HEADER_SIZE, arena);
389 
390             Tensor<Byte> inputImage = new Tensor(arena, imagesIn, Tensor.ElementType.UINT8, IMAGE_SHAPE);
391 
392             MemorySegment result = executor.apply(inputImage).data();
393 
394             int matched = 0, mismatched = 0;
395             int i = 0;
396             int resultSize = (int)result.byteSize() / 4;
397             while (i < resultSize) {
398                 int expected = labelsIn.get(ValueLayout.JAVA_BYTE, i / 10);
399 
400                 int actual = 0;
401                 float maxW = result.getAtIndex(ValueLayout.JAVA_FLOAT, i++);
402                 for (int j = 1; j < 10; j++) {
403                     float w = result.getAtIndex(ValueLayout.JAVA_FLOAT, i++);
404                     if (w > maxW) {
405                         maxW = w;
406                         actual = j;
407                     }
408                 }
409 
410                 if (expected == actual) {
411                     matched++;
412                 } else {
413                     int imageIndex = i / 10 - 1;
414                     printImage(imageIndex, imagesIn);
415                     System.out.println("expected: " + expected + " actual: " + actual);
416                     System.out.println("-".repeat(28));
417                     mismatched++;
418                 }
419             }
420             System.out.println("matched: " + matched + " mismatched: " + mismatched);
421             Assertions.assertTrue(mismatched / matched < 0.05);
422         }
423     }
424 
425     static String serialize(Op o) {
426         StringWriter w = new StringWriter();
427         OpWriter.writeTo(w, o, OpWriter.LocationOption.DROP_LOCATION);
428         return w.toString();
429     }
430 
431     static CoreOp.FuncOp getFuncOp(String name) {
432         Optional<Method> om = Stream.of(CNNTest.class.getDeclaredMethods())
433                 .filter(m -> m.getName().equals(name))
434                 .findFirst();
435 
436         Method m = om.get();
437         return Op.ofMethod(m).get();
438     }
439 
440 //    @Reflect
441 //    public Tensor<Float> loadWeight(Initializer init) {
442 //        var buf = ByteBuffer.allocate(init.values().length).order(ByteOrder.nativeOrder());
443 //        buf.put(init.values());
444 //        buf.rewind();
445 //        var floatBuf = buf.asFloatBuffer();
446 //        var floatArr = new float[floatBuf.remaining()];
447 //        floatBuf.get(floatArr);
448 //        Tensor<Long> shape = Constant(
449 //                empty(), empty(), empty(), empty(), empty(), of(init.shape()), empty(), empty()
450 //        );
451 //        Tensor<Float> floats = Constant(
452 //                empty(), of(floatArr), empty(), empty(), empty(), empty(), empty(), empty()
453 //        );
454 //        var shaped = Reshape(floats, shape, empty());
455 //        return shaped;
456 //    }
457 //
458 //    public static void extractWeights(Path inputOnnx, Path outputSerialized) throws IOException  {
459 //        try (InputStream is = Files.newInputStream(inputOnnx)) {
460 //            OnnxMl.ModelProto model = OnnxMl.ModelProto.parseFrom(is);
461 //            OnnxMl.GraphProto graph = model.getGraph();
462 //            List<Initializer> initList = new ArrayList<>();
463 //            for (var init : graph.getInitializerList()) {
464 //                var name = init.getName();
465 //                var type = init.getDataType();
466 //                var shape = init.getDimsList().stream().mapToLong(a -> a).toArray();
467 //                var valuesBuf = init.getRawData().asReadOnlyByteBuffer();
468 //                var valuesArr = new byte[valuesBuf.remaining()];
469 //                valuesBuf.get(valuesArr);
470 //                var initializer = new Initializer(name, type, shape, valuesArr);
471 //                System.out.println(initializer);
472 //                initList.add(initializer);
473 //            }
474 //            try (ObjectOutputStream oos = new ObjectOutputStream(Files.newOutputStream(outputSerialized))) {
475 //                oos.writeObject(initList);
476 //            }
477 //        }
478 //    }
479 //
480 //    public record Initializer(String name, int type, long[] shape, byte[] values) implements java.io.Serializable {
481 //        @Override
482 //        public String toString() {
483 //            return "Initializer{" +
484 //                    "name='" + name + '\'' +
485 //                    ", type=" + type +
486 //                    ", shape=" + Arrays.toString(shape) +
487 //                    ", values.length=" + values.length +
488 //                    '}';
489 //        }
490 //    }
491 //
492 //    public static void main(String[] args) throws IOException {
493 //        Path inputPath = Path.of(args[0]);
494 //
495 //        Path outputPath = Path.of(args[1]);
496 //
497 //        extractWeights(inputPath, outputPath);
498 //    }
499 
500     /*
501     ONNX code model
502 
503 func @"cnn" (
504 %0 : tensor<float32>,
505 %1 : tensor<float32>,
506 %2 : tensor<float32>,
507 %3 : tensor<float32>,
508 %4 : tensor<float32>,
509 %5 : tensor<float32>,
510 %6 : tensor<float32>,
511 %7 : tensor<float32>,
512 %8 : tensor<float32>,
513 %9 : tensor<float32>,
514 %10 : tensor<float32>)tensor<float32> -> {
515     %11 : tensor<int64> = Constant @value_ints="[I@32910148";
516     %12 : tensor<float32> = Reshape %0 %11;
517     %13 : tensor<float32> = Constant @value_float="255.0";
518     %14 : tensor<float32> = Div %12 %13;
519     %15 : tensor<float32> = Conv %14 %1 %2 @optional_inputs="[B]" @strides="[I@2b4bac49" @pads="[I@fd07cbb" @dilations="[I@3571b748" @group="1" @kernel_shape="[I@3e96bacf";
520     %16 : tensor<float32> = Relu %15;
521     %17 : tensor<float32> = MaxPool %16 @ceil_mode="0" @strides="[I@484970b0" @pads="[I@4470f8a6" @dilations="[I@7c83dc97" @kernel_shape="[I@7748410a";
522     %18 : tensor<float32> = Conv %17 %3 %4 @optional_inputs="[B]" @strides="[I@740773a3" @pads="[I@37f1104d" @dilations="[I@55740540" @group="1" @kernel_shape="[I@60015ef5";
523     %19 : tensor<float32> = Relu %18;
524     %20 : tensor<float32> = MaxPool %19 @ceil_mode="0" @strides="[I@2f54a33d" @pads="[I@1018bde2" @dilations="[I@65b3f4a4" @kernel_shape="[I@f2ff811";
525     %21 : tensor<float32> = Flatten %20 @axis="1";
526     %22 : tensor<float32> = Gemm %21 %5 %6 @optional_inputs="[C]" @transB="1" @beta="1.0" @alpha="1.0";
527     %23 : tensor<float32> = Relu %22;
528     %24 : tensor<float32> = Gemm %23 %7 %8 @optional_inputs="[C]" @transB="1" @beta="1.0" @alpha="1.0";
529     %25 : tensor<float32> = Relu %24;
530     %26 : tensor<float32> = Gemm %25 %9 %10 @optional_inputs="[C]" @transB="1" @beta="1.0" @alpha="1.0";
531     %27 : tensor<float32> = Softmax %26 @axis="1";
532     return %27;
533 };
534      */
535 }