1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.onnx;
 25 
 26 import java.io.*;
 27 import java.lang.foreign.Arena;
 28 import jdk.incubator.code.Block;
 29 import jdk.incubator.code.CodeReflection;
 30 import jdk.incubator.code.Op;
 31 import jdk.incubator.code.dialect.core.CoreOp;
 32 import jdk.incubator.code.dialect.core.CoreType;
 33 import jdk.incubator.code.dialect.core.FunctionType;
 34 import jdk.incubator.code.dialect.core.TupleType;
 35 import jdk.incubator.code.extern.OpWriter;
 36 import oracle.code.onnx.ir.OnnxOps;
 37 import oracle.code.onnx.ir.OnnxType;
 38 import org.junit.jupiter.api.Test;
 39 import org.junit.jupiter.api.Assertions;
 40 
 41 import java.lang.invoke.MethodHandles;
 42 import java.lang.reflect.Method;
 43 import java.util.Optional;
 44 import java.util.Set;
 45 import java.util.stream.Stream;
 46 import java.lang.foreign.MemorySegment;
 47 import java.lang.foreign.ValueLayout;
 48 import java.nio.channels.FileChannel;
 49 import java.util.function.Function;
 50 import oracle.code.onnx.compiler.OnnxTransformer;
 51 
 52 import static java.util.Optional.empty;
 53 import static java.util.Optional.of;
 54 import static oracle.code.onnx.OnnxOperators.Cast;
 55 import static oracle.code.onnx.OnnxOperators.Constant;
 56 import static oracle.code.onnx.OnnxOperators.Conv;
 57 import static oracle.code.onnx.OnnxOperators.Div;
 58 import static oracle.code.onnx.OnnxOperators.Flatten;
 59 import static oracle.code.onnx.OnnxOperators.Gemm;
 60 import static oracle.code.onnx.OnnxOperators.MaxPool;
 61 import static oracle.code.onnx.OnnxOperators.Relu;
 62 import static oracle.code.onnx.OnnxOperators.Softmax;
 63 
 64 // A rough CNN implementation which expects a input [batch_size, 1, 28, 28].
 65 // Over time we will improve the operator expressions to reduce
 66 // the verbosity e.g., esp. scalar constant expressions
 67 public class CNNTest {
 68 
 69     private static final String IMAGES_PATH = CNNTest.class.getResource("images-ubyte").getPath();
 70     private static final String LABELS_PATH = CNNTest.class.getResource("labels-ubyte").getPath();
 71     private static final int IMAGES_HEADER_SIZE = 0;
 72     private static final int LABELS_HEADER_SIZE = 0;
 73 
 74 //    static final String IMAGES_PATH = CNNTest.class.getResource("t10k-images-idx3-ubyte").getPath();
 75 //    static final String LABELS_PATH = CNNTest.class.getResource("t10k-labels-idx1-ubyte").getPath();
 76 //    static final int IMAGES_HEADER_SIZE = 16;
 77 //    static final int LABELS_HEADER_SIZE = 8;
 78 
 79     private static final String GREY_SCALE = " .'`^\",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$";
 80     private static final int PIXEL_DEPTH = 255;
 81     private static final int NUM_CHANNELS = 1;
 82     private static final int IMAGE_SIZE = 28;
 83 
 84     @CodeReflection
 85     public static Tensor<Float> cnn(
 86             // Weights and biases
 87             // [6, 1, 5, 5]
 88             Tensor<Float> conv1Weights,
 89             // [6]
 90             Tensor<Float> conv1Biases,
 91             // [16, 6, 5, 5]
 92             Tensor<Float> conv2Weights,
 93             // [16]
 94             Tensor<Float> conv2Biases,
 95             // [120, 256]
 96             Tensor<Float> fc1Weights,
 97             // [120]
 98             Tensor<Float> fc1Biases,
 99             // [84, 120]
100             Tensor<Float> fc2Weights,
101             // [84]
102             Tensor<Float> fc2Biases,
103             // [NUM_LABELS, 84]
104             Tensor<Float> fc3Weights,
105             // [NUM_LABELS]
106             Tensor<Float> fc3Biases,
107             // Inputs
108             Tensor<Byte> ubyteImage) {
109 
110         Tensor<Float> inputImage = Cast(ubyteImage, empty(), Tensor.ElementType.FLOAT.id, empty());
111 
112         // Scaling the features to 0-1
113         var scalingFactor = Constant((float) PIXEL_DEPTH);
114         var scaledInput = Div(inputImage, scalingFactor);
115 
116         // First conv layer
117         var conv1 = Conv(scaledInput, conv1Weights, of(conv1Biases), of(new long[4]),
118                 of(new long[]{1,1}), empty(), of(new long[]{1, 1, 1, 1}),
119                 of(1L), of(new long[]{5,5}));
120         var relu1 = Relu(conv1);
121 
122         // First pooling layer
123         var pool1 = MaxPool(relu1, of(new long[4]), of(new long[]{1,1}), empty(),
124                 of(0L), empty(), of(new long[]{2, 2}), new long[]{2, 2});
125 
126         // Second conv layer
127         var conv2 = Conv(pool1.Y(), conv2Weights, of(conv2Biases), of(new long[4]),
128                 of(new long[]{1,1}), empty(), of(new long[]{1, 1, 1, 1}),
129                 of(1L), of(new long[]{5,5}));
130         var relu2 = Relu(conv2);
131 
132         // Second pooling layer
133         var pool2 = MaxPool(relu2, of(new long[4]), of(new long[]{1,1}), empty(),
134                 of(0L), empty(), of(new long[]{2, 2}), new long[]{2, 2});
135 
136         // Flatten inputs
137         var flatten = Flatten(pool2.Y(), of(1L));
138 
139         // First fully connected layer
140         var fc1 = Gemm(flatten, fc1Weights, of(fc1Biases), of(1f), of(1L), of(1f), empty());
141         var relu3 = Relu(fc1);
142 
143         // Second fully connected layer
144         var fc2 = Gemm(relu3, fc2Weights, of(fc2Biases), of(1f), of(1L), of(1f), empty());
145         var relu4 = Relu(fc2);
146 
147         // Softmax layer
148         var fc3 = Gemm(relu4, fc3Weights, of(fc3Biases), of(1f), of(1L), of(1f), empty());
149         var prediction = Softmax(fc3, of(1L));
150 
151         return prediction;
152     }
153 
154     static CoreOp.ModuleOp cnnModel() {
155         // @@@ function type and result types with correct tensor element and shape
156 
157         FunctionType functionType = CoreType.functionType(
158                 OnnxType.TENSOR_FLOAT32, // return
159                 OnnxType.TENSOR_FLOAT32, // conv1Weights
160                 OnnxType.TENSOR_FLOAT32, // conv1Biases
161                 OnnxType.TENSOR_FLOAT32, // conv2Weights
162                 OnnxType.TENSOR_FLOAT32, // conv2Biases
163                 OnnxType.TENSOR_FLOAT32, // fc1Weights
164                 OnnxType.TENSOR_FLOAT32, // fc1Biases
165                 OnnxType.TENSOR_FLOAT32, // fc2Weights
166                 OnnxType.TENSOR_FLOAT32, // fc2Biases
167                 OnnxType.TENSOR_FLOAT32, // fc3Weights
168                 OnnxType.TENSOR_FLOAT32,  // fc3Biases
169                 OnnxType.TENSOR_UINT8 // input
170         );
171 
172         return CoreOp.module(CoreOp.func("cnn", functionType).body(b -> {
173             // weights & biases
174             Block.Parameter conv1Weights = b.parameters().get(0);
175             Block.Parameter conv1Biases = b.parameters().get(1);
176             Block.Parameter conv2Weights = b.parameters().get(2);
177             Block.Parameter conv2Biases = b.parameters().get(3);
178             Block.Parameter fc1Weights = b.parameters().get(4);
179             Block.Parameter fc1Biases = b.parameters().get(5);
180             Block.Parameter fc2Weights = b.parameters().get(6);
181             Block.Parameter fc2Biases = b.parameters().get(7);
182             Block.Parameter fc3Weights = b.parameters().get(8);
183             Block.Parameter fc3Biases = b.parameters().get(9);
184             Block.Parameter ubyteImage = b.parameters().get(10);
185 
186             var inputImage = b.op(OnnxOps.Cast(OnnxType.TENSOR_FLOAT32,
187                     ubyteImage,
188                     empty(),
189                     OnnxType.TENSOR_FLOAT32.eType().id(),
190                     empty()
191             ));
192 
193             // Scaling the features
194             var scalingFactor = b.op(OnnxOps.Constant(OnnxType.TENSOR_FLOAT32,
195                     empty(),
196                     empty(),
197                     empty(),
198                     of((float) PIXEL_DEPTH),
199                     empty(),
200                     empty(),
201                     empty(),
202                     empty()));
203             var scaledInput = b.op(OnnxOps.Div(inputImage.type(), inputImage, scalingFactor));
204 
205             // First conv layer
206             var conv1 = b.op(OnnxOps.Conv(scaledInput.type(),
207                     scaledInput,
208                     conv1Weights,
209                     of(conv1Biases),
210                     of(new long[4]),
211                     of(new long[]{1,1}),
212                     empty(),
213                     of(new long[]{1, 1, 1, 1}),
214                     of(1L),
215                     of(new long[]{5,5})));
216             var relu1 = b.op(OnnxOps.Relu(conv1.type(),
217                     conv1));
218 
219             // First pooling layer
220             // @@@ multiple results?
221             var pool1Result = b.op(OnnxOps.MaxPool(CoreType.tupleType(relu1.type(), OnnxType.TENSOR_INT64),
222                     Set.of(OnnxOps.MaxPool.OutputParameter.Indices),
223                     relu1,
224                     of(new long[4]),
225                     of(new long[]{1,1}),
226                     empty(),
227                     of(0L),
228                     empty(),
229                     of(new long[]{2, 2}),
230                     new long[]{2, 2}));
231 
232             // Second conv layer
233             var pool1 = b.op(CoreOp.tupleLoad(pool1Result, 0));
234             var conv2 = b.op(OnnxOps.Conv(pool1.type(),
235                     pool1,
236                     conv2Weights,
237                     of(conv2Biases),
238                     of(new long[4]),
239                     of(new long[]{1,1}),
240                     empty(),
241                     of(new long[]{1, 1, 1, 1}),
242                     of(1L),
243                     of(new long[]{5,5})));
244             var relu2 = b.op(OnnxOps.Relu(conv2.type(),
245                     conv2));
246 
247             // Second pooling layer
248             // @@@ multiple results?
249             var pool2Result = b.op(OnnxOps.MaxPool(CoreType.tupleType(relu2.type(), OnnxType.TENSOR_INT64),
250                     Set.of(OnnxOps.MaxPool.OutputParameter.Indices),
251                     relu2,
252                     of(new long[4]),
253                     of(new long[]{1,1}),
254                     empty(),
255                     of(0L),
256                     empty(),
257                     of(new long[]{2, 2}),
258                     new long[]{2, 2}));
259 
260             // Flatten inputs
261             var pool2 = b.op(CoreOp.tupleLoad(pool2Result, 0));
262             var flatten = b.op(OnnxOps.Flatten(pool2.type(),
263                     pool2,
264                     of(1L)));
265 
266             // First fully connected layer
267             var fc1 = b.op(OnnxOps.Gemm(flatten.type(),
268                     flatten,
269                     fc1Weights,
270                     of(fc1Biases),
271                     of(1f),
272                     of(1L),
273                     of(1f),
274                     empty()));
275             var relu3 = b.op(OnnxOps.Relu(fc1.type(),
276                     fc1));
277 
278             // Second fully connected layer
279             var fc2 = b.op(OnnxOps.Gemm(relu3.type(),
280                     relu3,
281                     fc2Weights,
282                     of(fc2Biases),
283                     of(1f),
284                     of(1L),
285                     of(1f),
286                     empty()));
287             var relu4 = b.op(OnnxOps.Relu(fc2.type(),
288                     fc2));
289 
290             // Softmax layer
291             var fc3 = b.op(OnnxOps.Gemm(relu4.type(),
292                     relu4,
293                     fc3Weights,
294                     of(fc3Biases),
295                     of(1f),
296                     of(1L),
297                     of(1f),
298                     empty()));
299             var prediction = b.op(OnnxOps.Softmax(fc3.type(),
300                     fc3,
301                     of(1L)));
302 
303             b.op(CoreOp.return_(prediction));
304         }));
305     }
306 
307     static void printImage(int imageIndex, MemorySegment data) {
308         System.out.println("Image #" + imageIndex + " :");
309         int offset = imageIndex * 28 * 28;
310         for (int y = 0; y < 28; y++) {
311             for (int x = 0; x < 28; x++) {
312                 System.out.print(GREY_SCALE.charAt(GREY_SCALE.length() * (0xff & data.get(ValueLayout.JAVA_BYTE, offset + y * 28 + x)) / 256));
313             }
314             System.out.println();
315         }
316     }
317 
318     private Tensor<Float> floatTensor(Arena arena, String resource, long... shape) throws IOException {
319         try (var file = new RandomAccessFile(CNNTest.class.getResource(resource).getPath(), "r")) {
320             return new Tensor(arena, file.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, file.length(), arena), Tensor.ElementType.FLOAT, shape);
321         }
322     }
323 
324     @Test
325     public void testModels() {
326         try (var arena = Arena.ofConfined()) {
327             CoreOp.FuncOp f = getFuncOp("cnn");
328             CoreOp.ModuleOp onnxModel = OnnxTransformer.transform(MethodHandles.lookup(), f).module();
329             System.out.println(onnxModel.toText());
330 
331             var expectedOnnxModel = cnnModel();
332             System.out.println(expectedOnnxModel.toText());
333 
334             Assertions.assertEquals(serialize(expectedOnnxModel), serialize(onnxModel));
335         }
336     }
337 
338     @Test
339     public void testInterpreter() throws Exception {
340         try (var arena = Arena.ofConfined()) {
341             var conv1Weight = floatTensor(arena, "mnist/conv1-weight-float-le", 6, 1, 5, 5);
342             var conv1Bias = floatTensor(arena, "mnist/conv1-bias-float-le", 6);
343             var conv2Weight = floatTensor(arena, "mnist/conv2-weight-float-le", 16, 6, 5, 5);
344             var conv2Bias = floatTensor(arena, "mnist/conv2-bias-float-le", 16);
345             var fc1Weight = floatTensor(arena, "mnist/fc1-weight-float-le", 120, 256);
346             var fc1Bias = floatTensor(arena, "mnist/fc1-bias-float-le", 120);
347             var fc2Weight = floatTensor(arena, "mnist/fc2-weight-float-le", 84, 120);
348             var fc2Bias = floatTensor(arena, "mnist/fc2-bias-float-le", 84);
349             var fc3Weight = floatTensor(arena, "mnist/fc3-weight-float-le", 10, 84);
350             var fc3Bias = floatTensor(arena, "mnist/fc3-bias-float-le", 10);
351             test(arena, inputImage -> cnn(conv1Weight, conv1Bias,
352                                           conv2Weight, conv2Bias,
353                                           fc1Weight, fc1Bias,
354                                           fc2Weight, fc2Bias,
355                                           fc3Weight, fc3Bias,
356                                           inputImage));
357         }
358     }
359 
360     @Test
361     public void testProtobufModel() throws Exception {
362         try (var arena = Arena.ofConfined()) {
363             var conv1Weight = floatTensor(arena, "mnist/conv1-weight-float-le", 6, 1, 5, 5);
364             var conv1Bias = floatTensor(arena, "mnist/conv1-bias-float-le", 6);
365             var conv2Weight = floatTensor(arena, "mnist/conv2-weight-float-le", 16, 6, 5, 5);
366             var conv2Bias = floatTensor(arena, "mnist/conv2-bias-float-le", 16);
367             var fc1Weight = floatTensor(arena, "mnist/fc1-weight-float-le", 120, 256);
368             var fc1Bias = floatTensor(arena, "mnist/fc1-bias-float-le", 120);
369             var fc2Weight = floatTensor(arena, "mnist/fc2-weight-float-le", 84, 120);
370             var fc2Bias = floatTensor(arena, "mnist/fc2-bias-float-le", 84);
371             var fc3Weight = floatTensor(arena, "mnist/fc3-weight-float-le", 10, 84);
372             var fc3Bias = floatTensor(arena, "mnist/fc3-bias-float-le", 10);
373             test(arena, inputImage -> OnnxRuntime.execute(arena, MethodHandles.lookup(), () ->
374                     cnn(conv1Weight, conv1Bias, conv2Weight, conv2Bias,
375                         fc1Weight, fc1Bias, fc2Weight, fc2Bias, fc3Weight, fc3Bias,
376                         inputImage)));
377         }
378     }
379 
380     private static final long[] IMAGE_SHAPE = new long[]{-1, 1, 28, 28};
381 
382     private void test(Arena arena, Function<Tensor<Byte>, Tensor<Float>> executor) throws Exception {
383         try (RandomAccessFile imagesF = new RandomAccessFile(IMAGES_PATH, "r");
384              RandomAccessFile labelsF = new RandomAccessFile(LABELS_PATH, "r")) {
385 
386             MemorySegment imagesIn = imagesF.getChannel().map(FileChannel.MapMode.READ_ONLY, IMAGES_HEADER_SIZE, imagesF.length() - IMAGES_HEADER_SIZE, arena);
387             MemorySegment labelsIn = labelsF.getChannel().map(FileChannel.MapMode.READ_ONLY, LABELS_HEADER_SIZE, labelsF.length() - LABELS_HEADER_SIZE, arena);
388 
389             Tensor<Byte> inputImage = new Tensor(arena, imagesIn, Tensor.ElementType.UINT8, IMAGE_SHAPE);
390 
391             MemorySegment result = executor.apply(inputImage).data();
392 
393             int matched = 0, mismatched = 0;
394             int i = 0;
395             int resultSize = (int)result.byteSize() / 4;
396             while (i < resultSize) {
397                 int expected = labelsIn.get(ValueLayout.JAVA_BYTE, i / 10);
398 
399                 int actual = 0;
400                 float maxW = result.getAtIndex(ValueLayout.JAVA_FLOAT, i++);
401                 for (int j = 1; j < 10; j++) {
402                     float w = result.getAtIndex(ValueLayout.JAVA_FLOAT, i++);
403                     if (w > maxW) {
404                         maxW = w;
405                         actual = j;
406                     }
407                 }
408 
409                 if (expected == actual) {
410                     matched++;
411                 } else {
412                     int imageIndex = i / 10 - 1;
413                     printImage(imageIndex, imagesIn);
414                     System.out.println("expected: " + expected + " actual: " + actual);
415                     System.out.println("-".repeat(28));
416                     mismatched++;
417                 }
418             }
419             System.out.println("matched: " + matched + " mismatched: " + mismatched);
420             Assertions.assertTrue(mismatched / matched < 0.05);
421         }
422     }
423 
424     static String serialize(Op o) {
425         StringWriter w = new StringWriter();
426         OpWriter.writeTo(w, o, OpWriter.LocationOption.DROP_LOCATION);
427         return w.toString();
428     }
429 
430     static CoreOp.FuncOp getFuncOp(String name) {
431         Optional<Method> om = Stream.of(CNNTest.class.getDeclaredMethods())
432                 .filter(m -> m.getName().equals(name))
433                 .findFirst();
434 
435         Method m = om.get();
436         return Op.ofMethod(m).get();
437     }
438 
439 //    @CodeReflection
440 //    public Tensor<Float> loadWeight(Initializer init) {
441 //        var buf = ByteBuffer.allocate(init.values().length).order(ByteOrder.nativeOrder());
442 //        buf.put(init.values());
443 //        buf.rewind();
444 //        var floatBuf = buf.asFloatBuffer();
445 //        var floatArr = new float[floatBuf.remaining()];
446 //        floatBuf.get(floatArr);
447 //        Tensor<Long> shape = Constant(
448 //                empty(), empty(), empty(), empty(), empty(), of(init.shape()), empty(), empty()
449 //        );
450 //        Tensor<Float> floats = Constant(
451 //                empty(), of(floatArr), empty(), empty(), empty(), empty(), empty(), empty()
452 //        );
453 //        var shaped = Reshape(floats, shape, empty());
454 //        return shaped;
455 //    }
456 //
457 //    public static void extractWeights(Path inputOnnx, Path outputSerialized) throws IOException  {
458 //        try (InputStream is = Files.newInputStream(inputOnnx)) {
459 //            OnnxMl.ModelProto model = OnnxMl.ModelProto.parseFrom(is);
460 //            OnnxMl.GraphProto graph = model.getGraph();
461 //            List<Initializer> initList = new ArrayList<>();
462 //            for (var init : graph.getInitializerList()) {
463 //                var name = init.getName();
464 //                var type = init.getDataType();
465 //                var shape = init.getDimsList().stream().mapToLong(a -> a).toArray();
466 //                var valuesBuf = init.getRawData().asReadOnlyByteBuffer();
467 //                var valuesArr = new byte[valuesBuf.remaining()];
468 //                valuesBuf.get(valuesArr);
469 //                var initializer = new Initializer(name, type, shape, valuesArr);
470 //                System.out.println(initializer);
471 //                initList.add(initializer);
472 //            }
473 //            try (ObjectOutputStream oos = new ObjectOutputStream(Files.newOutputStream(outputSerialized))) {
474 //                oos.writeObject(initList);
475 //            }
476 //        }
477 //    }
478 //
479 //    public record Initializer(String name, int type, long[] shape, byte[] values) implements java.io.Serializable {
480 //        @Override
481 //        public String toString() {
482 //            return "Initializer{" +
483 //                    "name='" + name + '\'' +
484 //                    ", type=" + type +
485 //                    ", shape=" + Arrays.toString(shape) +
486 //                    ", values.length=" + values.length +
487 //                    '}';
488 //        }
489 //    }
490 //
491 //    public static void main(String[] args) throws IOException {
492 //        Path inputPath = Path.of(args[0]);
493 //
494 //        Path outputPath = Path.of(args[1]);
495 //
496 //        extractWeights(inputPath, outputPath);
497 //    }
498 
499     /*
500     ONNX code model
501 
502 func @"cnn" (
503 %0 : tensor<float32>,
504 %1 : tensor<float32>,
505 %2 : tensor<float32>,
506 %3 : tensor<float32>,
507 %4 : tensor<float32>,
508 %5 : tensor<float32>,
509 %6 : tensor<float32>,
510 %7 : tensor<float32>,
511 %8 : tensor<float32>,
512 %9 : tensor<float32>,
513 %10 : tensor<float32>)tensor<float32> -> {
514     %11 : tensor<int64> = Constant @value_ints="[I@32910148";
515     %12 : tensor<float32> = Reshape %0 %11;
516     %13 : tensor<float32> = Constant @value_float="255.0";
517     %14 : tensor<float32> = Div %12 %13;
518     %15 : tensor<float32> = Conv %14 %1 %2 @optional_inputs="[B]" @strides="[I@2b4bac49" @pads="[I@fd07cbb" @dilations="[I@3571b748" @group="1" @kernel_shape="[I@3e96bacf";
519     %16 : tensor<float32> = Relu %15;
520     %17 : tensor<float32> = MaxPool %16 @ceil_mode="0" @strides="[I@484970b0" @pads="[I@4470f8a6" @dilations="[I@7c83dc97" @kernel_shape="[I@7748410a";
521     %18 : tensor<float32> = Conv %17 %3 %4 @optional_inputs="[B]" @strides="[I@740773a3" @pads="[I@37f1104d" @dilations="[I@55740540" @group="1" @kernel_shape="[I@60015ef5";
522     %19 : tensor<float32> = Relu %18;
523     %20 : tensor<float32> = MaxPool %19 @ceil_mode="0" @strides="[I@2f54a33d" @pads="[I@1018bde2" @dilations="[I@65b3f4a4" @kernel_shape="[I@f2ff811";
524     %21 : tensor<float32> = Flatten %20 @axis="1";
525     %22 : tensor<float32> = Gemm %21 %5 %6 @optional_inputs="[C]" @transB="1" @beta="1.0" @alpha="1.0";
526     %23 : tensor<float32> = Relu %22;
527     %24 : tensor<float32> = Gemm %23 %7 %8 @optional_inputs="[C]" @transB="1" @beta="1.0" @alpha="1.0";
528     %25 : tensor<float32> = Relu %24;
529     %26 : tensor<float32> = Gemm %25 %9 %10 @optional_inputs="[C]" @transB="1" @beta="1.0" @alpha="1.0";
530     %27 : tensor<float32> = Softmax %26 @axis="1";
531     return %27;
532 };
533      */
534 }