1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx.ir;
 27 
 28 import jdk.incubator.code.CopyContext;
 29 import jdk.incubator.code.Op;
 30 import jdk.incubator.code.TypeElement;
 31 import jdk.incubator.code.Value;
 32 import jdk.incubator.code.extern.ExternalizedOp;
 33 
 34 import java.util.*;
 35 
 36 public abstract class OnnxOp extends Op {
 37 
 38     public interface OnnxAttribute {
 39         String name();
 40 
 41         Class<?> type();
 42 
 43         Object defaultValue();
 44 
 45         boolean isOptional();
 46 
 47         default void process(Map<String, Object> attrs, Object value) {
 48             if (value instanceof Optional<?> o) {
 49                 value = o.orElse(null);
 50             }
 51             // @@@ Parse attribute from string value
 52             // @@@ Arrays don't serialize
 53             if (type().isInstance(value)) {
 54                 attrs.put(name(), value);
 55             } else if (value == null) {
 56                 // Ignore
 57             } else {
 58                 throw new UnsupportedOperationException();
 59             }
 60         }
 61 
 62         default <T> T access(Class<T> type, Map<String, Object> attrs) {
 63             Object value = attrs.get(name());
 64             if (value == null && !isOptional()) {
 65                 throw new NoSuchElementException();
 66             }
 67             return type.cast(value);
 68         }
 69 
 70         static Map<String, Object> process(ExternalizedOp eop,
 71                                            List<OnnxAttribute> attributes) {
 72             Map<String, Object> attrs = new HashMap<>();
 73             for (OnnxAttribute attribute : attributes) {
 74                 Object v = eop.attributes().get(attribute.name());
 75                 if (v == null && !attribute.isOptional()) {
 76                     throw new NoSuchElementException(attribute.name());
 77                 }
 78                 attribute.process(attrs, v);
 79             }
 80 
 81             return Map.copyOf(attrs);
 82         }
 83 
 84         interface None extends OnnxAttribute {
 85             @Override
 86             default String name() {
 87                 throw new UnsupportedOperationException();
 88             }
 89 
 90             @Override
 91             default Class<?> type() {
 92                 throw new UnsupportedOperationException();
 93             }
 94 
 95             @Override
 96             default Object defaultValue() {
 97                 throw new UnsupportedOperationException();
 98             }
 99 
100             @Override
101             default boolean isOptional() {
102                 throw new UnsupportedOperationException();
103             }
104         }
105 
106     }
107 
108     public interface OnnxTypeConstraint {
109         String name();
110 
111         OnnxType.TypeVariable typeVariable();
112 
113         interface None extends OnnxTypeConstraint {
114             @Override
115             default String name() {
116                 throw new UnsupportedOperationException();
117             }
118 
119             @Override
120             default OnnxType.TypeVariable typeVariable() {
121                 throw new UnsupportedOperationException();
122             }
123         }
124     }
125 
126     public interface OnnxParameter {
127         enum Quantifier {
128             REQUIRED, // Exactly once
129             OPTIONAL, // Once or none
130             VARIADIC, // One or more
131             ;
132 
133             public boolean isOptional() {
134                 return this == OPTIONAL;
135             }
136 
137             public boolean isRequired() {
138                 return this == REQUIRED;
139             }
140 
141             public boolean isVariadoc() {
142                 return this == VARIADIC;
143             }
144         }
145 
146         String name();
147 
148         OnnxType type();
149 
150         Quantifier quantifier();
151 
152         interface None extends OnnxParameter {
153             @Override
154             default String name() {
155                 throw new UnsupportedOperationException();
156             }
157 
158             @Override
159             default OnnxType type() {
160                 throw new UnsupportedOperationException();
161             }
162 
163             @Override
164             default Quantifier quantifier() {
165                 throw new UnsupportedOperationException();
166             }
167         }
168     }
169 
170     public interface OnnxSchema {
171         String name();
172 
173         List<OnnxAttribute> attributes();
174 
175         List<OnnxTypeConstraint> typeConstraints();
176 
177         List<OnnxParameter> inputs();
178 
179         List<OnnxParameter> outputs();
180     }
181 
182     record OnnxSchemaRecord(
183             String name,
184             List<OnnxAttribute> attributes,
185             List<OnnxTypeConstraint> typeConstraints,
186             List<OnnxParameter> inputs,
187             List<OnnxParameter> outputs
188     ) implements OnnxSchema {}
189 
190     static List<Value> concatValues(Value operand) {
191         return List.of(operand);
192     }
193 
194     static List<Value> concatValues(Value... operands) {
195         return List.of(operands);
196     }
197 
198     static List<Value> concatValues(List<Object> operands) {
199         return concatValues(operands.toArray());
200     }
201 
202     static List<Value> concatValues(Object... operands) {
203         List<Value> l = new ArrayList<>();
204         for (Object operand : operands) {
205             switch (operand) {
206                 case Value v -> l.add(v);
207                 case Optional<?> ov -> {
208                     if (ov.isPresent()) {
209                         l.add((Value) ov.get());
210                     }
211                 }
212                 case List<?> vs -> {
213                     for (Object v : vs) {
214                         l.add((Value) v);
215                     }
216                 }
217                 default -> throw new UnsupportedOperationException();
218             }
219         }
220         return l;
221     }
222 
223     static final String ATTRIBUTE_OPTIONAL_INPUTS = "optional_inputs";
224     static final String ATTRIBUTE_OPTIONAL_OUTPUTS = "optional_outputs";
225 
226     final OnnxSchema schema;
227     final Map<String, Object> onnxAttributes;
228     final TypeElement resultType;
229     final List<OnnxParameter> optionalInputArguments;
230     final List<OnnxParameter> optionalOutputParameters;
231 
232     @SuppressWarnings("unchecked")
233     OnnxOp(OnnxSchema schema, ExternalizedOp def) {
234         super(def.operands());
235 
236         this.schema = schema;
237         this.onnxAttributes = schema.attributes().isEmpty()
238                 ? Map.of()
239                 : OnnxAttribute.process(def, schema.attributes());
240         this.resultType = def.resultType();
241 
242         // @@@ Filter optional
243         this.optionalInputArguments = def.extractAttributeValue(ATTRIBUTE_OPTIONAL_INPUTS,
244                 false, v -> switch (v) {
245                     case List<?> s -> (List<OnnxParameter>) s;
246                     case null -> List.of();
247                     default -> throw new UnsupportedOperationException();
248                 });
249 
250         // @@@ Filter optional
251         this.optionalOutputParameters = def.extractAttributeValue(ATTRIBUTE_OPTIONAL_OUTPUTS,
252                 false, v -> switch (v) {
253                     case List<?> s -> (List<OnnxParameter>) s;
254                     case null -> List.of();
255                     default -> throw new UnsupportedOperationException();
256                 });
257     }
258 
259     OnnxOp(OnnxOp that, CopyContext cc) {
260         super(that, cc);
261 
262         this.schema = that.schema;
263         this.onnxAttributes = Map.copyOf(that.onnxAttributes);
264         this.resultType = that.resultType;
265         this.optionalInputArguments = List.copyOf(that.optionalInputArguments);
266         this.optionalOutputParameters = List.copyOf(that.optionalOutputParameters);
267     }
268 
269     OnnxOp(OnnxSchema schema, TypeElement resultType,
270            Set<? extends OnnxParameter> optionalOutputParameters,
271            List<Object> inputArguments,
272            List<Object> attributeValues) {
273         super(concatValues(inputArguments));
274 
275         this.schema = schema;
276         this.resultType = resultType;
277 
278         // Optional output parameters
279 
280         if (!optionalOutputParameters.isEmpty()) {
281             List<OnnxParameter> l = new ArrayList<>();
282 
283             for (int i = 0; i < schema.outputs().size(); i++) {
284                 OnnxParameter p = schema.outputs().get(i);
285                 if (p.quantifier().isOptional()
286                         && optionalOutputParameters.contains(p)) {
287                     l.add(p);
288                 }
289             }
290             this.optionalOutputParameters = List.copyOf(l);
291         } else {
292             this.optionalOutputParameters = List.of();
293         }
294 
295         // Optional input parameters
296 
297         if (!inputArguments.isEmpty()) {
298             List<OnnxParameter> l = new ArrayList<>();
299 
300             for (int i = 0; i < schema.inputs().size(); i++) {
301                 OnnxParameter p = schema.inputs().get(i);
302                 if (p.quantifier().isOptional()) {
303                     assert inputArguments.get(i) instanceof Optional;
304                     if (inputArguments.get(i) instanceof Optional<?> optionalValue
305                             && optionalValue.isPresent()) {
306                         l.add(p);
307                     }
308                 }
309             }
310             if (!l.isEmpty()) {
311                 this.optionalInputArguments = List.copyOf(l);
312             } else {
313                 this.optionalInputArguments = List.of();
314             }
315         } else {
316             this.optionalInputArguments = List.of();
317         }
318 
319         // Attributes
320 
321         if (!attributeValues.isEmpty()) {
322             Map<String, Object> attrs = new HashMap<>();
323             assert schema.attributes().size() == attributeValues.size();
324             for (int i = 0; i < schema.attributes().size(); i++) {
325                 schema.attributes().get(i).process(attrs, attributeValues.get(i));
326             }
327             this.onnxAttributes = Map.copyOf(attrs);
328         } else {
329             this.onnxAttributes = Map.of();
330         }
331     }
332 
333     @Override
334     public TypeElement resultType() {
335         return resultType;
336     }
337 
338     @Override
339     public String externalizeOpName() {
340         return schema.name();
341     }
342 
343     @Override
344     public Map<String, Object> externalize() {
345         HashMap<String, Object> m = new HashMap<>(onnxAttributes);
346         if (!optionalInputArguments.isEmpty()) {
347             m.put(ATTRIBUTE_OPTIONAL_INPUTS, optionalInputArguments);
348         }
349         if (!optionalOutputParameters.isEmpty()) {
350             m.put(ATTRIBUTE_OPTIONAL_OUTPUTS, optionalOutputParameters);
351         }
352         return Collections.unmodifiableMap(m);
353     }
354 
355     public OnnxSchema schema() {
356         return schema;
357     }
358 
359     // @@@ Change to Map<OnnxAttribute, Object>
360     public Map<String, Object> onnxAttributes() {
361         return onnxAttributes;
362     }
363 
364     public SequencedSet<OnnxParameter> onnxOutputs() {
365         return Collections.emptyNavigableSet();
366     }
367 
368     SequencedSet<OnnxParameter> onnxOutputs(OnnxSchema schema) {
369         LinkedHashSet<OnnxParameter> s = new LinkedHashSet<>();
370         for (OnnxParameter p : schema.outputs()) {
371             if (!p.quantifier().isOptional() || optionalOutputParameters.contains(p)) {
372                 s.add(p);
373             }
374         }
375 
376         return s;
377     }
378 
379     public SequencedMap<OnnxParameter, Object> onnxInputs() {
380         return Collections.emptyNavigableMap();
381     }
382 
383     SequencedMap<OnnxParameter, Object> onnxInputs(OnnxSchema schema, List<Object> inputArguments) {
384         assert schema.inputs().size() == inputArguments.size();
385         if (!inputArguments.isEmpty()) {
386             SequencedMap<OnnxParameter, Object> inputs = new LinkedHashMap<>();
387             for (int i = 0; i < schema.inputs().size(); i++) {
388                 inputs.put(schema.inputs().get(i), inputArguments.get(i));
389             }
390             return inputs;
391         } else {
392             return Collections.emptyNavigableMap();
393         }
394     }
395 }