1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package oracle.code.onnx;
25
26 import java.io.*;
27 import java.lang.foreign.Arena;
28 import jdk.incubator.code.Block;
29 import jdk.incubator.code.Reflect;
30 import jdk.incubator.code.Op;
31 import jdk.incubator.code.dialect.core.CoreOp;
32 import jdk.incubator.code.dialect.core.CoreType;
33 import jdk.incubator.code.dialect.core.FunctionType;
34 import jdk.incubator.code.dialect.core.TupleType;
35 import jdk.incubator.code.extern.OpWriter;
36 import oracle.code.onnx.ir.OnnxOps;
37 import oracle.code.onnx.ir.OnnxType;
38 import org.junit.jupiter.api.Test;
39 import org.junit.jupiter.api.Assertions;
40
41 import java.lang.invoke.MethodHandles;
42 import java.lang.reflect.Method;
43 import java.util.Optional;
44 import java.util.Set;
45 import java.util.function.Supplier;
46 import java.util.stream.Stream;
47 import java.lang.foreign.MemorySegment;
48 import java.lang.foreign.ValueLayout;
49 import java.nio.channels.FileChannel;
50 import java.util.function.Function;
51 import oracle.code.onnx.compiler.OnnxTransformer;
52
53 import static java.util.Optional.empty;
54 import static java.util.Optional.of;
55 import static oracle.code.onnx.OnnxOperators.Cast;
56 import static oracle.code.onnx.OnnxOperators.Constant;
57 import static oracle.code.onnx.OnnxOperators.Conv;
58 import static oracle.code.onnx.OnnxOperators.Div;
59 import static oracle.code.onnx.OnnxOperators.Flatten;
60 import static oracle.code.onnx.OnnxOperators.Gemm;
61 import static oracle.code.onnx.OnnxOperators.MaxPool;
62 import static oracle.code.onnx.OnnxOperators.Relu;
63 import static oracle.code.onnx.OnnxOperators.Softmax;
64
65 // A rough CNN implementation which expects a input [batch_size, 1, 28, 28].
66 // Over time we will improve the operator expressions to reduce
67 // the verbosity e.g., esp. scalar constant expressions
68 public class CNNTest {
69
70 private static final String IMAGES_PATH = CNNTest.class.getResource("images-ubyte").getPath();
71 private static final String LABELS_PATH = CNNTest.class.getResource("labels-ubyte").getPath();
72 private static final int IMAGES_HEADER_SIZE = 0;
73 private static final int LABELS_HEADER_SIZE = 0;
74
75 // static final String IMAGES_PATH = CNNTest.class.getResource("t10k-images-idx3-ubyte").getPath();
76 // static final String LABELS_PATH = CNNTest.class.getResource("t10k-labels-idx1-ubyte").getPath();
77 // static final int IMAGES_HEADER_SIZE = 16;
78 // static final int LABELS_HEADER_SIZE = 8;
79
80 private static final String GREY_SCALE = " .'`^\",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$";
81 private static final int PIXEL_DEPTH = 255;
82 private static final int NUM_CHANNELS = 1;
83 private static final int IMAGE_SIZE = 28;
84
85 @Reflect
86 public static Tensor<Float> cnn(
87 // Weights and biases
88 // [6, 1, 5, 5]
89 Tensor<Float> conv1Weights,
90 // [6]
91 Tensor<Float> conv1Biases,
92 // [16, 6, 5, 5]
93 Tensor<Float> conv2Weights,
94 // [16]
95 Tensor<Float> conv2Biases,
96 // [120, 256]
97 Tensor<Float> fc1Weights,
98 // [120]
99 Tensor<Float> fc1Biases,
100 // [84, 120]
101 Tensor<Float> fc2Weights,
102 // [84]
103 Tensor<Float> fc2Biases,
104 // [NUM_LABELS, 84]
105 Tensor<Float> fc3Weights,
106 // [NUM_LABELS]
107 Tensor<Float> fc3Biases,
108 // Inputs
109 Tensor<Byte> ubyteImage) {
110
111 Tensor<Float> inputImage = Cast(ubyteImage, empty(), Tensor.ElementType.FLOAT.id, empty());
112
113 // Scaling the features to 0-1
114 var scalingFactor = Constant((float) PIXEL_DEPTH);
115 var scaledInput = Div(inputImage, scalingFactor);
116
117 // First conv layer
118 var conv1 = Conv(scaledInput, conv1Weights, of(conv1Biases), of(new long[4]),
119 of(new long[]{1,1}), empty(), of(new long[]{1, 1, 1, 1}),
120 of(1L), of(new long[]{5,5}));
121 var relu1 = Relu(conv1);
122
123 // First pooling layer
124 var pool1 = MaxPool(relu1, of(new long[4]), of(new long[]{1,1}), empty(),
125 of(0L), empty(), of(new long[]{2, 2}), new long[]{2, 2});
126
127 // Second conv layer
128 var conv2 = Conv(pool1.Y(), conv2Weights, of(conv2Biases), of(new long[4]),
129 of(new long[]{1,1}), empty(), of(new long[]{1, 1, 1, 1}),
130 of(1L), of(new long[]{5,5}));
131 var relu2 = Relu(conv2);
132
133 // Second pooling layer
134 var pool2 = MaxPool(relu2, of(new long[4]), of(new long[]{1,1}), empty(),
135 of(0L), empty(), of(new long[]{2, 2}), new long[]{2, 2});
136
137 // Flatten inputs
138 var flatten = Flatten(pool2.Y(), of(1L));
139
140 // First fully connected layer
141 var fc1 = Gemm(flatten, fc1Weights, of(fc1Biases), of(1f), of(1L), of(1f), empty());
142 var relu3 = Relu(fc1);
143
144 // Second fully connected layer
145 var fc2 = Gemm(relu3, fc2Weights, of(fc2Biases), of(1f), of(1L), of(1f), empty());
146 var relu4 = Relu(fc2);
147
148 // Softmax layer
149 var fc3 = Gemm(relu4, fc3Weights, of(fc3Biases), of(1f), of(1L), of(1f), empty());
150 var prediction = Softmax(fc3, of(1L));
151
152 return prediction;
153 }
154
155 static CoreOp.ModuleOp cnnModel() {
156 // @@@ function type and result types with correct tensor element and shape
157
158 FunctionType functionType = CoreType.functionType(
159 OnnxType.TENSOR_FLOAT32, // return
160 OnnxType.TENSOR_FLOAT32, // conv1Weights
161 OnnxType.TENSOR_FLOAT32, // conv1Biases
162 OnnxType.TENSOR_FLOAT32, // conv2Weights
163 OnnxType.TENSOR_FLOAT32, // conv2Biases
164 OnnxType.TENSOR_FLOAT32, // fc1Weights
165 OnnxType.TENSOR_FLOAT32, // fc1Biases
166 OnnxType.TENSOR_FLOAT32, // fc2Weights
167 OnnxType.TENSOR_FLOAT32, // fc2Biases
168 OnnxType.TENSOR_FLOAT32, // fc3Weights
169 OnnxType.TENSOR_FLOAT32, // fc3Biases
170 OnnxType.TENSOR_UINT8 // input
171 );
172
173 return CoreOp.module(CoreOp.func("cnn", functionType).body(b -> {
174 // weights & biases
175 Block.Parameter conv1Weights = b.parameters().get(0);
176 Block.Parameter conv1Biases = b.parameters().get(1);
177 Block.Parameter conv2Weights = b.parameters().get(2);
178 Block.Parameter conv2Biases = b.parameters().get(3);
179 Block.Parameter fc1Weights = b.parameters().get(4);
180 Block.Parameter fc1Biases = b.parameters().get(5);
181 Block.Parameter fc2Weights = b.parameters().get(6);
182 Block.Parameter fc2Biases = b.parameters().get(7);
183 Block.Parameter fc3Weights = b.parameters().get(8);
184 Block.Parameter fc3Biases = b.parameters().get(9);
185 Block.Parameter ubyteImage = b.parameters().get(10);
186
187 var inputImage = b.op(OnnxOps.Cast(OnnxType.TENSOR_FLOAT32,
188 ubyteImage,
189 empty(),
190 OnnxType.TENSOR_FLOAT32.eType().id(),
191 empty()
192 ));
193
194 // Scaling the features
195 var scalingFactor = b.op(OnnxOps.Constant(OnnxType.TENSOR_FLOAT32,
196 empty(),
197 empty(),
198 empty(),
199 of((float) PIXEL_DEPTH),
200 empty(),
201 empty(),
202 empty(),
203 empty()));
204 var scaledInput = b.op(OnnxOps.Div(inputImage.type(), inputImage, scalingFactor));
205
206 // First conv layer
207 var conv1 = b.op(OnnxOps.Conv(scaledInput.type(),
208 scaledInput,
209 conv1Weights,
210 of(conv1Biases),
211 of(new long[4]),
212 of(new long[]{1,1}),
213 empty(),
214 of(new long[]{1, 1, 1, 1}),
215 of(1L),
216 of(new long[]{5,5})));
217 var relu1 = b.op(OnnxOps.Relu(conv1.type(),
218 conv1));
219
220 // First pooling layer
221 // @@@ multiple results?
222 var pool1Result = b.op(OnnxOps.MaxPool(CoreType.tupleType(relu1.type(), OnnxType.TENSOR_INT64),
223 Set.of(OnnxOps.MaxPool.OutputParameter.Indices),
224 relu1,
225 of(new long[4]),
226 of(new long[]{1,1}),
227 empty(),
228 of(0L),
229 empty(),
230 of(new long[]{2, 2}),
231 new long[]{2, 2}));
232
233 // Second conv layer
234 var pool1 = b.op(CoreOp.tupleLoad(pool1Result, 0));
235 var conv2 = b.op(OnnxOps.Conv(pool1.type(),
236 pool1,
237 conv2Weights,
238 of(conv2Biases),
239 of(new long[4]),
240 of(new long[]{1,1}),
241 empty(),
242 of(new long[]{1, 1, 1, 1}),
243 of(1L),
244 of(new long[]{5,5})));
245 var relu2 = b.op(OnnxOps.Relu(conv2.type(),
246 conv2));
247
248 // Second pooling layer
249 // @@@ multiple results?
250 var pool2Result = b.op(OnnxOps.MaxPool(CoreType.tupleType(relu2.type(), OnnxType.TENSOR_INT64),
251 Set.of(OnnxOps.MaxPool.OutputParameter.Indices),
252 relu2,
253 of(new long[4]),
254 of(new long[]{1,1}),
255 empty(),
256 of(0L),
257 empty(),
258 of(new long[]{2, 2}),
259 new long[]{2, 2}));
260
261 // Flatten inputs
262 var pool2 = b.op(CoreOp.tupleLoad(pool2Result, 0));
263 var flatten = b.op(OnnxOps.Flatten(pool2.type(),
264 pool2,
265 of(1L)));
266
267 // First fully connected layer
268 var fc1 = b.op(OnnxOps.Gemm(flatten.type(),
269 flatten,
270 fc1Weights,
271 of(fc1Biases),
272 of(1f),
273 of(1L),
274 of(1f),
275 empty()));
276 var relu3 = b.op(OnnxOps.Relu(fc1.type(),
277 fc1));
278
279 // Second fully connected layer
280 var fc2 = b.op(OnnxOps.Gemm(relu3.type(),
281 relu3,
282 fc2Weights,
283 of(fc2Biases),
284 of(1f),
285 of(1L),
286 of(1f),
287 empty()));
288 var relu4 = b.op(OnnxOps.Relu(fc2.type(),
289 fc2));
290
291 // Softmax layer
292 var fc3 = b.op(OnnxOps.Gemm(relu4.type(),
293 relu4,
294 fc3Weights,
295 of(fc3Biases),
296 of(1f),
297 of(1L),
298 of(1f),
299 empty()));
300 var prediction = b.op(OnnxOps.Softmax(fc3.type(),
301 fc3,
302 of(1L)));
303
304 b.op(CoreOp.return_(prediction));
305 }));
306 }
307
308 static void printImage(int imageIndex, MemorySegment data) {
309 System.out.println("Image #" + imageIndex + " :");
310 int offset = imageIndex * 28 * 28;
311 for (int y = 0; y < 28; y++) {
312 for (int x = 0; x < 28; x++) {
313 System.out.print(GREY_SCALE.charAt(GREY_SCALE.length() * (0xff & data.get(ValueLayout.JAVA_BYTE, offset + y * 28 + x)) / 256));
314 }
315 System.out.println();
316 }
317 }
318
319 private Tensor<Float> floatTensor(Arena arena, String resource, long... shape) throws IOException {
320 try (var file = new RandomAccessFile(CNNTest.class.getResource(resource).getPath(), "r")) {
321 return new Tensor(arena, file.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, file.length(), arena), Tensor.ElementType.FLOAT, shape);
322 }
323 }
324
325 @Test
326 public void testModels() {
327 try (var arena = Arena.ofConfined()) {
328 CoreOp.FuncOp f = getFuncOp("cnn");
329 CoreOp.ModuleOp onnxModel = OnnxTransformer.transform(MethodHandles.lookup(), f).module();
330 System.out.println(onnxModel.toText());
331
332 var expectedOnnxModel = cnnModel();
333 System.out.println(expectedOnnxModel.toText());
334
335 Assertions.assertEquals(serialize(expectedOnnxModel), serialize(onnxModel));
336 }
337 }
338
339 @Test
340 public void testInterpreter() throws Exception {
341 try (var arena = Arena.ofConfined()) {
342 var conv1Weight = floatTensor(arena, "mnist/conv1-weight-float-le", 6, 1, 5, 5);
343 var conv1Bias = floatTensor(arena, "mnist/conv1-bias-float-le", 6);
344 var conv2Weight = floatTensor(arena, "mnist/conv2-weight-float-le", 16, 6, 5, 5);
345 var conv2Bias = floatTensor(arena, "mnist/conv2-bias-float-le", 16);
346 var fc1Weight = floatTensor(arena, "mnist/fc1-weight-float-le", 120, 256);
347 var fc1Bias = floatTensor(arena, "mnist/fc1-bias-float-le", 120);
348 var fc2Weight = floatTensor(arena, "mnist/fc2-weight-float-le", 84, 120);
349 var fc2Bias = floatTensor(arena, "mnist/fc2-bias-float-le", 84);
350 var fc3Weight = floatTensor(arena, "mnist/fc3-weight-float-le", 10, 84);
351 var fc3Bias = floatTensor(arena, "mnist/fc3-bias-float-le", 10);
352 test(arena, inputImage -> cnn(conv1Weight, conv1Bias,
353 conv2Weight, conv2Bias,
354 fc1Weight, fc1Bias,
355 fc2Weight, fc2Bias,
356 fc3Weight, fc3Bias,
357 inputImage));
358 }
359 }
360
361 @Test
362 public void testProtobufModel() throws Exception {
363 try (var arena = Arena.ofConfined()) {
364 var conv1Weight = floatTensor(arena, "mnist/conv1-weight-float-le", 6, 1, 5, 5);
365 var conv1Bias = floatTensor(arena, "mnist/conv1-bias-float-le", 6);
366 var conv2Weight = floatTensor(arena, "mnist/conv2-weight-float-le", 16, 6, 5, 5);
367 var conv2Bias = floatTensor(arena, "mnist/conv2-bias-float-le", 16);
368 var fc1Weight = floatTensor(arena, "mnist/fc1-weight-float-le", 120, 256);
369 var fc1Bias = floatTensor(arena, "mnist/fc1-bias-float-le", 120);
370 var fc2Weight = floatTensor(arena, "mnist/fc2-weight-float-le", 84, 120);
371 var fc2Bias = floatTensor(arena, "mnist/fc2-bias-float-le", 84);
372 var fc3Weight = floatTensor(arena, "mnist/fc3-weight-float-le", 10, 84);
373 var fc3Bias = floatTensor(arena, "mnist/fc3-bias-float-le", 10);
374 test(arena, inputImage -> OnnxRuntime.execute(arena, MethodHandles.lookup(), (@Reflect Supplier<Tensor<Float>>) () ->
375 cnn(conv1Weight, conv1Bias, conv2Weight, conv2Bias,
376 fc1Weight, fc1Bias, fc2Weight, fc2Bias, fc3Weight, fc3Bias,
377 inputImage)));
378 }
379 }
380
381 private static final long[] IMAGE_SHAPE = new long[]{-1, 1, 28, 28};
382
383 private void test(Arena arena, Function<Tensor<Byte>, Tensor<Float>> executor) throws Exception {
384 try (RandomAccessFile imagesF = new RandomAccessFile(IMAGES_PATH, "r");
385 RandomAccessFile labelsF = new RandomAccessFile(LABELS_PATH, "r")) {
386
387 MemorySegment imagesIn = imagesF.getChannel().map(FileChannel.MapMode.READ_ONLY, IMAGES_HEADER_SIZE, imagesF.length() - IMAGES_HEADER_SIZE, arena);
388 MemorySegment labelsIn = labelsF.getChannel().map(FileChannel.MapMode.READ_ONLY, LABELS_HEADER_SIZE, labelsF.length() - LABELS_HEADER_SIZE, arena);
389
390 Tensor<Byte> inputImage = new Tensor(arena, imagesIn, Tensor.ElementType.UINT8, IMAGE_SHAPE);
391
392 MemorySegment result = executor.apply(inputImage).data();
393
394 int matched = 0, mismatched = 0;
395 int i = 0;
396 int resultSize = (int)result.byteSize() / 4;
397 while (i < resultSize) {
398 int expected = labelsIn.get(ValueLayout.JAVA_BYTE, i / 10);
399
400 int actual = 0;
401 float maxW = result.getAtIndex(ValueLayout.JAVA_FLOAT, i++);
402 for (int j = 1; j < 10; j++) {
403 float w = result.getAtIndex(ValueLayout.JAVA_FLOAT, i++);
404 if (w > maxW) {
405 maxW = w;
406 actual = j;
407 }
408 }
409
410 if (expected == actual) {
411 matched++;
412 } else {
413 int imageIndex = i / 10 - 1;
414 printImage(imageIndex, imagesIn);
415 System.out.println("expected: " + expected + " actual: " + actual);
416 System.out.println("-".repeat(28));
417 mismatched++;
418 }
419 }
420 System.out.println("matched: " + matched + " mismatched: " + mismatched);
421 Assertions.assertTrue(mismatched / matched < 0.05);
422 }
423 }
424
425 static String serialize(Op o) {
426 StringWriter w = new StringWriter();
427 OpWriter.writeTo(w, o, OpWriter.LocationOption.DROP_LOCATION);
428 return w.toString();
429 }
430
431 static CoreOp.FuncOp getFuncOp(String name) {
432 Optional<Method> om = Stream.of(CNNTest.class.getDeclaredMethods())
433 .filter(m -> m.getName().equals(name))
434 .findFirst();
435
436 Method m = om.get();
437 return Op.ofMethod(m).get();
438 }
439
440 // @Reflect
441 // public Tensor<Float> loadWeight(Initializer init) {
442 // var buf = ByteBuffer.allocate(init.values().length).order(ByteOrder.nativeOrder());
443 // buf.put(init.values());
444 // buf.rewind();
445 // var floatBuf = buf.asFloatBuffer();
446 // var floatArr = new float[floatBuf.remaining()];
447 // floatBuf.get(floatArr);
448 // Tensor<Long> shape = Constant(
449 // empty(), empty(), empty(), empty(), empty(), of(init.shape()), empty(), empty()
450 // );
451 // Tensor<Float> floats = Constant(
452 // empty(), of(floatArr), empty(), empty(), empty(), empty(), empty(), empty()
453 // );
454 // var shaped = Reshape(floats, shape, empty());
455 // return shaped;
456 // }
457 //
458 // public static void extractWeights(Path inputOnnx, Path outputSerialized) throws IOException {
459 // try (InputStream is = Files.newInputStream(inputOnnx)) {
460 // OnnxMl.ModelProto model = OnnxMl.ModelProto.parseFrom(is);
461 // OnnxMl.GraphProto graph = model.getGraph();
462 // List<Initializer> initList = new ArrayList<>();
463 // for (var init : graph.getInitializerList()) {
464 // var name = init.getName();
465 // var type = init.getDataType();
466 // var shape = init.getDimsList().stream().mapToLong(a -> a).toArray();
467 // var valuesBuf = init.getRawData().asReadOnlyByteBuffer();
468 // var valuesArr = new byte[valuesBuf.remaining()];
469 // valuesBuf.get(valuesArr);
470 // var initializer = new Initializer(name, type, shape, valuesArr);
471 // System.out.println(initializer);
472 // initList.add(initializer);
473 // }
474 // try (ObjectOutputStream oos = new ObjectOutputStream(Files.newOutputStream(outputSerialized))) {
475 // oos.writeObject(initList);
476 // }
477 // }
478 // }
479 //
480 // public record Initializer(String name, int type, long[] shape, byte[] values) implements java.io.Serializable {
481 // @Override
482 // public String toString() {
483 // return "Initializer{" +
484 // "name='" + name + '\'' +
485 // ", type=" + type +
486 // ", shape=" + Arrays.toString(shape) +
487 // ", values.length=" + values.length +
488 // '}';
489 // }
490 // }
491 //
492 // public static void main(String[] args) throws IOException {
493 // Path inputPath = Path.of(args[0]);
494 //
495 // Path outputPath = Path.of(args[1]);
496 //
497 // extractWeights(inputPath, outputPath);
498 // }
499
500 /*
501 ONNX code model
502
503 func @"cnn" (
504 %0 : tensor<float32>,
505 %1 : tensor<float32>,
506 %2 : tensor<float32>,
507 %3 : tensor<float32>,
508 %4 : tensor<float32>,
509 %5 : tensor<float32>,
510 %6 : tensor<float32>,
511 %7 : tensor<float32>,
512 %8 : tensor<float32>,
513 %9 : tensor<float32>,
514 %10 : tensor<float32>)tensor<float32> -> {
515 %11 : tensor<int64> = Constant @value_ints="[I@32910148";
516 %12 : tensor<float32> = Reshape %0 %11;
517 %13 : tensor<float32> = Constant @value_float="255.0";
518 %14 : tensor<float32> = Div %12 %13;
519 %15 : tensor<float32> = Conv %14 %1 %2 @optional_inputs="[B]" @strides="[I@2b4bac49" @pads="[I@fd07cbb" @dilations="[I@3571b748" @group="1" @kernel_shape="[I@3e96bacf";
520 %16 : tensor<float32> = Relu %15;
521 %17 : tensor<float32> = MaxPool %16 @ceil_mode="0" @strides="[I@484970b0" @pads="[I@4470f8a6" @dilations="[I@7c83dc97" @kernel_shape="[I@7748410a";
522 %18 : tensor<float32> = Conv %17 %3 %4 @optional_inputs="[B]" @strides="[I@740773a3" @pads="[I@37f1104d" @dilations="[I@55740540" @group="1" @kernel_shape="[I@60015ef5";
523 %19 : tensor<float32> = Relu %18;
524 %20 : tensor<float32> = MaxPool %19 @ceil_mode="0" @strides="[I@2f54a33d" @pads="[I@1018bde2" @dilations="[I@65b3f4a4" @kernel_shape="[I@f2ff811";
525 %21 : tensor<float32> = Flatten %20 @axis="1";
526 %22 : tensor<float32> = Gemm %21 %5 %6 @optional_inputs="[C]" @transB="1" @beta="1.0" @alpha="1.0";
527 %23 : tensor<float32> = Relu %22;
528 %24 : tensor<float32> = Gemm %23 %7 %8 @optional_inputs="[C]" @transB="1" @beta="1.0" @alpha="1.0";
529 %25 : tensor<float32> = Relu %24;
530 %26 : tensor<float32> = Gemm %25 %9 %10 @optional_inputs="[C]" @transB="1" @beta="1.0" @alpha="1.0";
531 %27 : tensor<float32> = Softmax %26 @axis="1";
532 return %27;
533 };
534 */
535 }