1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx.ir;
 27 
 28 import java.util.*;
 29 import jdk.incubator.code.*;
 30 import jdk.incubator.code.Op.Nested;
 31 import jdk.incubator.code.extern.ExternalizedOp;
 32 import jdk.incubator.code.extern.OpFactory;
 33 
 34 public sealed class ExplicitOnnxOps permits OnnxOps {
 35 
 36     // @@@ this should be generated from contrib operators
 37     @OpFactoryHelper.OpDeclaration(GroupQueryAttention.NAME)
 38     public static final class GroupQueryAttention extends OnnxOp {
 39         public static final String NAME = "com.microsoft.GroupQueryAttention";
 40 
 41         public enum Attribute implements OnnxAttribute {
 42             do_rotary(Long.class, true, 0),
 43             kv_num_heads(Long.class, false, null),
 44             local_window_size(Long.class, true, -1),
 45             num_heads(Long.class, false, null),
 46             rotary_interleaved(Long.class, true, 0),
 47             scale(Float.class, true, null), // @@@ Default value is 1/sqrt(head_size)
 48             ;
 49 
 50                 final Class<?> t;
 51                 final boolean optional;
 52                 final Object defaultValue;
 53 
 54                 Attribute(Class<?> type, boolean optional, Object defaultValue) {
 55                     this.t = type;
 56                     this.optional = optional;
 57                     this.defaultValue = defaultValue;
 58                     assert optional || defaultValue == null;
 59                 }
 60 
 61                 public Class<?> type() {
 62                     return t;
 63                 }
 64 
 65                 public boolean isOptional() {
 66                     return optional;
 67                 }
 68 
 69                 public Object defaultValue() {
 70                     return defaultValue;
 71                 }
 72         }
 73 
 74         public enum TypeConstraint implements OnnxTypeConstraint {
 75             T(new OnnxType.TypeVariable("T", List.of(OnnxType.tensor(OnnxType.float16()), OnnxType.tensor(OnnxType.bfloat16()), OnnxType.tensor(OnnxType.float32())))),
 76             M(new OnnxType.TypeVariable("M", List.of(OnnxType.tensor(OnnxType.int32())))),
 77             ;
 78 
 79             final OnnxType.TypeVariable typeVariable;
 80 
 81             TypeConstraint(OnnxType.TypeVariable typeVariable) {
 82                 assert typeVariable.name().equals(name());
 83                 this.typeVariable = typeVariable;
 84             }
 85 
 86             @Override
 87             public OnnxType.TypeVariable typeVariable() {
 88                 return typeVariable;
 89             }
 90         }
 91 
 92         public enum InputParameter implements OnnxParameter {
 93             query(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
 94             key(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
 95             value(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
 96             past_key(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
 97             past_value(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
 98             seqlens_k(TypeConstraint.M.typeVariable(), Quantifier.REQUIRED),
 99             total_sequence_length(TypeConstraint.M.typeVariable(), Quantifier.REQUIRED),
100             cos_cache(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
101             sin_cache(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
102             ;
103 
104             final OnnxType type;
105             final Quantifier quantifier;
106 
107             InputParameter(OnnxType type, Quantifier quantifier) {
108                 this.type = type;
109                 this.quantifier = quantifier;
110             }
111 
112             @Override
113             public OnnxType type() {
114                 return type;
115             }
116 
117             @Override
118             public Quantifier quantifier() {
119                 return quantifier;
120             }
121         }
122 
123         public enum OutputParameter implements OnnxParameter {
124             output(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
125             present_key(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
126             present_value(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
127             ;
128 
129             final OnnxType type;
130             final Quantifier quantifier;
131 
132             OutputParameter(OnnxType type, Quantifier quantifier) {
133                 this.type = type;
134                 this.quantifier = quantifier;
135             }
136 
137             @Override
138             public OnnxType type() {
139                 return type;
140             }
141 
142             @Override
143             public Quantifier quantifier() {
144                 return quantifier;
145             }
146         }
147 
148         public static final OnnxSchema SCHEMA = new OnnxSchemaRecord(
149                 NAME,
150                 List.of(Attribute.values()),
151                 List.of(TypeConstraint.values()),
152                 List.of(InputParameter.values()),
153                 List.of(OutputParameter.values())
154         );
155 
156         public GroupQueryAttention(ExternalizedOp def) {
157             super(SCHEMA, def);
158         }
159 
160         GroupQueryAttention(GroupQueryAttention that, CodeContext cc) {
161             super(that, cc);
162         }
163 
164         @Override
165         public GroupQueryAttention transform(CodeContext cc, CodeTransformer ot) {
166             return new GroupQueryAttention(this, cc);
167         }
168 
169         GroupQueryAttention(TypeElement resultType, Value query, java.util.Optional<Value> key, java.util.Optional<Value> value, java.util.Optional<Value> past_key, java.util.Optional<Value> past_value, Value seqlens_k, Value total_sequence_length, java.util.Optional<Value> cos_cache, java.util.Optional<Value> sin_cache, java.util.Optional<Long> do_rotary, long kv_num_heads, java.util.Optional<Long> local_window_size, long num_heads, java.util.Optional<Long> rotary_interleaved, java.util.Optional<Float> scale) {
170             super(SCHEMA, resultType, Collections.emptySet(), List.of(query, key, value, past_key, past_value, seqlens_k, total_sequence_length, cos_cache, sin_cache), List.of(do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale));
171         }
172 
173         @Override
174         public SequencedSet<OnnxParameter> onnxOutputs() {
175             return onnxOutputs(SCHEMA);
176         }
177 
178         @Override
179         public SequencedMap<OnnxParameter, Object> onnxInputs() {
180             return onnxInputs(SCHEMA, List.of(query(), key(), value(), past_key(), past_value(), seqlens_k(), total_sequence_length(), cos_cache(), sin_cache()));
181         }
182 
183         public Value query() {
184             return operands().get(0);
185         }
186 
187         public java.util.Optional<Value> key() {
188             int i = optionalInputArguments.indexOf(InputParameter.key);
189             return i != -1 ? java.util.Optional.of(operands().get(1 + i)) : java.util.Optional.empty();
190         }
191 
192         public java.util.Optional<Value> value() {
193             int i = optionalInputArguments.indexOf(InputParameter.value);
194             return i != -1 ? java.util.Optional.of(operands().get(1 + i)) : java.util.Optional.empty();
195         }
196 
197         public java.util.Optional<Value> past_key() {
198             int i = optionalInputArguments.indexOf(InputParameter.past_key);
199             return i != -1 ? java.util.Optional.of(operands().get(1 + i)) : java.util.Optional.empty();
200         }
201 
202         public java.util.Optional<Value> past_value() {
203             int i = optionalInputArguments.indexOf(InputParameter.past_value);
204             return i != -1 ? java.util.Optional.of(operands().get(1 + i)) : java.util.Optional.empty();
205         }
206 
207         private int skipOptional() {
208             for (int i = optionalInputArguments.size() - 1; i >= 0; i--) {
209                 var opt = optionalInputArguments.get(i);
210                 if (opt != InputParameter.cos_cache && opt != InputParameter.sin_cache) return i;
211             }
212             return -1;
213         }
214 
215         public Value seqlens_k() {
216             return operands().get(skipOptional() + 2);
217         }
218 
219         public Value total_sequence_length() {
220             return operands().get(skipOptional() + 3);
221         }
222 
223         public java.util.Optional<Value> cos_cache() {
224             int i = optionalInputArguments.indexOf(InputParameter.cos_cache);
225             return i != -1 ? java.util.Optional.of(operands().get(3 + i)) : java.util.Optional.empty();
226         }
227 
228         public java.util.Optional<Value> sin_cache() {
229             int i = optionalInputArguments.indexOf(InputParameter.sin_cache);
230             return i != -1 ? java.util.Optional.of(operands().get(3 + i)) : java.util.Optional.empty();
231         }
232     }
233 
234     public static GroupQueryAttention GroupQueryAttention(TypeElement resultType, Value query, java.util.Optional<Value> key, java.util.Optional<Value> value, java.util.Optional<Value> past_key, java.util.Optional<Value> past_value, Value seqlens_k, Value total_sequence_length, java.util.Optional<Value> cos_cache, java.util.Optional<Value> sin_cache, java.util.Optional<Long> do_rotary, long kv_num_heads, java.util.Optional<Long> local_window_size, long num_heads, java.util.Optional<Long> rotary_interleaved, java.util.Optional<Float> scale) {
235         return new GroupQueryAttention(resultType, query, key, value, past_key, past_value, seqlens_k, total_sequence_length, cos_cache, sin_cache, do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale);
236     }
237 
238     // @@@ this should be generated from contrib operators
239     @OpFactoryHelper.OpDeclaration(MatMulNBits.NAME)
240     public static final class MatMulNBits extends OnnxOp {
241         public static final String NAME = "com.microsoft.MatMulNBits";
242 
243         public enum Attribute implements OnnxAttribute {
244             K(Long.class, false, null),
245             N(Long.class, false, null),
246             accuracy_level(Long.class, true, 0),
247             bits(Long.class, false, null),
248             block_size(Long.class, false, null),
249             ;
250 
251                 final Class<?> t;
252                 final boolean optional;
253                 final Object defaultValue;
254 
255                 Attribute(Class<?> type, boolean optional, Object defaultValue) {
256                     this.t = type;
257                     this.optional = optional;
258                     this.defaultValue = defaultValue;
259                     assert optional || defaultValue == null;
260                 }
261 
262                 public Class<?> type() {
263                     return t;
264                 }
265 
266                 public boolean isOptional() {
267                     return optional;
268                 }
269 
270                 public Object defaultValue() {
271                     return defaultValue;
272                 }
273         }
274 
275         public enum TypeConstraint implements OnnxTypeConstraint {
276             T1(new OnnxType.TypeVariable("T1", List.of(OnnxType.tensor(OnnxType.float32()), OnnxType.tensor(OnnxType.float16())))),
277             T2(new OnnxType.TypeVariable("T2", List.of(OnnxType.tensor(OnnxType.uint8()), OnnxType.tensor(OnnxType.int32())))),
278             T3(new OnnxType.TypeVariable("T3", List.of(OnnxType.tensor(OnnxType.uint8()), OnnxType.tensor(OnnxType.int32()), OnnxType.tensor(OnnxType.float16()), OnnxType.tensor(OnnxType.float32())))),
279             T4(new OnnxType.TypeVariable("T4", List.of(OnnxType.tensor(OnnxType.int32())))),
280             ;
281 
282             final OnnxType.TypeVariable typeVariable;
283 
284             TypeConstraint(OnnxType.TypeVariable typeVariable) {
285                 assert typeVariable.name().equals(name());
286                 this.typeVariable = typeVariable;
287             }
288 
289             @Override
290             public OnnxType.TypeVariable typeVariable() {
291                 return typeVariable;
292             }
293         }
294 
295         public enum InputParameter implements OnnxParameter {
296             A(TypeConstraint.T1.typeVariable(), Quantifier.REQUIRED),
297             B(TypeConstraint.T2.typeVariable(), Quantifier.REQUIRED),
298             scales(TypeConstraint.T1.typeVariable(), Quantifier.REQUIRED),
299             zero_points(TypeConstraint.T3.typeVariable(), Quantifier.OPTIONAL),
300             g_idx(TypeConstraint.T4.typeVariable(), Quantifier.OPTIONAL),
301             bias(TypeConstraint.T1.typeVariable(), Quantifier.OPTIONAL),
302             ;
303 
304             final OnnxType type;
305             final Quantifier quantifier;
306 
307             InputParameter(OnnxType type, Quantifier quantifier) {
308                 this.type = type;
309                 this.quantifier = quantifier;
310             }
311 
312             @Override
313             public OnnxType type() {
314                 return type;
315             }
316 
317             @Override
318             public Quantifier quantifier() {
319                 return quantifier;
320             }
321         }
322 
323         public enum OutputParameter implements OnnxParameter {
324             Y(TypeConstraint.T1.typeVariable(), Quantifier.REQUIRED),
325             ;
326 
327             final OnnxType type;
328             final Quantifier quantifier;
329 
330             OutputParameter(OnnxType type, Quantifier quantifier) {
331                 this.type = type;
332                 this.quantifier = quantifier;
333             }
334 
335             @Override
336             public OnnxType type() {
337                 return type;
338             }
339 
340             @Override
341             public Quantifier quantifier() {
342                 return quantifier;
343             }
344         }
345 
346         public static final OnnxSchema SCHEMA = new OnnxSchemaRecord(
347                 NAME,
348                 List.of(Attribute.values()),
349                 List.of(TypeConstraint.values()),
350                 List.of(InputParameter.values()),
351                 List.of(OutputParameter.values())
352         );
353 
354         public MatMulNBits(ExternalizedOp def) {
355             super(SCHEMA, def);
356         }
357 
358         MatMulNBits(MatMulNBits that, CodeContext cc) {
359             super(that, cc);
360         }
361 
362         @Override
363         public MatMulNBits transform(CodeContext cc, CodeTransformer ot) {
364             return new MatMulNBits(this, cc);
365         }
366 
367         MatMulNBits(TypeElement resultType, Value a, Value b, Value scales, java.util.Optional<Value> zero_points, java.util.Optional<Value> g_idx, java.util.Optional<Value> bias, long K, long N, java.util.Optional<Long> accuracy_level, long bits, long block_size) {
368             super(SCHEMA, resultType, Collections.emptySet(), List.of(a, b, scales, zero_points, g_idx, bias), List.of(K, N, accuracy_level, bits, block_size));
369         }
370 
371         @Override
372         public SequencedSet<OnnxParameter> onnxOutputs() {
373             return onnxOutputs(SCHEMA);
374         }
375 
376         @Override
377         public SequencedMap<OnnxParameter, Object> onnxInputs() {
378                     return onnxInputs(SCHEMA, List.of(a(), b(), scales(), zero_points(), g_idx(), bias()));
379         }
380 
381         public Value a() {
382             return operands().get(0);
383         }
384 
385         public Value b() {
386             return operands().get(1);
387         }
388 
389         public Value scales() {
390             return operands().get(2);
391         }
392 
393         public java.util.Optional<Value> zero_points() {
394             int i = optionalInputArguments.indexOf(InputParameter.zero_points);
395             return i != -1 ? java.util.Optional.of(operands().get(3 + i)) : java.util.Optional.empty();
396         }
397 
398         public java.util.Optional<Value> g_idx() {
399             int i = optionalInputArguments.indexOf(InputParameter.g_idx);
400             return i != -1 ? java.util.Optional.of(operands().get(3 + i)) : java.util.Optional.empty();
401         }
402 
403         public java.util.Optional<Value> bias() {
404             int i = optionalInputArguments.indexOf(InputParameter.bias);
405             return i != -1 ? java.util.Optional.of(operands().get(3 + i)) : java.util.Optional.empty();
406         }
407     }
408 
409     public static MatMulNBits MatMulNBits(TypeElement resultType, Value a, Value b, Value scales, java.util.Optional<Value> zero_points, java.util.Optional<Value> g_idx, java.util.Optional<Value> bias, long K, long N, java.util.Optional<Long> accuracy_level, long bits, long block_size) {
410         return new MatMulNBits(resultType, a, b, scales, zero_points, g_idx, bias, K, N, accuracy_level, bits, block_size);
411     }
412 
413     // @@@ this should be generated from contrib operators
414     @OpFactoryHelper.OpDeclaration(SkipSimplifiedLayerNormalization.NAME)
415     public static final class SkipSimplifiedLayerNormalization extends OnnxOp {
416         public static final String NAME = "com.microsoft.SkipSimplifiedLayerNormalization";
417 
418         public enum Attribute implements OnnxAttribute {
419             epsilon(Float.class, true, null),
420             ;
421 
422                 final Class<?> t;
423                 final boolean optional;
424                 final Object defaultValue;
425 
426                 Attribute(Class<?> type, boolean optional, Object defaultValue) {
427                     this.t = type;
428                     this.optional = optional;
429                     this.defaultValue = defaultValue;
430                     assert optional || defaultValue == null;
431                 }
432 
433                 public Class<?> type() {
434                     return t;
435                 }
436 
437                 public boolean isOptional() {
438                     return optional;
439                 }
440 
441                 public Object defaultValue() {
442                     return defaultValue;
443                 }
444         }
445 
446         public enum TypeConstraint implements OnnxTypeConstraint {
447             T(new OnnxType.TypeVariable("T", List.of(OnnxType.tensor(OnnxType.float32()), OnnxType.tensor(OnnxType.float16())))),
448             ;
449 
450             final OnnxType.TypeVariable typeVariable;
451 
452             TypeConstraint(OnnxType.TypeVariable typeVariable) {
453                 assert typeVariable.name().equals(name());
454                 this.typeVariable = typeVariable;
455             }
456 
457             @Override
458             public OnnxType.TypeVariable typeVariable() {
459                 return typeVariable;
460             }
461         }
462 
463         public enum InputParameter implements OnnxParameter {
464             input(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
465             skip(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
466             gamma(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
467             bias(TypeConstraint.T.typeVariable(), Quantifier.OPTIONAL),
468             ;
469 
470             final OnnxType type;
471             final Quantifier quantifier;
472 
473             InputParameter(OnnxType type, Quantifier quantifier) {
474                 this.type = type;
475                 this.quantifier = quantifier;
476             }
477 
478             @Override
479             public OnnxType type() {
480                 return type;
481             }
482 
483             @Override
484             public Quantifier quantifier() {
485                 return quantifier;
486             }
487         }
488 
489         public enum OutputParameter implements OnnxParameter {
490             output(TypeConstraint.T.typeVariable(), Quantifier.REQUIRED),
491             mean(OnnxType.TENSOR_FLOAT32, Quantifier.OPTIONAL),
492             inv_std_var(OnnxType.TENSOR_FLOAT32, Quantifier.OPTIONAL),
493             input_skip_bias_sum(OnnxType.TENSOR_FLOAT32, Quantifier.OPTIONAL),
494             ;
495 
496             final OnnxType type;
497             final Quantifier quantifier;
498 
499             OutputParameter(OnnxType type, Quantifier quantifier) {
500                 this.type = type;
501                 this.quantifier = quantifier;
502             }
503 
504             @Override
505             public OnnxType type() {
506                 return type;
507             }
508 
509             @Override
510             public Quantifier quantifier() {
511                 return quantifier;
512             }
513         }
514 
515         public static final OnnxSchema SCHEMA = new OnnxSchemaRecord(
516                 NAME,
517                 List.of(Attribute.values()),
518                 List.of(TypeConstraint.values()),
519                 List.of(InputParameter.values()),
520                 List.of(OutputParameter.values())
521         );
522 
523         public SkipSimplifiedLayerNormalization(ExternalizedOp def) {
524             super(SCHEMA, def);
525         }
526 
527         SkipSimplifiedLayerNormalization(SkipSimplifiedLayerNormalization that, CodeContext cc) {
528             super(that, cc);
529         }
530 
531         @Override
532         public SkipSimplifiedLayerNormalization transform(CodeContext cc, CodeTransformer ot) {
533             return new SkipSimplifiedLayerNormalization(this, cc);
534         }
535 
536         SkipSimplifiedLayerNormalization(TypeElement resultType, Set<OutputParameter> optionalOutputs, Value input, Value skip, Value gamma, java.util.Optional<Value> bias, java.util.Optional<Float> epsilon) {
537             super(SCHEMA, resultType, optionalOutputs, List.of(input, skip, gamma, bias), List.of(epsilon));
538         }
539 
540         @Override
541         public SequencedSet<OnnxParameter> onnxOutputs() {
542             return onnxOutputs(SCHEMA);
543         }
544 
545         @Override
546         public SequencedMap<OnnxParameter, Object> onnxInputs() {
547             return onnxInputs(SCHEMA, List.of(input(), skip(), gamma(), bias()));
548         }
549 
550         public Value input() {
551             return operands().get(0);
552         }
553 
554         public Value skip() {
555             return operands().get(1);
556         }
557 
558         public Value gamma() {
559             return operands().get(2);
560         }
561 
562         public java.util.Optional<Value> bias() {
563             int i = optionalInputArguments.indexOf(InputParameter.bias);
564             return i != -1 ? java.util.Optional.of(operands().get(3 + i)) : java.util.Optional.empty();
565         }
566     }
567 
568     public static SkipSimplifiedLayerNormalization SkipSimplifiedLayerNormalization(TypeElement resultType, Set<SkipSimplifiedLayerNormalization.OutputParameter> optionalOutputs, Value input, Value skip, Value gamma, java.util.Optional<Value> bias, java.util.Optional<Float> epsilon) {
569         return new SkipSimplifiedLayerNormalization(resultType, optionalOutputs, input, skip, gamma, bias, epsilon);
570     }
571 
572     // @@@ this should be generated from onnxruntime-extensions
573     @OpFactoryHelper.OpDeclaration(CLIPTokenizer.NAME)
574     public static final class CLIPTokenizer extends OnnxOp {
575         public static final String NAME = "ai.onnx.contrib.CLIPTokenizer";
576 
577         public enum Attribute implements OnnxAttribute {
578             vocab(String.class, false, null),
579             merges(String.class, false, null),
580             padding_length(Long.class, true, -1),
581             ;
582 
583                 final Class<?> t;
584                 final boolean optional;
585                 final Object defaultValue;
586 
587                 Attribute(Class<?> type, boolean optional, Object defaultValue) {
588                     this.t = type;
589                     this.optional = optional;
590                     this.defaultValue = defaultValue;
591                     assert optional || defaultValue == null;
592                 }
593 
594                 public Class<?> type() {
595                     return t;
596                 }
597 
598                 public boolean isOptional() {
599                     return optional;
600                 }
601 
602                 public Object defaultValue() {
603                     return defaultValue;
604                 }
605         }
606 
607         public enum TypeConstraint implements OnnxTypeConstraint.None { }
608 
609         public enum InputParameter implements OnnxParameter {
610             input_text(OnnxType.TENSOR_STRING, Quantifier.REQUIRED),
611             ;
612 
613             final OnnxType type;
614             final Quantifier quantifier;
615 
616             InputParameter(OnnxType type, Quantifier quantifier) {
617                 this.type = type;
618                 this.quantifier = quantifier;
619             }
620 
621             @Override
622             public OnnxType type() {
623                 return type;
624             }
625 
626             @Override
627             public Quantifier quantifier() {
628                 return quantifier;
629             }
630         }
631 
632         public enum OutputParameter implements OnnxParameter {
633             input_ids(OnnxType.TENSOR_INT64, Quantifier.REQUIRED),
634             attention_mask(OnnxType.TENSOR_INT64, Quantifier.OPTIONAL),
635             offset_mapping(OnnxType.TENSOR_INT64, Quantifier.OPTIONAL),
636             ;
637 
638             final OnnxType type;
639             final Quantifier quantifier;
640 
641             OutputParameter(OnnxType type, Quantifier quantifier) {
642                 this.type = type;
643                 this.quantifier = quantifier;
644             }
645 
646             @Override
647             public OnnxType type() {
648                 return type;
649             }
650 
651             @Override
652             public Quantifier quantifier() {
653                 return quantifier;
654             }
655         }
656 
657         public static final OnnxSchema SCHEMA = new OnnxSchemaRecord(
658                 NAME,
659                 List.of(Attribute.values()),
660                 List.of(TypeConstraint.values()),
661                 List.of(InputParameter.values()),
662                 List.of(OutputParameter.values())
663         );
664 
665         public CLIPTokenizer(ExternalizedOp def) {
666             super(SCHEMA, def);
667         }
668 
669         CLIPTokenizer(CLIPTokenizer that, CodeContext cc) {
670             super(that, cc);
671         }
672 
673         @Override
674         public CLIPTokenizer transform(CodeContext cc, CodeTransformer ot) {
675             return new CLIPTokenizer(this, cc);
676         }
677 
678         CLIPTokenizer(TypeElement resultType, Set<OutputParameter> optionalOutputs, Value input_text, String vocab, String merges, java.util.Optional<Long> padding_length) {
679             super(SCHEMA, resultType, optionalOutputs, List.of(input_text), List.of(vocab, merges, padding_length));
680         }
681 
682         @Override
683         public SequencedSet<OnnxParameter> onnxOutputs() {
684             return onnxOutputs(SCHEMA);
685         }
686 
687         @Override
688         public SequencedMap<OnnxParameter, Object> onnxInputs() {
689             return onnxInputs(SCHEMA, List.of(input_text()));
690         }
691 
692         public Value input_text() {
693             return operands().get(0);
694         }
695     }
696 
697     public static CLIPTokenizer CLIPTokenizer(TypeElement resultType, Set<CLIPTokenizer.OutputParameter> optionalOutputs, Value input_text, String vocab, String merges, java.util.Optional<Long> padding_length) {
698         return new CLIPTokenizer(resultType, optionalOutputs, input_text, vocab, merges, padding_length);
699     }
700 
701 
702     @OpFactoryHelper.OpDeclaration(If.NAME)
703     public static final class If extends OnnxOp implements Nested {
704         public static final String NAME = "If";
705 
706         final Body thenBody, elseBody;
707 
708         // @@@ make or fake elseBody as "else_branch" attribute and thenBody as "then_branch" attribute
709         public enum Attribute implements OnnxOp.OnnxAttribute.None { }
710 
711         public enum TypeConstraint implements OnnxOp.OnnxTypeConstraint {
712             V(new OnnxType.TypeVariable("V", List.of(OnnxType.tensor(OnnxType.uint8()), OnnxType.tensor(OnnxType.uint16()), OnnxType.tensor(OnnxType.uint32()), OnnxType.tensor(OnnxType.uint64()), OnnxType.tensor(OnnxType.int8()), OnnxType.tensor(OnnxType.int16()), OnnxType.tensor(OnnxType.int32()), OnnxType.tensor(OnnxType.int64()), OnnxType.tensor(OnnxType.bfloat16()), OnnxType.tensor(OnnxType.float16()), OnnxType.tensor(OnnxType.float32()), OnnxType.tensor(OnnxType.float64()), OnnxType.tensor(OnnxType.bool())))),
713             B(new OnnxType.TypeVariable("B", List.of(OnnxType.tensor(OnnxType.bool())))),
714             ;
715 
716             final OnnxType.TypeVariable typeVariable;
717 
718             TypeConstraint(OnnxType.TypeVariable typeVariable) {
719                 assert typeVariable.name().equals(name());
720                 this.typeVariable = typeVariable;
721             }
722 
723             @Override
724             public OnnxType.TypeVariable typeVariable() {
725                 return typeVariable;
726             }
727         }
728 
729         public enum InputParameter implements OnnxOp.OnnxParameter {
730             cond(TypeConstraint.B.typeVariable(), OnnxOp.OnnxParameter.Quantifier.REQUIRED),
731             ;
732 
733             final OnnxType type;
734             final OnnxOp.OnnxParameter.Quantifier quantifier;
735 
736             InputParameter(OnnxType type, OnnxOp.OnnxParameter.Quantifier quantifier) {
737                 this.type = type;
738                 this.quantifier = quantifier;
739             }
740 
741             @Override
742             public OnnxType type() {
743                 return type;
744             }
745 
746             @Override
747             public OnnxOp.OnnxParameter.Quantifier quantifier() {
748                 return quantifier;
749             }
750         }
751 
752         public enum OutputParameter implements OnnxOp.OnnxParameter {
753             output(TypeConstraint.V.typeVariable(), OnnxOp.OnnxParameter.Quantifier.VARIADIC),
754             ;
755 
756             final OnnxType type;
757             final OnnxOp.OnnxParameter.Quantifier quantifier;
758 
759             OutputParameter(OnnxType type, OnnxOp.OnnxParameter.Quantifier quantifier) {
760                 this.type = type;
761                 this.quantifier = quantifier;
762             }
763 
764             @Override
765             public OnnxType type() {
766                 return type;
767             }
768 
769             @Override
770             public OnnxOp.OnnxParameter.Quantifier quantifier() {
771                 return quantifier;
772             }
773         }
774 
775         public static final OnnxOp.OnnxSchema SCHEMA = new OnnxSchemaRecord(
776                 NAME,
777                 List.of(Attribute.values()),
778                 List.of(TypeConstraint.values()),
779                 List.of(InputParameter.values()),
780                 List.of(OutputParameter.values())
781         );
782 
783         public If(ExternalizedOp def) {
784             super(SCHEMA, def);
785 
786             this.thenBody = def.bodyDefinitions().get(0).build(this);
787             this.elseBody = def.bodyDefinitions().get(1).build(this);
788         }
789 
790         If(If that, CodeContext cc, CodeTransformer ot) {
791             super(that, cc);
792 
793             this.thenBody = that.thenBody.transform(cc, ot).build(this);
794             this.elseBody = that.elseBody.transform(cc, ot).build(this);
795         }
796 
797         @Override
798         public If transform(CodeContext cc, CodeTransformer ot) {
799             return new If(this, cc, ot);
800         }
801 
802         If(TypeElement resultType, Value cond, Body.Builder thenBranch, Body.Builder elseBranch) {
803             super(SCHEMA, resultType, Set.of(), List.of(cond), List.of());
804 
805             this.thenBody = thenBranch.build(this);
806             this.elseBody = elseBranch.build(this);
807         }
808 
809         @Override
810         public List<Body> bodies() {
811             return List.of(thenBody, elseBody);
812         }
813 
814         @Override
815         public SequencedSet<OnnxOp.OnnxParameter> onnxOutputs() {
816             return onnxOutputs(SCHEMA);
817         }
818 
819         @Override
820         public SequencedMap<OnnxOp.OnnxParameter, Object> onnxInputs() {
821             return onnxInputs(SCHEMA, List.of(cond()));
822         }
823 
824         public Value cond() {
825             return operands().get(0);
826         }
827 
828         public Body elseBranch() {
829             return elseBody;
830         }
831 
832         public Body thenBranch() {
833             return thenBody;
834         }
835     }
836 
837     public static If If(TypeElement resultType, Value cond, Body.Builder thenBody, Body.Builder elseBody) {
838         return new If(resultType, cond, thenBody, elseBody);
839     }
840 
841     @OpFactoryHelper.OpDeclaration(Loop.NAME)
842     public static final class Loop extends OnnxOp implements Op.Loop {
843         public static final String NAME = "Loop";
844 
845         final Body body;
846 
847         // @@@ make or fake body
848         public enum Attribute implements OnnxOp.OnnxAttribute.None { }
849 
850         public enum TypeConstraint implements OnnxOp.OnnxTypeConstraint {
851             V(new OnnxType.TypeVariable("V", List.of(OnnxType.tensor(OnnxType.uint8()), OnnxType.tensor(OnnxType.uint16()), OnnxType.tensor(OnnxType.uint32()), OnnxType.tensor(OnnxType.uint64()), OnnxType.tensor(OnnxType.int8()), OnnxType.tensor(OnnxType.int16()), OnnxType.tensor(OnnxType.int32()), OnnxType.tensor(OnnxType.int64()), OnnxType.tensor(OnnxType.bfloat16()), OnnxType.tensor(OnnxType.float16()), OnnxType.tensor(OnnxType.float32()), OnnxType.tensor(OnnxType.float64()), OnnxType.tensor(OnnxType.bool())))),
852             I(new OnnxType.TypeVariable("I", List.of(OnnxType.tensor(OnnxType.int64())))),
853             B(new OnnxType.TypeVariable("B", List.of(OnnxType.tensor(OnnxType.bool())))),
854             ;
855 
856             final OnnxType.TypeVariable typeVariable;
857 
858             TypeConstraint(OnnxType.TypeVariable typeVariable) {
859                 assert typeVariable.name().equals(name());
860                 this.typeVariable = typeVariable;
861             }
862 
863             @Override
864             public OnnxType.TypeVariable typeVariable() {
865                 return typeVariable;
866             }
867         }
868 
869         public enum InputParameter implements OnnxOp.OnnxParameter {
870             // @@@ Onnx spec declares the input parameters as optional, however it is causing problems
871             M(TypeConstraint.I.typeVariable(), OnnxOp.OnnxParameter.Quantifier.REQUIRED),
872             cond(TypeConstraint.B.typeVariable(), OnnxOp.OnnxParameter.Quantifier.REQUIRED),
873             v_initial(TypeConstraint.V.typeVariable(), OnnxOp.OnnxParameter.Quantifier.VARIADIC),
874             ;
875 
876             final OnnxType type;
877             final OnnxOp.OnnxParameter.Quantifier quantifier;
878 
879             InputParameter(OnnxType type, OnnxOp.OnnxParameter.Quantifier quantifier) {
880                 this.type = type;
881                 this.quantifier = quantifier;
882             }
883 
884             @Override
885             public OnnxType type() {
886                 return type;
887             }
888 
889             @Override
890             public OnnxOp.OnnxParameter.Quantifier quantifier() {
891                 return quantifier;
892             }
893         }
894 
895         public enum OutputParameter implements OnnxOp.OnnxParameter {
896             v_final_and_scan_outputs(TypeConstraint.V.typeVariable(), OnnxOp.OnnxParameter.Quantifier.VARIADIC),
897             ;
898 
899             final OnnxType type;
900             final OnnxOp.OnnxParameter.Quantifier quantifier;
901 
902             OutputParameter(OnnxType type, OnnxOp.OnnxParameter.Quantifier quantifier) {
903                 this.type = type;
904                 this.quantifier = quantifier;
905             }
906 
907             @Override
908             public OnnxType type() {
909                 return type;
910             }
911 
912             @Override
913             public OnnxOp.OnnxParameter.Quantifier quantifier() {
914                 return quantifier;
915             }
916         }
917 
918         public static final OnnxOp.OnnxSchema SCHEMA = new OnnxSchemaRecord(
919                 NAME,
920                 List.of(Attribute.values()),
921                 List.of(TypeConstraint.values()),
922                 List.of(InputParameter.values()),
923                 List.of(OutputParameter.values())
924         );
925 
926         public Loop(ExternalizedOp def) {
927             super(SCHEMA, def);
928 
929             this.body = def.bodyDefinitions().get(0).build(this);
930         }
931 
932         Loop(ExplicitOnnxOps.Loop that, CodeContext cc, CodeTransformer ot) {
933             super(that, cc);
934 
935             this.body = that.body.transform(cc, ot).build(this);
936         }
937 
938         @Override
939         public ExplicitOnnxOps.Loop transform(CodeContext cc, CodeTransformer ot) {
940             return new ExplicitOnnxOps.Loop(this, cc, ot);
941         }
942 
943         Loop(TypeElement resultType, Value m, Value cond, Object v_initial, Body.Builder body) {
944             super(SCHEMA, resultType, Set.of(), List.of(m, cond, v_initial), List.of());
945 
946             this.body = body.build(this);
947         }
948 
949         @Override
950         public List<Body> bodies() {
951             return List.of(body);
952         }
953 
954         @Override
955         public SequencedSet<OnnxOp.OnnxParameter> onnxOutputs() {
956             return onnxOutputs(SCHEMA);
957         }
958 
959         @Override
960         public SequencedMap<OnnxOp.OnnxParameter, Object> onnxInputs() {
961             return onnxInputs(SCHEMA, List.of(cond()));
962         }
963 
964         public Value max() {
965             return operands().get(0);
966         }
967 
968         public Value cond() {
969             return operands().get(1);
970         }
971 
972         public List<Value> v_initial() {
973             return operands().subList(2, operands().size());
974         }
975 
976         @Override
977         public Body loopBody() {
978             return body;
979         }
980     }
981 
982     public static Loop Loop(TypeElement resultType, Value m, Value cond, Object v_initial, Body.Builder body) {
983         return new Loop(resultType, m, cond, v_initial, body);
984     }
985 }