1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx.ir;
27
28 import java.util.*;
29 import jdk.incubator.code.*;
30 import jdk.incubator.code.Op.Nested;
31 import jdk.incubator.code.extern.ExternalizedOp;
32 import jdk.incubator.code.extern.OpFactory;
33
34 public sealed class ExplicitOnnxOps permits OnnxOps {
35
36 // @@@ this should be generated from contrib operators
37 @OpFactoryHelper.OpDeclaration(GroupQueryAttention.NAME)
38 public static final class GroupQueryAttention extends OnnxOp {
39 public static final String NAME = "com.microsoft.GroupQueryAttention";
40
41 public enum Attribute implements OnnxAttribute {
42 do_rotary(Long.class, true, 0),
43 kv_num_heads(Long.class, false, null),
44 local_window_size(Long.class, true, -1),
45 num_heads(Long.class, false, null),
46 rotary_interleaved(Long.class, true, 0),
47 scale(Float.class, true, null), // @@@ Default value is 1/sqrt(head_size)
48 ;
49
50 final Class<?> t;
51 final boolean optional;
52 final Object defaultValue;
53
54 Attribute(Class<?> type, boolean optional, Object defaultValue) {
55 this.t = type;
56 this.optional = optional;
57 this.defaultValue = defaultValue;
58 assert optional || defaultValue == null;
59 }
60
61 public Class<?> type() {
62 return t;
63 }
64
65 public boolean isOptional() {
66 return optional;
67 }
68
69 public Object defaultValue() {
70 return defaultValue;
71 }
72 }
73
74 public enum TypeConstraint implements OnnxTypeConstraint {
75 T(new OnnxType.TypeVariable("T", List.of(OnnxType.tensor(OnnxType.float16()), OnnxType.tensor(OnnxType.bfloat16()), OnnxType.tensor(OnnxType.float32())))),
76 M(new OnnxType.TypeVariable("M", List.of(OnnxType.tensor(OnnxType.int32())))),
77 ;
78
79 final OnnxType.TypeVariable typeVariable;
80
81 TypeConstraint(OnnxType.TypeVariable typeVariable) {
82 assert typeVariable.name().equals(name());
83 this.typeVariable = typeVariable;
84 }
85
86 @Override
87 public OnnxType.TypeVariable typeVariable() {
88 return typeVariable;
89 }
90 }
91
92 public enum InputParameter implements OnnxParameter {
93 query(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
94 key(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
95 value(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
96 past_key(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
97 past_value(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
98 seqlens_k(TypeConstraint.M.typeVariable(), Quantifier.REQUIRED),
99 total_sequence_length(TypeConstraint.M.typeVariable(), Quantifier.REQUIRED),
100 cos_cache(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
101 sin_cache(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
102 ;
103
104 final OnnxType type;
105 final Quantifier quantifier;
106
107 InputParameter(OnnxType type, Quantifier quantifier) {
108 this.type = type;
109 this.quantifier = quantifier;
110 }
111
112 @Override
113 public OnnxType type() {
114 return type;
115 }
116
117 @Override
118 public Quantifier quantifier() {
119 return quantifier;
120 }
121 }
122
123 public enum OutputParameter implements OnnxParameter {
124 output(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
125 present_key(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
126 present_value(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
127 ;
128
129 final OnnxType type;
130 final Quantifier quantifier;
131
132 OutputParameter(OnnxType type, Quantifier quantifier) {
133 this.type = type;
134 this.quantifier = quantifier;
135 }
136
137 @Override
138 public OnnxType type() {
139 return type;
140 }
141
142 @Override
143 public Quantifier quantifier() {
144 return quantifier;
145 }
146 }
147
148 public static final OnnxSchema SCHEMA = new OnnxSchemaRecord(
149 NAME,
150 List.of(Attribute.values()),
151 List.of(TypeConstraint.values()),
152 List.of(InputParameter.values()),
153 List.of(OutputParameter.values())
154 );
155
156 public GroupQueryAttention(ExternalizedOp def) {
157 super(SCHEMA, def);
158 }
159
160 GroupQueryAttention(GroupQueryAttention that, CodeContext cc) {
161 super(that, cc);
162 }
163
164 @Override
165 public GroupQueryAttention transform(CodeContext cc, CodeTransformer ot) {
166 return new GroupQueryAttention(this, cc);
167 }
168
169 GroupQueryAttention(TypeElement resultType, Value query, java.util.Optional<Value> key, java.util.Optional<Value> value, java.util.Optional<Value> past_key, java.util.Optional<Value> past_value, Value seqlens_k, Value total_sequence_length, java.util.Optional<Value> cos_cache, java.util.Optional<Value> sin_cache, java.util.Optional<Long> do_rotary, long kv_num_heads, java.util.Optional<Long> local_window_size, long num_heads, java.util.Optional<Long> rotary_interleaved, java.util.Optional<Float> scale) {
170 super(SCHEMA, resultType, Collections.emptySet(), List.of(query, key, value, past_key, past_value, seqlens_k, total_sequence_length, cos_cache, sin_cache), List.of(do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale));
171 }
172
173 @Override
174 public SequencedSet<OnnxParameter> onnxOutputs() {
175 return onnxOutputs(SCHEMA);
176 }
177
178 @Override
179 public SequencedMap<OnnxParameter, Object> onnxInputs() {
180 return onnxInputs(SCHEMA, List.of(query(), key(), value(), past_key(), past_value(), seqlens_k(), total_sequence_length(), cos_cache(), sin_cache()));
181 }
182
183 public Value query() {
184 return operands().get(0);
185 }
186
187 public java.util.Optional<Value> key() {
188 int i = optionalInputArguments.indexOf(InputParameter.key);
189 return i != -1 ? java.util.Optional.of(operands().get(1 + i)) : java.util.Optional.empty();
190 }
191
192 public java.util.Optional<Value> value() {
193 int i = optionalInputArguments.indexOf(InputParameter.value);
194 return i != -1 ? java.util.Optional.of(operands().get(1 + i)) : java.util.Optional.empty();
195 }
196
197 public java.util.Optional<Value> past_key() {
198 int i = optionalInputArguments.indexOf(InputParameter.past_key);
199 return i != -1 ? java.util.Optional.of(operands().get(1 + i)) : java.util.Optional.empty();
200 }
201
202 public java.util.Optional<Value> past_value() {
203 int i = optionalInputArguments.indexOf(InputParameter.past_value);
204 return i != -1 ? java.util.Optional.of(operands().get(1 + i)) : java.util.Optional.empty();
205 }
206
207 private int skipOptional() {
208 for (int i = optionalInputArguments.size() - 1; i >= 0; i--) {
209 var opt = optionalInputArguments.get(i);
210 if (opt != InputParameter.cos_cache && opt != InputParameter.sin_cache) return i;
211 }
212 return -1;
213 }
214
215 public Value seqlens_k() {
216 return operands().get(skipOptional() + 2);
217 }
218
219 public Value total_sequence_length() {
220 return operands().get(skipOptional() + 3);
221 }
222
223 public java.util.Optional<Value> cos_cache() {
224 int i = optionalInputArguments.indexOf(InputParameter.cos_cache);
225 return i != -1 ? java.util.Optional.of(operands().get(3 + i)) : java.util.Optional.empty();
226 }
227
228 public java.util.Optional<Value> sin_cache() {
229 int i = optionalInputArguments.indexOf(InputParameter.sin_cache);
230 return i != -1 ? java.util.Optional.of(operands().get(3 + i)) : java.util.Optional.empty();
231 }
232 }
233
234 public static GroupQueryAttention GroupQueryAttention(TypeElement resultType, Value query, java.util.Optional<Value> key, java.util.Optional<Value> value, java.util.Optional<Value> past_key, java.util.Optional<Value> past_value, Value seqlens_k, Value total_sequence_length, java.util.Optional<Value> cos_cache, java.util.Optional<Value> sin_cache, java.util.Optional<Long> do_rotary, long kv_num_heads, java.util.Optional<Long> local_window_size, long num_heads, java.util.Optional<Long> rotary_interleaved, java.util.Optional<Float> scale) {
235 return new GroupQueryAttention(resultType, query, key, value, past_key, past_value, seqlens_k, total_sequence_length, cos_cache, sin_cache, do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale);
236 }
237
238 // @@@ this should be generated from contrib operators
239 @OpFactoryHelper.OpDeclaration(MatMulNBits.NAME)
240 public static final class MatMulNBits extends OnnxOp {
241 public static final String NAME = "com.microsoft.MatMulNBits";
242
243 public enum Attribute implements OnnxAttribute {
244 K(Long.class, false, null),
245 N(Long.class, false, null),
246 accuracy_level(Long.class, true, 0),
247 bits(Long.class, false, null),
248 block_size(Long.class, false, null),
249 ;
250
251 final Class<?> t;
252 final boolean optional;
253 final Object defaultValue;
254
255 Attribute(Class<?> type, boolean optional, Object defaultValue) {
256 this.t = type;
257 this.optional = optional;
258 this.defaultValue = defaultValue;
259 assert optional || defaultValue == null;
260 }
261
262 public Class<?> type() {
263 return t;
264 }
265
266 public boolean isOptional() {
267 return optional;
268 }
269
270 public Object defaultValue() {
271 return defaultValue;
272 }
273 }
274
275 public enum TypeConstraint implements OnnxTypeConstraint {
276 T1(new OnnxType.TypeVariable("T1", List.of(OnnxType.tensor(OnnxType.float32()), OnnxType.tensor(OnnxType.float16())))),
277 T2(new OnnxType.TypeVariable("T2", List.of(OnnxType.tensor(OnnxType.uint8()), OnnxType.tensor(OnnxType.int32())))),
278 T3(new OnnxType.TypeVariable("T3", List.of(OnnxType.tensor(OnnxType.uint8()), OnnxType.tensor(OnnxType.int32()), OnnxType.tensor(OnnxType.float16()), OnnxType.tensor(OnnxType.float32())))),
279 T4(new OnnxType.TypeVariable("T4", List.of(OnnxType.tensor(OnnxType.int32())))),
280 ;
281
282 final OnnxType.TypeVariable typeVariable;
283
284 TypeConstraint(OnnxType.TypeVariable typeVariable) {
285 assert typeVariable.name().equals(name());
286 this.typeVariable = typeVariable;
287 }
288
289 @Override
290 public OnnxType.TypeVariable typeVariable() {
291 return typeVariable;
292 }
293 }
294
295 public enum InputParameter implements OnnxParameter {
296 A(TypeConstraint.T1.typeVariable(), Quantifier.REQUIRED),
297 B(TypeConstraint.T2.typeVariable(), Quantifier.REQUIRED),
298 scales(TypeConstraint.T1.typeVariable(), Quantifier.REQUIRED),
299 zero_points(TypeConstraint.T3.typeVariable(), Quantifier.OPTIONAL),
300 g_idx(TypeConstraint.T4.typeVariable(), Quantifier.OPTIONAL),
301 bias(TypeConstraint.T1.typeVariable(), Quantifier.OPTIONAL),
302 ;
303
304 final OnnxType type;
305 final Quantifier quantifier;
306
307 InputParameter(OnnxType type, Quantifier quantifier) {
308 this.type = type;
309 this.quantifier = quantifier;
310 }
311
312 @Override
313 public OnnxType type() {
314 return type;
315 }
316
317 @Override
318 public Quantifier quantifier() {
319 return quantifier;
320 }
321 }
322
323 public enum OutputParameter implements OnnxParameter {
324 Y(TypeConstraint.T1.typeVariable(), Quantifier.REQUIRED),
325 ;
326
327 final OnnxType type;
328 final Quantifier quantifier;
329
330 OutputParameter(OnnxType type, Quantifier quantifier) {
331 this.type = type;
332 this.quantifier = quantifier;
333 }
334
335 @Override
336 public OnnxType type() {
337 return type;
338 }
339
340 @Override
341 public Quantifier quantifier() {
342 return quantifier;
343 }
344 }
345
346 public static final OnnxSchema SCHEMA = new OnnxSchemaRecord(
347 NAME,
348 List.of(Attribute.values()),
349 List.of(TypeConstraint.values()),
350 List.of(InputParameter.values()),
351 List.of(OutputParameter.values())
352 );
353
354 public MatMulNBits(ExternalizedOp def) {
355 super(SCHEMA, def);
356 }
357
358 MatMulNBits(MatMulNBits that, CodeContext cc) {
359 super(that, cc);
360 }
361
362 @Override
363 public MatMulNBits transform(CodeContext cc, CodeTransformer ot) {
364 return new MatMulNBits(this, cc);
365 }
366
367 MatMulNBits(TypeElement resultType, Value a, Value b, Value scales, java.util.Optional<Value> zero_points, java.util.Optional<Value> g_idx, java.util.Optional<Value> bias, long K, long N, java.util.Optional<Long> accuracy_level, long bits, long block_size) {
368 super(SCHEMA, resultType, Collections.emptySet(), List.of(a, b, scales, zero_points, g_idx, bias), List.of(K, N, accuracy_level, bits, block_size));
369 }
370
371 @Override
372 public SequencedSet<OnnxParameter> onnxOutputs() {
373 return onnxOutputs(SCHEMA);
374 }
375
376 @Override
377 public SequencedMap<OnnxParameter, Object> onnxInputs() {
378 return onnxInputs(SCHEMA, List.of(a(), b(), scales(), zero_points(), g_idx(), bias()));
379 }
380
381 public Value a() {
382 return operands().get(0);
383 }
384
385 public Value b() {
386 return operands().get(1);
387 }
388
389 public Value scales() {
390 return operands().get(2);
391 }
392
393 public java.util.Optional<Value> zero_points() {
394 int i = optionalInputArguments.indexOf(InputParameter.zero_points);
395 return i != -1 ? java.util.Optional.of(operands().get(3 + i)) : java.util.Optional.empty();
396 }
397
398 public java.util.Optional<Value> g_idx() {
399 int i = optionalInputArguments.indexOf(InputParameter.g_idx);
400 return i != -1 ? java.util.Optional.of(operands().get(3 + i)) : java.util.Optional.empty();
401 }
402
403 public java.util.Optional<Value> bias() {
404 int i = optionalInputArguments.indexOf(InputParameter.bias);
405 return i != -1 ? java.util.Optional.of(operands().get(3 + i)) : java.util.Optional.empty();
406 }
407 }
408
409 public static MatMulNBits MatMulNBits(TypeElement resultType, Value a, Value b, Value scales, java.util.Optional<Value> zero_points, java.util.Optional<Value> g_idx, java.util.Optional<Value> bias, long K, long N, java.util.Optional<Long> accuracy_level, long bits, long block_size) {
410 return new MatMulNBits(resultType, a, b, scales, zero_points, g_idx, bias, K, N, accuracy_level, bits, block_size);
411 }
412
413 // @@@ this should be generated from contrib operators
414 @OpFactoryHelper.OpDeclaration(SkipSimplifiedLayerNormalization.NAME)
415 public static final class SkipSimplifiedLayerNormalization extends OnnxOp {
416 public static final String NAME = "com.microsoft.SkipSimplifiedLayerNormalization";
417
418 public enum Attribute implements OnnxAttribute {
419 epsilon(Float.class, true, null),
420 ;
421
422 final Class<?> t;
423 final boolean optional;
424 final Object defaultValue;
425
426 Attribute(Class<?> type, boolean optional, Object defaultValue) {
427 this.t = type;
428 this.optional = optional;
429 this.defaultValue = defaultValue;
430 assert optional || defaultValue == null;
431 }
432
433 public Class<?> type() {
434 return t;
435 }
436
437 public boolean isOptional() {
438 return optional;
439 }
440
441 public Object defaultValue() {
442 return defaultValue;
443 }
444 }
445
446 public enum TypeConstraint implements OnnxTypeConstraint {
447 T(new OnnxType.TypeVariable("T", List.of(OnnxType.tensor(OnnxType.float32()), OnnxType.tensor(OnnxType.float16())))),
448 ;
449
450 final OnnxType.TypeVariable typeVariable;
451
452 TypeConstraint(OnnxType.TypeVariable typeVariable) {
453 assert typeVariable.name().equals(name());
454 this.typeVariable = typeVariable;
455 }
456
457 @Override
458 public OnnxType.TypeVariable typeVariable() {
459 return typeVariable;
460 }
461 }
462
463 public enum InputParameter implements OnnxParameter {
464 input(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
465 skip(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
466 gamma(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
467 bias(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
468 ;
469
470 final OnnxType type;
471 final Quantifier quantifier;
472
473 InputParameter(OnnxType type, Quantifier quantifier) {
474 this.type = type;
475 this.quantifier = quantifier;
476 }
477
478 @Override
479 public OnnxType type() {
480 return type;
481 }
482
483 @Override
484 public Quantifier quantifier() {
485 return quantifier;
486 }
487 }
488
489 public enum OutputParameter implements OnnxParameter {
490 output(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
491 mean(OnnxType.TENSOR_FLOAT32, Quantifier.OPTIONAL),
492 inv_std_var(OnnxType.TENSOR_FLOAT32, Quantifier.OPTIONAL),
493 input_skip_bias_sum(OnnxType.TENSOR_FLOAT32, Quantifier.OPTIONAL),
494 ;
495
496 final OnnxType type;
497 final Quantifier quantifier;
498
499 OutputParameter(OnnxType type, Quantifier quantifier) {
500 this.type = type;
501 this.quantifier = quantifier;
502 }
503
504 @Override
505 public OnnxType type() {
506 return type;
507 }
508
509 @Override
510 public Quantifier quantifier() {
511 return quantifier;
512 }
513 }
514
515 public static final OnnxSchema SCHEMA = new OnnxSchemaRecord(
516 NAME,
517 List.of(Attribute.values()),
518 List.of(TypeConstraint.values()),
519 List.of(InputParameter.values()),
520 List.of(OutputParameter.values())
521 );
522
523 public SkipSimplifiedLayerNormalization(ExternalizedOp def) {
524 super(SCHEMA, def);
525 }
526
527 SkipSimplifiedLayerNormalization(SkipSimplifiedLayerNormalization that, CodeContext cc) {
528 super(that, cc);
529 }
530
531 @Override
532 public SkipSimplifiedLayerNormalization transform(CodeContext cc, CodeTransformer ot) {
533 return new SkipSimplifiedLayerNormalization(this, cc);
534 }
535
536 SkipSimplifiedLayerNormalization(TypeElement resultType, Set<OutputParameter> optionalOutputs, Value input, Value skip, Value gamma, java.util.Optional<Value> bias, java.util.Optional<Float> epsilon) {
537 super(SCHEMA, resultType, optionalOutputs, List.of(input, skip, gamma, bias), List.of(epsilon));
538 }
539
540 @Override
541 public SequencedSet<OnnxParameter> onnxOutputs() {
542 return onnxOutputs(SCHEMA);
543 }
544
545 @Override
546 public SequencedMap<OnnxParameter, Object> onnxInputs() {
547 return onnxInputs(SCHEMA, List.of(input(), skip(), gamma(), bias()));
548 }
549
550 public Value input() {
551 return operands().get(0);
552 }
553
554 public Value skip() {
555 return operands().get(1);
556 }
557
558 public Value gamma() {
559 return operands().get(2);
560 }
561
562 public java.util.Optional<Value> bias() {
563 int i = optionalInputArguments.indexOf(InputParameter.bias);
564 return i != -1 ? java.util.Optional.of(operands().get(3 + i)) : java.util.Optional.empty();
565 }
566 }
567
568 public static SkipSimplifiedLayerNormalization SkipSimplifiedLayerNormalization(TypeElement resultType, Set<SkipSimplifiedLayerNormalization.OutputParameter> optionalOutputs, Value input, Value skip, Value gamma, java.util.Optional<Value> bias, java.util.Optional<Float> epsilon) {
569 return new SkipSimplifiedLayerNormalization(resultType, optionalOutputs, input, skip, gamma, bias, epsilon);
570 }
571
572 // @@@ this should be generated from onnxruntime-extensions
573 @OpFactoryHelper.OpDeclaration(CLIPTokenizer.NAME)
574 public static final class CLIPTokenizer extends OnnxOp {
575 public static final String NAME = "ai.onnx.contrib.CLIPTokenizer";
576
577 public enum Attribute implements OnnxAttribute {
578 vocab(String.class, false, null),
579 merges(String.class, false, null),
580 padding_length(Long.class, true, -1),
581 ;
582
583 final Class<?> t;
584 final boolean optional;
585 final Object defaultValue;
586
587 Attribute(Class<?> type, boolean optional, Object defaultValue) {
588 this.t = type;
589 this.optional = optional;
590 this.defaultValue = defaultValue;
591 assert optional || defaultValue == null;
592 }
593
594 public Class<?> type() {
595 return t;
596 }
597
598 public boolean isOptional() {
599 return optional;
600 }
601
602 public Object defaultValue() {
603 return defaultValue;
604 }
605 }
606
607 public enum TypeConstraint implements OnnxTypeConstraint.None { }
608
609 public enum InputParameter implements OnnxParameter {
610 input_text(OnnxType.TENSOR_STRING, Quantifier.REQUIRED),
611 ;
612
613 final OnnxType type;
614 final Quantifier quantifier;
615
616 InputParameter(OnnxType type, Quantifier quantifier) {
617 this.type = type;
618 this.quantifier = quantifier;
619 }
620
621 @Override
622 public OnnxType type() {
623 return type;
624 }
625
626 @Override
627 public Quantifier quantifier() {
628 return quantifier;
629 }
630 }
631
632 public enum OutputParameter implements OnnxParameter {
633 input_ids(OnnxType.TENSOR_INT64, Quantifier.REQUIRED),
634 attention_mask(OnnxType.TENSOR_INT64, Quantifier.OPTIONAL),
635 offset_mapping(OnnxType.TENSOR_INT64, Quantifier.OPTIONAL),
636 ;
637
638 final OnnxType type;
639 final Quantifier quantifier;
640
641 OutputParameter(OnnxType type, Quantifier quantifier) {
642 this.type = type;
643 this.quantifier = quantifier;
644 }
645
646 @Override
647 public OnnxType type() {
648 return type;
649 }
650
651 @Override
652 public Quantifier quantifier() {
653 return quantifier;
654 }
655 }
656
657 public static final OnnxSchema SCHEMA = new OnnxSchemaRecord(
658 NAME,
659 List.of(Attribute.values()),
660 List.of(TypeConstraint.values()),
661 List.of(InputParameter.values()),
662 List.of(OutputParameter.values())
663 );
664
665 public CLIPTokenizer(ExternalizedOp def) {
666 super(SCHEMA, def);
667 }
668
669 CLIPTokenizer(CLIPTokenizer that, CodeContext cc) {
670 super(that, cc);
671 }
672
673 @Override
674 public CLIPTokenizer transform(CodeContext cc, CodeTransformer ot) {
675 return new CLIPTokenizer(this, cc);
676 }
677
678 CLIPTokenizer(TypeElement resultType, Set<OutputParameter> optionalOutputs, Value input_text, String vocab, String merges, java.util.Optional<Long> padding_length) {
679 super(SCHEMA, resultType, optionalOutputs, List.of(input_text), List.of(vocab, merges, padding_length));
680 }
681
682 @Override
683 public SequencedSet<OnnxParameter> onnxOutputs() {
684 return onnxOutputs(SCHEMA);
685 }
686
687 @Override
688 public SequencedMap<OnnxParameter, Object> onnxInputs() {
689 return onnxInputs(SCHEMA, List.of(input_text()));
690 }
691
692 public Value input_text() {
693 return operands().get(0);
694 }
695 }
696
697 public static CLIPTokenizer CLIPTokenizer(TypeElement resultType, Set<CLIPTokenizer.OutputParameter> optionalOutputs, Value input_text, String vocab, String merges, java.util.Optional<Long> padding_length) {
698 return new CLIPTokenizer(resultType, optionalOutputs, input_text, vocab, merges, padding_length);
699 }
700
701
702 @OpFactoryHelper.OpDeclaration(If.NAME)
703 public static final class If extends OnnxOp implements Nested {
704 public static final String NAME = "If";
705
706 final Body thenBody, elseBody;
707
708 // @@@ make or fake elseBody as "else_branch" attribute and thenBody as "then_branch" attribute
709 public enum Attribute implements OnnxOp.OnnxAttribute.None { }
710
711 public enum TypeConstraint implements OnnxOp.OnnxTypeConstraint {
712 V(new OnnxType.TypeVariable("V", List.of(OnnxType.tensor(OnnxType.uint8()), OnnxType.tensor(OnnxType.uint16()), OnnxType.tensor(OnnxType.uint32()), OnnxType.tensor(OnnxType.uint64()), OnnxType.tensor(OnnxType.int8()), OnnxType.tensor(OnnxType.int16()), OnnxType.tensor(OnnxType.int32()), OnnxType.tensor(OnnxType.int64()), OnnxType.tensor(OnnxType.bfloat16()), OnnxType.tensor(OnnxType.float16()), OnnxType.tensor(OnnxType.float32()), OnnxType.tensor(OnnxType.float64()), OnnxType.tensor(OnnxType.bool())))),
713 B(new OnnxType.TypeVariable("B", List.of(OnnxType.tensor(OnnxType.bool())))),
714 ;
715
716 final OnnxType.TypeVariable typeVariable;
717
718 TypeConstraint(OnnxType.TypeVariable typeVariable) {
719 assert typeVariable.name().equals(name());
720 this.typeVariable = typeVariable;
721 }
722
723 @Override
724 public OnnxType.TypeVariable typeVariable() {
725 return typeVariable;
726 }
727 }
728
729 public enum InputParameter implements OnnxOp.OnnxParameter {
730 cond(TypeConstraint.B.typeVariable(), OnnxOp.OnnxParameter.Quantifier.REQUIRED),
731 ;
732
733 final OnnxType type;
734 final OnnxOp.OnnxParameter.Quantifier quantifier;
735
736 InputParameter(OnnxType type, OnnxOp.OnnxParameter.Quantifier quantifier) {
737 this.type = type;
738 this.quantifier = quantifier;
739 }
740
741 @Override
742 public OnnxType type() {
743 return type;
744 }
745
746 @Override
747 public OnnxOp.OnnxParameter.Quantifier quantifier() {
748 return quantifier;
749 }
750 }
751
752 public enum OutputParameter implements OnnxOp.OnnxParameter {
753 output(TypeConstraint.V.typeVariable(), OnnxOp.OnnxParameter.Quantifier.VARIADIC),
754 ;
755
756 final OnnxType type;
757 final OnnxOp.OnnxParameter.Quantifier quantifier;
758
759 OutputParameter(OnnxType type, OnnxOp.OnnxParameter.Quantifier quantifier) {
760 this.type = type;
761 this.quantifier = quantifier;
762 }
763
764 @Override
765 public OnnxType type() {
766 return type;
767 }
768
769 @Override
770 public OnnxOp.OnnxParameter.Quantifier quantifier() {
771 return quantifier;
772 }
773 }
774
775 public static final OnnxOp.OnnxSchema SCHEMA = new OnnxSchemaRecord(
776 NAME,
777 List.of(Attribute.values()),
778 List.of(TypeConstraint.values()),
779 List.of(InputParameter.values()),
780 List.of(OutputParameter.values())
781 );
782
783 public If(ExternalizedOp def) {
784 super(SCHEMA, def);
785
786 this.thenBody = def.bodyDefinitions().get(0).build(this);
787 this.elseBody = def.bodyDefinitions().get(1).build(this);
788 }
789
790 If(If that, CodeContext cc, CodeTransformer ot) {
791 super(that, cc);
792
793 this.thenBody = that.thenBody.transform(cc, ot).build(this);
794 this.elseBody = that.elseBody.transform(cc, ot).build(this);
795 }
796
797 @Override
798 public If transform(CodeContext cc, CodeTransformer ot) {
799 return new If(this, cc, ot);
800 }
801
802 If(TypeElement resultType, Value cond, Body.Builder thenBranch, Body.Builder elseBranch) {
803 super(SCHEMA, resultType, Set.of(), List.of(cond), List.of());
804
805 this.thenBody = thenBranch.build(this);
806 this.elseBody = elseBranch.build(this);
807 }
808
809 @Override
810 public List<Body> bodies() {
811 return List.of(thenBody, elseBody);
812 }
813
814 @Override
815 public SequencedSet<OnnxOp.OnnxParameter> onnxOutputs() {
816 return onnxOutputs(SCHEMA);
817 }
818
819 @Override
820 public SequencedMap<OnnxOp.OnnxParameter, Object> onnxInputs() {
821 return onnxInputs(SCHEMA, List.of(cond()));
822 }
823
824 public Value cond() {
825 return operands().get(0);
826 }
827
828 public Body elseBranch() {
829 return elseBody;
830 }
831
832 public Body thenBranch() {
833 return thenBody;
834 }
835 }
836
837 public static If If(TypeElement resultType, Value cond, Body.Builder thenBody, Body.Builder elseBody) {
838 return new If(resultType, cond, thenBody, elseBody);
839 }
840
841 @OpFactoryHelper.OpDeclaration(Loop.NAME)
842 public static final class Loop extends OnnxOp implements Op.Loop {
843 public static final String NAME = "Loop";
844
845 final Body body;
846
847 // @@@ make or fake body
848 public enum Attribute implements OnnxOp.OnnxAttribute.None { }
849
850 public enum TypeConstraint implements OnnxOp.OnnxTypeConstraint {
851 V(new OnnxType.TypeVariable("V", List.of(OnnxType.tensor(OnnxType.uint8()), OnnxType.tensor(OnnxType.uint16()), OnnxType.tensor(OnnxType.uint32()), OnnxType.tensor(OnnxType.uint64()), OnnxType.tensor(OnnxType.int8()), OnnxType.tensor(OnnxType.int16()), OnnxType.tensor(OnnxType.int32()), OnnxType.tensor(OnnxType.int64()), OnnxType.tensor(OnnxType.bfloat16()), OnnxType.tensor(OnnxType.float16()), OnnxType.tensor(OnnxType.float32()), OnnxType.tensor(OnnxType.float64()), OnnxType.tensor(OnnxType.bool())))),
852 I(new OnnxType.TypeVariable("I", List.of(OnnxType.tensor(OnnxType.int64())))),
853 B(new OnnxType.TypeVariable("B", List.of(OnnxType.tensor(OnnxType.bool())))),
854 ;
855
856 final OnnxType.TypeVariable typeVariable;
857
858 TypeConstraint(OnnxType.TypeVariable typeVariable) {
859 assert typeVariable.name().equals(name());
860 this.typeVariable = typeVariable;
861 }
862
863 @Override
864 public OnnxType.TypeVariable typeVariable() {
865 return typeVariable;
866 }
867 }
868
869 public enum InputParameter implements OnnxOp.OnnxParameter {
870 // @@@ Onnx spec declares the input parameters as optional, however it is causing problems
871 M(TypeConstraint.I.typeVariable(), OnnxOp.OnnxParameter.Quantifier.REQUIRED),
872 cond(TypeConstraint.B.typeVariable(), OnnxOp.OnnxParameter.Quantifier.REQUIRED),
873 v_initial(TypeConstraint.V.typeVariable(), OnnxOp.OnnxParameter.Quantifier.VARIADIC),
874 ;
875
876 final OnnxType type;
877 final OnnxOp.OnnxParameter.Quantifier quantifier;
878
879 InputParameter(OnnxType type, OnnxOp.OnnxParameter.Quantifier quantifier) {
880 this.type = type;
881 this.quantifier = quantifier;
882 }
883
884 @Override
885 public OnnxType type() {
886 return type;
887 }
888
889 @Override
890 public OnnxOp.OnnxParameter.Quantifier quantifier() {
891 return quantifier;
892 }
893 }
894
895 public enum OutputParameter implements OnnxOp.OnnxParameter {
896 v_final_and_scan_outputs(TypeConstraint.V.typeVariable(), OnnxOp.OnnxParameter.Quantifier.VARIADIC),
897 ;
898
899 final OnnxType type;
900 final OnnxOp.OnnxParameter.Quantifier quantifier;
901
902 OutputParameter(OnnxType type, OnnxOp.OnnxParameter.Quantifier quantifier) {
903 this.type = type;
904 this.quantifier = quantifier;
905 }
906
907 @Override
908 public OnnxType type() {
909 return type;
910 }
911
912 @Override
913 public OnnxOp.OnnxParameter.Quantifier quantifier() {
914 return quantifier;
915 }
916 }
917
918 public static final OnnxOp.OnnxSchema SCHEMA = new OnnxSchemaRecord(
919 NAME,
920 List.of(Attribute.values()),
921 List.of(TypeConstraint.values()),
922 List.of(InputParameter.values()),
923 List.of(OutputParameter.values())
924 );
925
926 public Loop(ExternalizedOp def) {
927 super(SCHEMA, def);
928
929 this.body = def.bodyDefinitions().get(0).build(this);
930 }
931
932 Loop(ExplicitOnnxOps.Loop that, CodeContext cc, CodeTransformer ot) {
933 super(that, cc);
934
935 this.body = that.body.transform(cc, ot).build(this);
936 }
937
938 @Override
939 public ExplicitOnnxOps.Loop transform(CodeContext cc, CodeTransformer ot) {
940 return new ExplicitOnnxOps.Loop(this, cc, ot);
941 }
942
943 Loop(TypeElement resultType, Value m, Value cond, Object v_initial, Body.Builder body) {
944 super(SCHEMA, resultType, Set.of(), List.of(m, cond, v_initial), List.of());
945
946 this.body = body.build(this);
947 }
948
949 @Override
950 public List<Body> bodies() {
951 return List.of(body);
952 }
953
954 @Override
955 public SequencedSet<OnnxOp.OnnxParameter> onnxOutputs() {
956 return onnxOutputs(SCHEMA);
957 }
958
959 @Override
960 public SequencedMap<OnnxOp.OnnxParameter, Object> onnxInputs() {
961 return onnxInputs(SCHEMA, List.of(cond()));
962 }
963
964 public Value max() {
965 return operands().get(0);
966 }
967
968 public Value cond() {
969 return operands().get(1);
970 }
971
972 public List<Value> v_initial() {
973 return operands().subList(2, operands().size());
974 }
975
976 @Override
977 public Body loopBody() {
978 return body;
979 }
980 }
981
982 public static Loop Loop(TypeElement resultType, Value m, Value cond, Object v_initial, Body.Builder body) {
983 return new Loop(resultType, m, cond, v_initial, body);
984 }
985 }