1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx.ir;
27
28 import jdk.incubator.code.CopyContext;
29 import jdk.incubator.code.Op;
30 import jdk.incubator.code.TypeElement;
31 import jdk.incubator.code.Value;
32 import jdk.incubator.code.extern.ExternalizedOp;
33
34 import java.util.*;
35
36 public abstract class OnnxOp extends Op {
37
38 public interface OnnxAttribute {
39 String name();
40
41 Class<?> type();
42
43 Object defaultValue();
44
45 boolean isOptional();
46
47 default void process(Map<String, Object> attrs, Object value) {
48 if (value instanceof Optional<?> o) {
49 value = o.orElse(null);
50 }
51 // @@@ Parse attribute from string value
52 // @@@ Arrays don't serialize
53 if (type().isInstance(value)) {
54 attrs.put(name(), value);
55 } else if (value == null) {
56 // Ignore
57 } else {
58 throw new UnsupportedOperationException();
59 }
60 }
61
62 default <T> T access(Class<T> type, Map<String, Object> attrs) {
63 Object value = attrs.get(name());
64 if (value == null && !isOptional()) {
65 throw new NoSuchElementException();
66 }
67 return type.cast(value);
68 }
69
70 static Map<String, Object> process(ExternalizedOp eop,
71 List<OnnxAttribute> attributes) {
72 Map<String, Object> attrs = new HashMap<>();
73 for (OnnxAttribute attribute : attributes) {
74 Object v = eop.attributes().get(attribute.name());
75 if (v == null && !attribute.isOptional()) {
76 throw new NoSuchElementException(attribute.name());
77 }
78 attribute.process(attrs, v);
79 }
80
81 return Map.copyOf(attrs);
82 }
83
84 interface None extends OnnxAttribute {
85 @Override
86 default String name() {
87 throw new UnsupportedOperationException();
88 }
89
90 @Override
91 default Class<?> type() {
92 throw new UnsupportedOperationException();
93 }
94
95 @Override
96 default Object defaultValue() {
97 throw new UnsupportedOperationException();
98 }
99
100 @Override
101 default boolean isOptional() {
102 throw new UnsupportedOperationException();
103 }
104 }
105
106 }
107
108 public interface OnnxTypeConstraint {
109 String name();
110
111 OnnxType.TypeVariable typeVariable();
112
113 interface None extends OnnxTypeConstraint {
114 @Override
115 default String name() {
116 throw new UnsupportedOperationException();
117 }
118
119 @Override
120 default OnnxType.TypeVariable typeVariable() {
121 throw new UnsupportedOperationException();
122 }
123 }
124 }
125
126 public interface OnnxParameter {
127 enum Quantifier {
128 REQUIRED, // Exactly once
129 OPTIONAL, // Once or none
130 VARIADIC, // One or more
131 ;
132
133 public boolean isOptional() {
134 return this == OPTIONAL;
135 }
136
137 public boolean isRequired() {
138 return this == REQUIRED;
139 }
140
141 public boolean isVariadoc() {
142 return this == VARIADIC;
143 }
144 }
145
146 String name();
147
148 OnnxType type();
149
150 Quantifier quantifier();
151
152 interface None extends OnnxParameter {
153 @Override
154 default String name() {
155 throw new UnsupportedOperationException();
156 }
157
158 @Override
159 default OnnxType type() {
160 throw new UnsupportedOperationException();
161 }
162
163 @Override
164 default Quantifier quantifier() {
165 throw new UnsupportedOperationException();
166 }
167 }
168 }
169
170 public interface OnnxSchema {
171 String name();
172
173 List<OnnxAttribute> attributes();
174
175 List<OnnxTypeConstraint> typeConstraints();
176
177 List<OnnxParameter> inputs();
178
179 List<OnnxParameter> outputs();
180 }
181
182 record OnnxSchemaRecord(
183 String name,
184 List<OnnxAttribute> attributes,
185 List<OnnxTypeConstraint> typeConstraints,
186 List<OnnxParameter> inputs,
187 List<OnnxParameter> outputs
188 ) implements OnnxSchema {}
189
190 static List<Value> concatValues(Value operand) {
191 return List.of(operand);
192 }
193
194 static List<Value> concatValues(Value... operands) {
195 return List.of(operands);
196 }
197
198 static List<Value> concatValues(List<Object> operands) {
199 return concatValues(operands.toArray());
200 }
201
202 static List<Value> concatValues(Object... operands) {
203 List<Value> l = new ArrayList<>();
204 for (Object operand : operands) {
205 switch (operand) {
206 case Value v -> l.add(v);
207 case Optional<?> ov -> {
208 if (ov.isPresent()) {
209 l.add((Value) ov.get());
210 }
211 }
212 case List<?> vs -> {
213 for (Object v : vs) {
214 l.add((Value) v);
215 }
216 }
217 default -> throw new UnsupportedOperationException();
218 }
219 }
220 return l;
221 }
222
223 static final String ATTRIBUTE_OPTIONAL_INPUTS = "optional_inputs";
224 static final String ATTRIBUTE_OPTIONAL_OUTPUTS = "optional_outputs";
225
226 final OnnxSchema schema;
227 final Map<String, Object> onnxAttributes;
228 final TypeElement resultType;
229 final List<OnnxParameter> optionalInputArguments;
230 final List<OnnxParameter> optionalOutputParameters;
231
232 @SuppressWarnings("unchecked")
233 OnnxOp(OnnxSchema schema, ExternalizedOp def) {
234 super(def.operands());
235
236 this.schema = schema;
237 this.onnxAttributes = schema.attributes().isEmpty()
238 ? Map.of()
239 : OnnxAttribute.process(def, schema.attributes());
240 this.resultType = def.resultType();
241
242 // @@@ Filter optional
243 this.optionalInputArguments = def.extractAttributeValue(ATTRIBUTE_OPTIONAL_INPUTS,
244 false, v -> switch (v) {
245 case List<?> s -> (List<OnnxParameter>) s;
246 case null -> List.of();
247 default -> throw new UnsupportedOperationException();
248 });
249
250 // @@@ Filter optional
251 this.optionalOutputParameters = def.extractAttributeValue(ATTRIBUTE_OPTIONAL_OUTPUTS,
252 false, v -> switch (v) {
253 case List<?> s -> (List<OnnxParameter>) s;
254 case null -> List.of();
255 default -> throw new UnsupportedOperationException();
256 });
257 }
258
259 OnnxOp(OnnxOp that, CopyContext cc) {
260 super(that, cc);
261
262 this.schema = that.schema;
263 this.onnxAttributes = Map.copyOf(that.onnxAttributes);
264 this.resultType = that.resultType;
265 this.optionalInputArguments = List.copyOf(that.optionalInputArguments);
266 this.optionalOutputParameters = List.copyOf(that.optionalOutputParameters);
267 }
268
269 OnnxOp(OnnxSchema schema, TypeElement resultType,
270 Set<? extends OnnxParameter> optionalOutputParameters,
271 List<Object> inputArguments,
272 List<Object> attributeValues) {
273 super(concatValues(inputArguments));
274
275 this.schema = schema;
276 this.resultType = resultType;
277
278 // Optional output parameters
279
280 if (!optionalOutputParameters.isEmpty()) {
281 List<OnnxParameter> l = new ArrayList<>();
282
283 for (int i = 0; i < schema.outputs().size(); i++) {
284 OnnxParameter p = schema.outputs().get(i);
285 if (p.quantifier().isOptional()
286 && optionalOutputParameters.contains(p)) {
287 l.add(p);
288 }
289 }
290 this.optionalOutputParameters = List.copyOf(l);
291 } else {
292 this.optionalOutputParameters = List.of();
293 }
294
295 // Optional input parameters
296
297 if (!inputArguments.isEmpty()) {
298 List<OnnxParameter> l = new ArrayList<>();
299
300 for (int i = 0; i < schema.inputs().size(); i++) {
301 OnnxParameter p = schema.inputs().get(i);
302 if (p.quantifier().isOptional()) {
303 assert inputArguments.get(i) instanceof Optional;
304 if (inputArguments.get(i) instanceof Optional<?> optionalValue
305 && optionalValue.isPresent()) {
306 l.add(p);
307 }
308 }
309 }
310 if (!l.isEmpty()) {
311 this.optionalInputArguments = List.copyOf(l);
312 } else {
313 this.optionalInputArguments = List.of();
314 }
315 } else {
316 this.optionalInputArguments = List.of();
317 }
318
319 // Attributes
320
321 if (!attributeValues.isEmpty()) {
322 Map<String, Object> attrs = new HashMap<>();
323 assert schema.attributes().size() == attributeValues.size();
324 for (int i = 0; i < schema.attributes().size(); i++) {
325 schema.attributes().get(i).process(attrs, attributeValues.get(i));
326 }
327 this.onnxAttributes = Map.copyOf(attrs);
328 } else {
329 this.onnxAttributes = Map.of();
330 }
331 }
332
333 @Override
334 public TypeElement resultType() {
335 return resultType;
336 }
337
338 @Override
339 public String externalizeOpName() {
340 return schema.name();
341 }
342
343 @Override
344 public Map<String, Object> externalize() {
345 HashMap<String, Object> m = new HashMap<>(onnxAttributes);
346 if (!optionalInputArguments.isEmpty()) {
347 m.put(ATTRIBUTE_OPTIONAL_INPUTS, optionalInputArguments);
348 }
349 if (!optionalOutputParameters.isEmpty()) {
350 m.put(ATTRIBUTE_OPTIONAL_OUTPUTS, optionalOutputParameters);
351 }
352 return Collections.unmodifiableMap(m);
353 }
354
355 public OnnxSchema schema() {
356 return schema;
357 }
358
359 // @@@ Change to Map<OnnxAttribute, Object>
360 public Map<String, Object> onnxAttributes() {
361 return onnxAttributes;
362 }
363
364 public SequencedSet<OnnxParameter> onnxOutputs() {
365 return Collections.emptyNavigableSet();
366 }
367
368 SequencedSet<OnnxParameter> onnxOutputs(OnnxSchema schema) {
369 LinkedHashSet<OnnxParameter> s = new LinkedHashSet<>();
370 for (OnnxParameter p : schema.outputs()) {
371 if (!p.quantifier().isOptional() || optionalOutputParameters.contains(p)) {
372 s.add(p);
373 }
374 }
375
376 return s;
377 }
378
379 public SequencedMap<OnnxParameter, Object> onnxInputs() {
380 return Collections.emptyNavigableMap();
381 }
382
383 SequencedMap<OnnxParameter, Object> onnxInputs(OnnxSchema schema, List<Object> inputArguments) {
384 assert schema.inputs().size() == inputArguments.size();
385 if (!inputArguments.isEmpty()) {
386 SequencedMap<OnnxParameter, Object> inputs = new LinkedHashMap<>();
387 for (int i = 0; i < schema.inputs().size(); i++) {
388 inputs.put(schema.inputs().get(i), inputArguments.get(i));
389 }
390 return inputs;
391 } else {
392 return Collections.emptyNavigableMap();
393 }
394 }
395 }