1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package oracle.code.onnx;
25
26 import java.io.*;
27 import java.lang.foreign.Arena;
28 import jdk.incubator.code.Block;
29 import jdk.incubator.code.CodeReflection;
30 import jdk.incubator.code.Op;
31 import jdk.incubator.code.dialect.core.CoreOp;
32 import jdk.incubator.code.dialect.core.CoreType;
33 import jdk.incubator.code.dialect.core.FunctionType;
34 import jdk.incubator.code.dialect.core.TupleType;
35 import jdk.incubator.code.extern.OpWriter;
36 import oracle.code.onnx.ir.OnnxOps;
37 import oracle.code.onnx.ir.OnnxType;
38 import org.junit.jupiter.api.Test;
39 import org.junit.jupiter.api.Assertions;
40
41 import java.lang.invoke.MethodHandles;
42 import java.lang.reflect.Method;
43 import java.util.Optional;
44 import java.util.Set;
45 import java.util.stream.Stream;
46 import java.lang.foreign.MemorySegment;
47 import java.lang.foreign.ValueLayout;
48 import java.nio.channels.FileChannel;
49 import java.util.function.Function;
50 import oracle.code.onnx.compiler.OnnxTransformer;
51
52 import static java.util.Optional.empty;
53 import static java.util.Optional.of;
54 import static oracle.code.onnx.OnnxOperators.Cast;
55 import static oracle.code.onnx.OnnxOperators.Constant;
56 import static oracle.code.onnx.OnnxOperators.Conv;
57 import static oracle.code.onnx.OnnxOperators.Div;
58 import static oracle.code.onnx.OnnxOperators.Flatten;
59 import static oracle.code.onnx.OnnxOperators.Gemm;
60 import static oracle.code.onnx.OnnxOperators.MaxPool;
61 import static oracle.code.onnx.OnnxOperators.Relu;
62 import static oracle.code.onnx.OnnxOperators.Softmax;
63
64 // A rough CNN implementation which expects a input [batch_size, 1, 28, 28].
65 // Over time we will improve the operator expressions to reduce
66 // the verbosity e.g., esp. scalar constant expressions
67 public class CNNTest {
68
69 private static final String IMAGES_PATH = CNNTest.class.getResource("images-ubyte").getPath();
70 private static final String LABELS_PATH = CNNTest.class.getResource("labels-ubyte").getPath();
71 private static final int IMAGES_HEADER_SIZE = 0;
72 private static final int LABELS_HEADER_SIZE = 0;
73
74 // static final String IMAGES_PATH = CNNTest.class.getResource("t10k-images-idx3-ubyte").getPath();
75 // static final String LABELS_PATH = CNNTest.class.getResource("t10k-labels-idx1-ubyte").getPath();
76 // static final int IMAGES_HEADER_SIZE = 16;
77 // static final int LABELS_HEADER_SIZE = 8;
78
79 private static final String GREY_SCALE = " .'`^\",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$";
80 private static final int PIXEL_DEPTH = 255;
81 private static final int NUM_CHANNELS = 1;
82 private static final int IMAGE_SIZE = 28;
83
84 @CodeReflection
85 public static Tensor<Float> cnn(
86 // Weights and biases
87 // [6, 1, 5, 5]
88 Tensor<Float> conv1Weights,
89 // [6]
90 Tensor<Float> conv1Biases,
91 // [16, 6, 5, 5]
92 Tensor<Float> conv2Weights,
93 // [16]
94 Tensor<Float> conv2Biases,
95 // [120, 256]
96 Tensor<Float> fc1Weights,
97 // [120]
98 Tensor<Float> fc1Biases,
99 // [84, 120]
100 Tensor<Float> fc2Weights,
101 // [84]
102 Tensor<Float> fc2Biases,
103 // [NUM_LABELS, 84]
104 Tensor<Float> fc3Weights,
105 // [NUM_LABELS]
106 Tensor<Float> fc3Biases,
107 // Inputs
108 Tensor<Byte> ubyteImage) {
109
110 Tensor<Float> inputImage = Cast(ubyteImage, empty(), Tensor.ElementType.FLOAT.id, empty());
111
112 // Scaling the features to 0-1
113 var scalingFactor = Constant((float) PIXEL_DEPTH);
114 var scaledInput = Div(inputImage, scalingFactor);
115
116 // First conv layer
117 var conv1 = Conv(scaledInput, conv1Weights, of(conv1Biases), of(new long[4]),
118 of(new long[]{1,1}), empty(), of(new long[]{1, 1, 1, 1}),
119 of(1L), of(new long[]{5,5}));
120 var relu1 = Relu(conv1);
121
122 // First pooling layer
123 var pool1 = MaxPool(relu1, of(new long[4]), of(new long[]{1,1}), empty(),
124 of(0L), empty(), of(new long[]{2, 2}), new long[]{2, 2});
125
126 // Second conv layer
127 var conv2 = Conv(pool1.Y(), conv2Weights, of(conv2Biases), of(new long[4]),
128 of(new long[]{1,1}), empty(), of(new long[]{1, 1, 1, 1}),
129 of(1L), of(new long[]{5,5}));
130 var relu2 = Relu(conv2);
131
132 // Second pooling layer
133 var pool2 = MaxPool(relu2, of(new long[4]), of(new long[]{1,1}), empty(),
134 of(0L), empty(), of(new long[]{2, 2}), new long[]{2, 2});
135
136 // Flatten inputs
137 var flatten = Flatten(pool2.Y(), of(1L));
138
139 // First fully connected layer
140 var fc1 = Gemm(flatten, fc1Weights, of(fc1Biases), of(1f), of(1L), of(1f), empty());
141 var relu3 = Relu(fc1);
142
143 // Second fully connected layer
144 var fc2 = Gemm(relu3, fc2Weights, of(fc2Biases), of(1f), of(1L), of(1f), empty());
145 var relu4 = Relu(fc2);
146
147 // Softmax layer
148 var fc3 = Gemm(relu4, fc3Weights, of(fc3Biases), of(1f), of(1L), of(1f), empty());
149 var prediction = Softmax(fc3, of(1L));
150
151 return prediction;
152 }
153
154 static CoreOp.ModuleOp cnnModel() {
155 // @@@ function type and result types with correct tensor element and shape
156
157 FunctionType functionType = CoreType.functionType(
158 OnnxType.TENSOR_FLOAT32, // return
159 OnnxType.TENSOR_FLOAT32, // conv1Weights
160 OnnxType.TENSOR_FLOAT32, // conv1Biases
161 OnnxType.TENSOR_FLOAT32, // conv2Weights
162 OnnxType.TENSOR_FLOAT32, // conv2Biases
163 OnnxType.TENSOR_FLOAT32, // fc1Weights
164 OnnxType.TENSOR_FLOAT32, // fc1Biases
165 OnnxType.TENSOR_FLOAT32, // fc2Weights
166 OnnxType.TENSOR_FLOAT32, // fc2Biases
167 OnnxType.TENSOR_FLOAT32, // fc3Weights
168 OnnxType.TENSOR_FLOAT32, // fc3Biases
169 OnnxType.TENSOR_UINT8 // input
170 );
171
172 return CoreOp.module(CoreOp.func("cnn", functionType).body(b -> {
173 // weights & biases
174 Block.Parameter conv1Weights = b.parameters().get(0);
175 Block.Parameter conv1Biases = b.parameters().get(1);
176 Block.Parameter conv2Weights = b.parameters().get(2);
177 Block.Parameter conv2Biases = b.parameters().get(3);
178 Block.Parameter fc1Weights = b.parameters().get(4);
179 Block.Parameter fc1Biases = b.parameters().get(5);
180 Block.Parameter fc2Weights = b.parameters().get(6);
181 Block.Parameter fc2Biases = b.parameters().get(7);
182 Block.Parameter fc3Weights = b.parameters().get(8);
183 Block.Parameter fc3Biases = b.parameters().get(9);
184 Block.Parameter ubyteImage = b.parameters().get(10);
185
186 var inputImage = b.op(OnnxOps.Cast(OnnxType.TENSOR_FLOAT32,
187 ubyteImage,
188 empty(),
189 OnnxType.TENSOR_FLOAT32.eType().id(),
190 empty()
191 ));
192
193 // Scaling the features
194 var scalingFactor = b.op(OnnxOps.Constant(OnnxType.TENSOR_FLOAT32,
195 empty(),
196 empty(),
197 empty(),
198 of((float) PIXEL_DEPTH),
199 empty(),
200 empty(),
201 empty(),
202 empty()));
203 var scaledInput = b.op(OnnxOps.Div(inputImage.type(), inputImage, scalingFactor));
204
205 // First conv layer
206 var conv1 = b.op(OnnxOps.Conv(scaledInput.type(),
207 scaledInput,
208 conv1Weights,
209 of(conv1Biases),
210 of(new long[4]),
211 of(new long[]{1,1}),
212 empty(),
213 of(new long[]{1, 1, 1, 1}),
214 of(1L),
215 of(new long[]{5,5})));
216 var relu1 = b.op(OnnxOps.Relu(conv1.type(),
217 conv1));
218
219 // First pooling layer
220 // @@@ multiple results?
221 var pool1Result = b.op(OnnxOps.MaxPool(CoreType.tupleType(relu1.type(), OnnxType.TENSOR_INT64),
222 Set.of(OnnxOps.MaxPool.OutputParameter.Indices),
223 relu1,
224 of(new long[4]),
225 of(new long[]{1,1}),
226 empty(),
227 of(0L),
228 empty(),
229 of(new long[]{2, 2}),
230 new long[]{2, 2}));
231
232 // Second conv layer
233 var pool1 = b.op(CoreOp.tupleLoad(pool1Result, 0));
234 var conv2 = b.op(OnnxOps.Conv(pool1.type(),
235 pool1,
236 conv2Weights,
237 of(conv2Biases),
238 of(new long[4]),
239 of(new long[]{1,1}),
240 empty(),
241 of(new long[]{1, 1, 1, 1}),
242 of(1L),
243 of(new long[]{5,5})));
244 var relu2 = b.op(OnnxOps.Relu(conv2.type(),
245 conv2));
246
247 // Second pooling layer
248 // @@@ multiple results?
249 var pool2Result = b.op(OnnxOps.MaxPool(CoreType.tupleType(relu2.type(), OnnxType.TENSOR_INT64),
250 Set.of(OnnxOps.MaxPool.OutputParameter.Indices),
251 relu2,
252 of(new long[4]),
253 of(new long[]{1,1}),
254 empty(),
255 of(0L),
256 empty(),
257 of(new long[]{2, 2}),
258 new long[]{2, 2}));
259
260 // Flatten inputs
261 var pool2 = b.op(CoreOp.tupleLoad(pool2Result, 0));
262 var flatten = b.op(OnnxOps.Flatten(pool2.type(),
263 pool2,
264 of(1L)));
265
266 // First fully connected layer
267 var fc1 = b.op(OnnxOps.Gemm(flatten.type(),
268 flatten,
269 fc1Weights,
270 of(fc1Biases),
271 of(1f),
272 of(1L),
273 of(1f),
274 empty()));
275 var relu3 = b.op(OnnxOps.Relu(fc1.type(),
276 fc1));
277
278 // Second fully connected layer
279 var fc2 = b.op(OnnxOps.Gemm(relu3.type(),
280 relu3,
281 fc2Weights,
282 of(fc2Biases),
283 of(1f),
284 of(1L),
285 of(1f),
286 empty()));
287 var relu4 = b.op(OnnxOps.Relu(fc2.type(),
288 fc2));
289
290 // Softmax layer
291 var fc3 = b.op(OnnxOps.Gemm(relu4.type(),
292 relu4,
293 fc3Weights,
294 of(fc3Biases),
295 of(1f),
296 of(1L),
297 of(1f),
298 empty()));
299 var prediction = b.op(OnnxOps.Softmax(fc3.type(),
300 fc3,
301 of(1L)));
302
303 b.op(CoreOp.return_(prediction));
304 }));
305 }
306
307 static void printImage(int imageIndex, MemorySegment data) {
308 System.out.println("Image #" + imageIndex + " :");
309 int offset = imageIndex * 28 * 28;
310 for (int y = 0; y < 28; y++) {
311 for (int x = 0; x < 28; x++) {
312 System.out.print(GREY_SCALE.charAt(GREY_SCALE.length() * (0xff & data.get(ValueLayout.JAVA_BYTE, offset + y * 28 + x)) / 256));
313 }
314 System.out.println();
315 }
316 }
317
318 private Tensor<Float> floatTensor(Arena arena, String resource, long... shape) throws IOException {
319 try (var file = new RandomAccessFile(CNNTest.class.getResource(resource).getPath(), "r")) {
320 return new Tensor(arena, file.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, file.length(), arena), Tensor.ElementType.FLOAT, shape);
321 }
322 }
323
324 @Test
325 public void testModels() {
326 try (var arena = Arena.ofConfined()) {
327 CoreOp.FuncOp f = getFuncOp("cnn");
328 CoreOp.ModuleOp onnxModel = OnnxTransformer.transform(MethodHandles.lookup(), f).module();
329 System.out.println(onnxModel.toText());
330
331 var expectedOnnxModel = cnnModel();
332 System.out.println(expectedOnnxModel.toText());
333
334 Assertions.assertEquals(serialize(expectedOnnxModel), serialize(onnxModel));
335 }
336 }
337
338 @Test
339 public void testInterpreter() throws Exception {
340 try (var arena = Arena.ofConfined()) {
341 var conv1Weight = floatTensor(arena, "mnist/conv1-weight-float-le", 6, 1, 5, 5);
342 var conv1Bias = floatTensor(arena, "mnist/conv1-bias-float-le", 6);
343 var conv2Weight = floatTensor(arena, "mnist/conv2-weight-float-le", 16, 6, 5, 5);
344 var conv2Bias = floatTensor(arena, "mnist/conv2-bias-float-le", 16);
345 var fc1Weight = floatTensor(arena, "mnist/fc1-weight-float-le", 120, 256);
346 var fc1Bias = floatTensor(arena, "mnist/fc1-bias-float-le", 120);
347 var fc2Weight = floatTensor(arena, "mnist/fc2-weight-float-le", 84, 120);
348 var fc2Bias = floatTensor(arena, "mnist/fc2-bias-float-le", 84);
349 var fc3Weight = floatTensor(arena, "mnist/fc3-weight-float-le", 10, 84);
350 var fc3Bias = floatTensor(arena, "mnist/fc3-bias-float-le", 10);
351 test(arena, inputImage -> cnn(conv1Weight, conv1Bias,
352 conv2Weight, conv2Bias,
353 fc1Weight, fc1Bias,
354 fc2Weight, fc2Bias,
355 fc3Weight, fc3Bias,
356 inputImage));
357 }
358 }
359
360 @Test
361 public void testProtobufModel() throws Exception {
362 try (var arena = Arena.ofConfined()) {
363 var conv1Weight = floatTensor(arena, "mnist/conv1-weight-float-le", 6, 1, 5, 5);
364 var conv1Bias = floatTensor(arena, "mnist/conv1-bias-float-le", 6);
365 var conv2Weight = floatTensor(arena, "mnist/conv2-weight-float-le", 16, 6, 5, 5);
366 var conv2Bias = floatTensor(arena, "mnist/conv2-bias-float-le", 16);
367 var fc1Weight = floatTensor(arena, "mnist/fc1-weight-float-le", 120, 256);
368 var fc1Bias = floatTensor(arena, "mnist/fc1-bias-float-le", 120);
369 var fc2Weight = floatTensor(arena, "mnist/fc2-weight-float-le", 84, 120);
370 var fc2Bias = floatTensor(arena, "mnist/fc2-bias-float-le", 84);
371 var fc3Weight = floatTensor(arena, "mnist/fc3-weight-float-le", 10, 84);
372 var fc3Bias = floatTensor(arena, "mnist/fc3-bias-float-le", 10);
373 test(arena, inputImage -> OnnxRuntime.execute(arena, MethodHandles.lookup(), () ->
374 cnn(conv1Weight, conv1Bias, conv2Weight, conv2Bias,
375 fc1Weight, fc1Bias, fc2Weight, fc2Bias, fc3Weight, fc3Bias,
376 inputImage)));
377 }
378 }
379
380 private static final long[] IMAGE_SHAPE = new long[]{-1, 1, 28, 28};
381
382 private void test(Arena arena, Function<Tensor<Byte>, Tensor<Float>> executor) throws Exception {
383 try (RandomAccessFile imagesF = new RandomAccessFile(IMAGES_PATH, "r");
384 RandomAccessFile labelsF = new RandomAccessFile(LABELS_PATH, "r")) {
385
386 MemorySegment imagesIn = imagesF.getChannel().map(FileChannel.MapMode.READ_ONLY, IMAGES_HEADER_SIZE, imagesF.length() - IMAGES_HEADER_SIZE, arena);
387 MemorySegment labelsIn = labelsF.getChannel().map(FileChannel.MapMode.READ_ONLY, LABELS_HEADER_SIZE, labelsF.length() - LABELS_HEADER_SIZE, arena);
388
389 Tensor<Byte> inputImage = new Tensor(arena, imagesIn, Tensor.ElementType.UINT8, IMAGE_SHAPE);
390
391 MemorySegment result = executor.apply(inputImage).data();
392
393 int matched = 0, mismatched = 0;
394 int i = 0;
395 int resultSize = (int)result.byteSize() / 4;
396 while (i < resultSize) {
397 int expected = labelsIn.get(ValueLayout.JAVA_BYTE, i / 10);
398
399 int actual = 0;
400 float maxW = result.getAtIndex(ValueLayout.JAVA_FLOAT, i++);
401 for (int j = 1; j < 10; j++) {
402 float w = result.getAtIndex(ValueLayout.JAVA_FLOAT, i++);
403 if (w > maxW) {
404 maxW = w;
405 actual = j;
406 }
407 }
408
409 if (expected == actual) {
410 matched++;
411 } else {
412 int imageIndex = i / 10 - 1;
413 printImage(imageIndex, imagesIn);
414 System.out.println("expected: " + expected + " actual: " + actual);
415 System.out.println("-".repeat(28));
416 mismatched++;
417 }
418 }
419 System.out.println("matched: " + matched + " mismatched: " + mismatched);
420 Assertions.assertTrue(mismatched / matched < 0.05);
421 }
422 }
423
424 static String serialize(Op o) {
425 StringWriter w = new StringWriter();
426 OpWriter.writeTo(w, o, OpWriter.LocationOption.DROP_LOCATION);
427 return w.toString();
428 }
429
430 static CoreOp.FuncOp getFuncOp(String name) {
431 Optional<Method> om = Stream.of(CNNTest.class.getDeclaredMethods())
432 .filter(m -> m.getName().equals(name))
433 .findFirst();
434
435 Method m = om.get();
436 return Op.ofMethod(m).get();
437 }
438
439 // @CodeReflection
440 // public Tensor<Float> loadWeight(Initializer init) {
441 // var buf = ByteBuffer.allocate(init.values().length).order(ByteOrder.nativeOrder());
442 // buf.put(init.values());
443 // buf.rewind();
444 // var floatBuf = buf.asFloatBuffer();
445 // var floatArr = new float[floatBuf.remaining()];
446 // floatBuf.get(floatArr);
447 // Tensor<Long> shape = Constant(
448 // empty(), empty(), empty(), empty(), empty(), of(init.shape()), empty(), empty()
449 // );
450 // Tensor<Float> floats = Constant(
451 // empty(), of(floatArr), empty(), empty(), empty(), empty(), empty(), empty()
452 // );
453 // var shaped = Reshape(floats, shape, empty());
454 // return shaped;
455 // }
456 //
457 // public static void extractWeights(Path inputOnnx, Path outputSerialized) throws IOException {
458 // try (InputStream is = Files.newInputStream(inputOnnx)) {
459 // OnnxMl.ModelProto model = OnnxMl.ModelProto.parseFrom(is);
460 // OnnxMl.GraphProto graph = model.getGraph();
461 // List<Initializer> initList = new ArrayList<>();
462 // for (var init : graph.getInitializerList()) {
463 // var name = init.getName();
464 // var type = init.getDataType();
465 // var shape = init.getDimsList().stream().mapToLong(a -> a).toArray();
466 // var valuesBuf = init.getRawData().asReadOnlyByteBuffer();
467 // var valuesArr = new byte[valuesBuf.remaining()];
468 // valuesBuf.get(valuesArr);
469 // var initializer = new Initializer(name, type, shape, valuesArr);
470 // System.out.println(initializer);
471 // initList.add(initializer);
472 // }
473 // try (ObjectOutputStream oos = new ObjectOutputStream(Files.newOutputStream(outputSerialized))) {
474 // oos.writeObject(initList);
475 // }
476 // }
477 // }
478 //
479 // public record Initializer(String name, int type, long[] shape, byte[] values) implements java.io.Serializable {
480 // @Override
481 // public String toString() {
482 // return "Initializer{" +
483 // "name='" + name + '\'' +
484 // ", type=" + type +
485 // ", shape=" + Arrays.toString(shape) +
486 // ", values.length=" + values.length +
487 // '}';
488 // }
489 // }
490 //
491 // public static void main(String[] args) throws IOException {
492 // Path inputPath = Path.of(args[0]);
493 //
494 // Path outputPath = Path.of(args[1]);
495 //
496 // extractWeights(inputPath, outputPath);
497 // }
498
499 /*
500 ONNX code model
501
502 func @"cnn" (
503 %0 : tensor<float32>,
504 %1 : tensor<float32>,
505 %2 : tensor<float32>,
506 %3 : tensor<float32>,
507 %4 : tensor<float32>,
508 %5 : tensor<float32>,
509 %6 : tensor<float32>,
510 %7 : tensor<float32>,
511 %8 : tensor<float32>,
512 %9 : tensor<float32>,
513 %10 : tensor<float32>)tensor<float32> -> {
514 %11 : tensor<int64> = Constant @value_ints="[I@32910148";
515 %12 : tensor<float32> = Reshape %0 %11;
516 %13 : tensor<float32> = Constant @value_float="255.0";
517 %14 : tensor<float32> = Div %12 %13;
518 %15 : tensor<float32> = Conv %14 %1 %2 @optional_inputs="[B]" @strides="[I@2b4bac49" @pads="[I@fd07cbb" @dilations="[I@3571b748" @group="1" @kernel_shape="[I@3e96bacf";
519 %16 : tensor<float32> = Relu %15;
520 %17 : tensor<float32> = MaxPool %16 @ceil_mode="0" @strides="[I@484970b0" @pads="[I@4470f8a6" @dilations="[I@7c83dc97" @kernel_shape="[I@7748410a";
521 %18 : tensor<float32> = Conv %17 %3 %4 @optional_inputs="[B]" @strides="[I@740773a3" @pads="[I@37f1104d" @dilations="[I@55740540" @group="1" @kernel_shape="[I@60015ef5";
522 %19 : tensor<float32> = Relu %18;
523 %20 : tensor<float32> = MaxPool %19 @ceil_mode="0" @strides="[I@2f54a33d" @pads="[I@1018bde2" @dilations="[I@65b3f4a4" @kernel_shape="[I@f2ff811";
524 %21 : tensor<float32> = Flatten %20 @axis="1";
525 %22 : tensor<float32> = Gemm %21 %5 %6 @optional_inputs="[C]" @transB="1" @beta="1.0" @alpha="1.0";
526 %23 : tensor<float32> = Relu %22;
527 %24 : tensor<float32> = Gemm %23 %7 %8 @optional_inputs="[C]" @transB="1" @beta="1.0" @alpha="1.0";
528 %25 : tensor<float32> = Relu %24;
529 %26 : tensor<float32> = Gemm %25 %9 %10 @optional_inputs="[C]" @transB="1" @beta="1.0" @alpha="1.0";
530 %27 : tensor<float32> = Softmax %26 @axis="1";
531 return %27;
532 };
533 */
534 }