1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx.protogen;
27
28 import java.io.FileOutputStream;
29 import java.io.PrintStream;
30 import java.nio.file.Files;
31 import java.nio.file.Path;
32 import java.util.ArrayList;
33 import java.util.Iterator;
34 import java.util.List;
35 import java.util.regex.Matcher;
36 import java.util.regex.Pattern;
37 import java.util.stream.Stream;
38
39 public class ProtoGen {
40
41 static final String COPYRIGHT_NOTICE = """
42 /*
43 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
44 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
45 *
46 * This code is free software; you can redistribute it and/or modify it
47 * under the terms of the GNU General Public License version 2 only, as
48 * published by the Free Software Foundation. Oracle designates this
49 * particular file as subject to the "Classpath" exception as provided
50 * by Oracle in the LICENSE file that accompanied this code.
51 *
52 * This code is distributed in the hope that it will be useful, but WITHOUT
53 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
54 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
55 * version 2 for more details (a copy is included in the LICENSE file that
56 * accompanied this code).
57 *
58 * You should have received a copy of the GNU General Public License version
59 * 2 along with this work; if not, write to the Free Software Foundation,
60 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
61 *
62 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
63 * or visit www.oracle.com if you need additional information or have any
64 * questions.
65 */
66
67 """;
68
69 static final String PROTOGEN_PACKAGE = "oracle.code.onnx.proto";
70 static final String PROTOGEN_CONSTANTS_CLASS = "OnnxConstants";
71 static final String PROTOGEN_BUILDER_CLASS = "OnnxBuilder";
72 static final String PROTOGEN_MODEL_CLASS = "OnnxModel";
73
74 static final String SOURCES_PATH = "src/main/java/" + PROTOGEN_PACKAGE.replace(".", "/") + "/";
75
76 private static final String E = "\\s*";
77 private static final String C = "\\s*(?<comment>//.*)";
78 private static final String OC = C + "?";
79 private static final String NB = "\\s+(?<name>\\w+)\\s*\\{" + OC;
80
81 enum TokenType {
82 EMPTY(E),
83 COMMENT(E + C),
84 FIELD(E + "(?<flag>optional |repeated |)\\s*(?<type>\\w+)\\s+(?<name>\\w+)\\s*=\\s*(?<index>\\d+)\\s*(\\[.*\\])?\\s*;" + OC),
85 ENUM_ELEMENT(E + "(?<name>\\w+)\\s*=\\s*(?<value>\\w+)\\s*;" + OC),
86 END(E + "\\}\\s*;?" + OC),
87 MESSAGE(E + "message" + NB),
88 ENUM(E + "enum" + NB),
89 ONEOF(E + "oneof" + NB),
90 PACKAGE(E + "package\\s+(?<name>\\S+)\\s*;" + OC),
91 RESERVED(E + "reserved\\s+(?<words>.+)\\s*;" + OC),
92 SYNTAX(E + "syntax\\s*=\\s*(?<version>.+)\\s*;" + OC);
93
94 final Pattern pattern;
95
96 TokenType(String pattern) {
97 this.pattern = Pattern.compile(pattern);
98 }
99 }
100
101 record Token(TokenType type, Matcher matcher) {}
102
103 record TreeNode(List<String> comments, Token token, List<TreeNode> nested) {}
104
105 static Token lineToToken(String line) {
106 for (var tt : TokenType.values()) {
107 var m = tt.pattern.matcher(line);
108 if (m.matches()) {
109 return new Token(tt, m);
110 }
111 }
112 throw new IllegalArgumentException(line);
113 }
114
115 static void generateConstants(List<TreeNode> tree, PrintStream out) {
116 out.print(COPYRIGHT_NOTICE);
117 out.print("""
118 package %1$s;
119
120 import java.util.function.IntSupplier;
121
122 // Generated from onnx.in.proto
123 public final class %2$s {
124 """.formatted(PROTOGEN_PACKAGE, PROTOGEN_CONSTANTS_CLASS));
125 for (TreeNode en : tree.stream().flatMap(n -> Stream.concat(Stream.of(n), n.nested().stream())).filter(n -> n.token().type() == TokenType.ENUM).toList()) {
126 out.println();
127 String name = en.token().matcher().group("name");
128 for (String c : en.comments()) {
129 out.println(" /" + c);
130 }
131 out.println(" public enum " + name + " implements IntSupplier {");
132 for (TreeNode ev : en.nested()) {
133 if (ev.token().type() == TokenType.ENUM_ELEMENT) {
134 out.println();
135 for (String c : ev.comments()) {
136 out.println(" /" + c);
137 }
138 out.println(" " + ev.token().matcher().group("name") + "(" + ev.token().matcher().group("value") + "),");
139 }
140 }
141 out.print("""
142 ;
143
144 final int value;
145
146 %1$s(int value) {
147 this.value = value;
148 }
149
150 @Override
151 public int getAsInt() {
152 return value;
153 }
154 }
155 """.formatted(name));
156 }
157 out.println("}");
158 }
159
160 static void generateBuilder(List<TreeNode> tree, PrintStream out) {
161 out.print(COPYRIGHT_NOTICE);
162 out.print("""
163 package %1$s;
164
165 import java.io.ByteArrayOutputStream;
166 import java.nio.charset.StandardCharsets;
167 import java.util.function.BiConsumer;
168 import java.util.function.IntSupplier;
169
170 import %1$s.%2$s.*;
171
172 // Generated from onnx.in.proto
173 public sealed class %3$s<T extends %3$s> {
174 """.formatted(PROTOGEN_PACKAGE, PROTOGEN_CONSTANTS_CLASS, PROTOGEN_BUILDER_CLASS));
175 generateBuilderCode(null, " ", tree, out);
176 out.print("""
177
178 // Implementation
179
180 final ByteArrayOutputStream buf = new ByteArrayOutputStream();
181
182 public byte[] getBytes() {
183 return buf.toByteArray();
184 }
185
186 @SuppressWarnings("unchecked")
187 public <P> T forEach(Iterable<P> sup, BiConsumer<T, ? super P> cons) {
188 sup.forEach(p -> cons.accept((T)this, p));
189 return (T)this;
190 }
191
192 void _encode(long number) {
193 for (int i = 64 - Long.numberOfLeadingZeros(number); i > 7; i -= 7) {
194 buf.write(0x80 | (int)number & 0x7f);
195 number >>= 7;
196 }
197 buf.write((int)number & 0x7f);
198 }
199
200 void _encode(float value) {
201 int bits = Float.floatToRawIntBits(value);
202 buf.write((byte)bits);
203 buf.write((byte)(bits >> 8));
204 buf.write((byte)(bits >> 16));
205 buf.write((byte)(bits >> 24));
206 }
207
208 void _encode(double value) {
209 long bits = Double.doubleToRawLongBits(value);
210 buf.write((byte)bits);
211 buf.write((byte)(bits >> 8));
212 buf.write((byte)(bits >> 16));
213 buf.write((byte)(bits >> 24));
214 buf.write((byte)(bits >> 32));
215 buf.write((byte)(bits >> 40));
216 buf.write((byte)(bits >> 48));
217 buf.write((byte)(bits >> 56));
218 }
219
220 @SuppressWarnings("unchecked")
221 T _f(int fieldIndex, String value) {
222 return value == null ? (T)this : _f(fieldIndex, value.getBytes(StandardCharsets.UTF_8));
223 }
224
225 @SuppressWarnings("unchecked")
226 T _f(int fieldIndex, byte[] bytes) {
227 _encode(fieldIndex << 3 | 2);
228 _encode(bytes.length);
229 buf.writeBytes(bytes);
230 return (T)this;
231 }
232
233 @SuppressWarnings("unchecked")
234 T _f(int fieldIndex, float value) {
235 _encode(fieldIndex << 3 | 5);
236 _encode(value);
237 return (T)this;
238 }
239
240 @SuppressWarnings("unchecked")
241 T _f(int fieldIndex, float... values) {
242 if (values.length == 1) {
243 return _f(fieldIndex, values[0]);
244 }
245 var b = new %1$s();
246 for (var v : values) b._encode(v);
247 _f(fieldIndex, b);
248 return (T)this;
249 }
250
251 @SuppressWarnings("unchecked")
252 T _f(int fieldIndex, double value) {
253 _encode(fieldIndex << 3 | 1);
254 _encode(value);
255 return (T)this;
256 }
257
258 @SuppressWarnings("unchecked")
259 T _f(int fieldIndex, double... values) {
260 if (values.length == 1) {
261 return _f(fieldIndex, values[0]);
262 }
263 var b = new %1$s();
264 for (var v : values) b._encode(v);
265 _f(fieldIndex, b);
266 return (T)this;
267 }
268
269 @SuppressWarnings("unchecked")
270 T _f(int fieldIndex, long value) {
271 _encode(fieldIndex << 3);
272 _encode(value);
273 return (T)this;
274 }
275
276 @SuppressWarnings("unchecked")
277 T _f(int fieldIndex, long... values) {
278 if (values.length == 1) {
279 return _f(fieldIndex, values[0]);
280 }
281 var b = new %1$s();
282 for (var v : values) b._encode(v);
283 _f(fieldIndex, b);
284 return (T)this;
285 }
286
287 @SuppressWarnings("unchecked")
288 T _f(int fieldIndex, int... values) {
289 if (values.length == 1) {
290 return _f(fieldIndex, values[0]);
291 }
292 var b = new %1$s();
293 for (var v : values) b._encode(v);
294 _f(fieldIndex, b);
295 return (T)this;
296 }
297
298 @SuppressWarnings("unchecked")
299 T _f(int fieldIndex, %1$s value) {
300 return _f(fieldIndex, value.buf.toByteArray());
301 }
302
303 @SuppressWarnings("unchecked")
304 T _f(int fieldIndex, IntSupplier value) {
305 return _f(fieldIndex, value.getAsInt());
306 }
307 }
308 """.formatted(PROTOGEN_BUILDER_CLASS));
309 }
310
311 static void generateBuilderCode(String parentName, String indent, List<TreeNode> tree, PrintStream out) {
312 for (TreeNode n : tree) {
313 switch (n.token().type()) {
314 case MESSAGE, FIELD -> {
315 out.println();
316 for (String c : n.comments()) out.println(indent + '/' + c);
317 String name = snakeToCamelCase(n.token().matcher().group("name"));
318 if (n.token().type() == TokenType.MESSAGE) {
319 out.println(indent + "public static final class " + name + " extends " + PROTOGEN_BUILDER_CLASS + "<" + name + "> {");
320 generateBuilderCode(name, indent + " ", n.nested(), out);
321 out.println(indent + "}");
322 } else {
323 String type = n.token().matcher().group("type");
324 type = switch (type) {
325 case "string" -> "String";
326 case "int32" -> "int";
327 case "int64" -> "long";
328 case "uint64" -> "long";
329 case "bytes" -> "byte[]";
330 default -> type;
331 };
332 if (Character.isLowerCase(type.charAt(0)) && !type.equals("byte[]") && n.token().matcher().group("flag").equals("repeated ")) {
333 type += "...";
334 }
335 String index = n.token().matcher().group("index");
336 out.println(indent + "public " + parentName + " " + name + "(" + type + " " + name + ") {return _f(" + index + ", " + name + ");}");
337 }
338 }
339 }
340 }
341 }
342
343 static void generateModel(List<TreeNode> tree, PrintStream out) {
344 out.print(COPYRIGHT_NOTICE);
345 out.print("""
346 package %1$s;
347
348 import java.io.RandomAccessFile;
349 import java.lang.annotation.ElementType;
350 import java.lang.annotation.Retention;
351 import java.lang.annotation.RetentionPolicy;
352 import java.lang.annotation.Target;
353 import java.lang.reflect.ParameterizedType;
354 import java.lang.reflect.RecordComponent;
355 import java.nio.ByteBuffer;
356 import java.nio.ByteOrder;
357 import java.nio.channels.FileChannel;
358 import java.util.ArrayList;
359 import java.util.Arrays;
360 import java.util.List;
361 import java.util.function.IntSupplier;
362 import java.util.function.Supplier;
363
364 import %1$s.%2$s.*;
365
366 // Generated from onnx.in.proto
367 public sealed interface %3$s {
368 """.formatted(PROTOGEN_PACKAGE, PROTOGEN_CONSTANTS_CLASS, PROTOGEN_MODEL_CLASS));
369 generateModelCode(" ", tree, out);
370 out.print("""
371
372 // Implementation
373
374
375 @Retention(RetentionPolicy.RUNTIME)
376 @Target(ElementType.RECORD_COMPONENT)
377 @interface f {
378 int value();
379 }
380
381 private static long decodeVarint(ByteBuffer data) {
382 long i, shift = 0, value = 0;
383 do {
384 value |= ((i = data.get()) & 0x7f) << shift;
385 shift += 7;
386 } while ((i & 0x80) != 0);
387 return value;
388 }
389
390 private static int countVarInts(ByteBuffer data) {
391 long end = decodeVarint(data);
392 int start = data.position();
393 end += start;
394 int count = 0;
395 while (data.position() < end) {
396 if ((data.get() & 0x80) == 0) count++;
397 }
398 data.position(start);
399 return count;
400 }
401
402 private static int[] readPackedInts(ByteBuffer data) {
403 var ret = new int[countVarInts(data)];
404 for (int i = 0; i < ret.length; i++) {
405 ret[i] = (int)decodeVarint(data);
406 }
407 return ret;
408 }
409
410 private static long[] readPackedLongs(ByteBuffer data) {
411 var ret = new long[countVarInts(data)];
412 for (int i = 0; i < ret.length; i++) {
413 ret[i] = decodeVarint(data);
414 }
415 return ret;
416 }
417
418 private static float[] readPackedFloats(ByteBuffer data) {
419 var ret = new float[(int)(decodeVarint(data)/4)];
420 for (int i = 0; i < ret.length; i++) {
421 ret[i] = data.getFloat();
422 }
423 return ret;
424 }
425
426 private static double[] readPackedDoubles(ByteBuffer data) {
427 var ret = new double[(int)(decodeVarint(data)/8)];
428 for (int i = 0; i < ret.length; i++) {
429 ret[i] = data.getDouble();
430 }
431 return ret;
432 }
433
434 private static byte[] readBytes(ByteBuffer data) {
435 var bytes = new byte[(int)decodeVarint(data)];
436 data.get(bytes);
437 return bytes;
438 }
439
440 private static Object readData(Class<?> baseType, boolean packed, ByteBuffer bb) {
441 if (baseType == Integer.class) {
442 return (int)decodeVarint(bb);
443 } else if (baseType == int[].class) {
444 return packed ? readPackedInts(bb) : new int[]{(int)decodeVarint(bb)};
445 } else if (baseType == Long.class) {
446 return decodeVarint(bb);
447 } else if (baseType == long[].class) {
448 return packed ? readPackedLongs(bb) : new long[]{decodeVarint(bb)};
449 } else if (baseType == Float.class) {
450 return bb.getFloat();
451 } else if (baseType == float[].class) {
452 return packed ? readPackedFloats(bb) : new float[] {bb.getFloat()};
453 } else if (baseType == Double.class) {
454 return bb.getDouble();
455 } else if (baseType == double[].class) {
456 return packed ? readPackedDoubles(bb) : new double[] {bb.getDouble()};
457 } else if (baseType == byte[].class) {
458 return readBytes(bb);
459 } else if (baseType == String.class) {
460 return new String(readBytes(bb));
461 } else if (baseType.getEnclosingClass() == %1$s.class) {
462 int value = (int)decodeVarint(bb);
463 for (Object cs : baseType.getEnumConstants()) {
464 if (cs instanceof IntSupplier is && is.getAsInt() == value) {
465 return cs;
466 }
467 }
468 throw new IllegalArgumentException(baseType.toString());
469 } else {
470 var size = decodeVarint(bb);
471 int limit = bb.limit();
472 var data = readFrom((Class<Record>)baseType, bb.limit(bb.position() + (int)size));
473 bb.limit(limit);
474 return data;
475 }
476 }
477
478 private static int getRecordFieldIndex(RecordComponent[] rcs, int fieldIndex) {
479 for (int i = 0; i < rcs.length; i++) {
480 if (rcs[i].getAnnotation(f.class).value() == fieldIndex) {
481 return i;
482 }
483 }
484 throw new IllegalArgumentException("Field index " + fieldIndex + " not found in " + rcs[0].getDeclaringRecord());
485 }
486
487 private static <T> T readFrom(Class<T> type, ByteBuffer bb) {
488 Object[] fieldsData = new Object[type.getRecordComponents().length];
489 while (bb.remaining() > 0) {
490 long tag = decodeVarint(bb);
491 RecordComponent[] rcs = type.getRecordComponents();
492 int rfi = getRecordFieldIndex(rcs, (int)tag >> 3);
493 boolean packed = (tag & 7) == 2;
494 RecordComponent rc = rcs[rfi];
495 Class<?> rcType = rc.getType();
496 if (rcType == List.class) {
497 List list;
498 if (fieldsData[rfi] instanceof List l) {
499 list = l;
500 } else {
501 list = new ArrayList();
502 fieldsData[rfi] = list;
503 }
504 Class baseType = (Class)((ParameterizedType)rc.getGenericType()).getActualTypeArguments()[0];
505 list.add(readData(baseType, packed, bb));
506 } else {
507 fieldsData[rfi] = readData(rcType, packed, bb);
508 }
509 }
510 try {
511 return (T)type.getDeclaredConstructors()[0].newInstance(fieldsData);
512 } catch (ReflectiveOperationException e) {
513 throw new RuntimeException(e);
514 }
515 }
516
517 private static void print(StringBuilder out, int indent, String name, Object value, boolean skipBigData) throws ReflectiveOperationException {
518 if (value == null) return;
519 out.append(" ".repeat(indent)).append(name);
520 switch (value) {
521 case List l -> {
522 out.append(name.endsWith("s") ? ":" : "s:").append(System.lineSeparator());
523 for (var el : l) print(out, indent + 1, "- " + (name.endsWith("s") ? name.substring(0, name.length() - 1) : name), el, skipBigData);
524 }
525 case Record r -> {
526 out.append(':').append(System.lineSeparator());
527 for (var rc : r.getClass().getRecordComponents()) {
528 print(out, indent + 2, rc.getName(), rc.getAccessor().invoke(r), skipBigData);
529 }
530 }
531 case byte[] a ->
532 out.append(checkSize(a.length, () -> Arrays.toString(a), skipBigData));
533 case long[] a ->
534 out.append(checkSize(a.length, () -> Arrays.toString(a), skipBigData));
535 case float[] a ->
536 out.append(checkSize(a.length, () -> Arrays.toString(a), skipBigData));
537 case double[] a ->
538 out.append(checkSize(a.length, () -> Arrays.toString(a), skipBigData));
539 case String s ->
540 out.append(": \\"").append(s).append('"').append(System.lineSeparator());
541 default ->
542 out.append(": ").append(value).append(System.lineSeparator());
543 }
544 }
545
546 static final int SKIP_LIMIT = 1000;
547
548 private static String checkSize(int size, Supplier<String> sup, boolean skipBigData) {
549 return ": " + (skipBigData && size > SKIP_LIMIT ? "# skipped " + size + " values" : sup.get()) + System.lineSeparator();
550 }
551
552 default String toText() {
553 return toText(true);
554 }
555
556 default String toText(boolean skipBigData) {
557 try {
558 var sb = new StringBuilder();
559 print(sb, 0, getClass().getSimpleName(), this, skipBigData);
560 return sb.toString();
561 } catch (ReflectiveOperationException e) {
562 throw new RuntimeException(e);
563 }
564 }
565
566 public static %2$s.ModelProto readFrom(byte[] onnxProtoModel) {
567 return readFrom(ByteBuffer.wrap(onnxProtoModel));
568 }
569
570 public static %2$s.ModelProto readFrom(ByteBuffer onnxProtoModel) {
571 return readFrom(%2$s.ModelProto.class, onnxProtoModel.order(ByteOrder.LITTLE_ENDIAN));
572 }
573
574 public static void main(String... args) throws Exception {
575 for (var fName : args) {
576 try (var in = new RandomAccessFile(fName, "r")) {
577 %2$s.ModelProto model = readFrom(in.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, in.length()));
578 System.out.println(model.toText());
579 }
580 }
581 }
582 }
583 """.formatted(PROTOGEN_CONSTANTS_CLASS, PROTOGEN_MODEL_CLASS));
584 }
585
586 static void generateModelCode(String indent, List<TreeNode> tree, PrintStream out) {
587 for (TreeNode n : tree) {
588 if (n.token().type() == TokenType.MESSAGE) {
589 out.println();
590 for (String c : n.comments()) out.println(indent + '/' + c);
591 String recordName = n.token().matcher().group("name");
592 out.println(indent + "public record " + recordName + " (");
593 boolean first = true;
594 for (TreeNode nn : n.nested()) {
595 if (nn.token().type() == TokenType.FIELD) {
596 if (first) {
597 first = false;
598 } else {
599 out.println(",");
600 }
601 out.println();
602 for (String c : nn.comments()) out.println(indent + " /" + c);
603 String name = snakeToCamelCase(nn.token().matcher().group("name"));
604 String type = nn.token().matcher().group("type");
605 if (nn.token().matcher().group("flag").equals("repeated ")) {
606 type = switch (type) {
607 case "float" -> "List<float[]>";
608 case "double" -> "List<double[]>";
609 case "string" -> "List<String>";
610 case "int32" -> "List<int[]>";
611 case "int64", "uint64" -> "List<long[]>";
612 case "bytes" -> "List<byte[]>";
613 default -> "List<" + type + ">";
614 };
615 } else {
616 type = switch (type) {
617 case "float" -> "Float";
618 case "double" -> "Double";
619 case "string" -> "String";
620 case "int32" -> "Integer";
621 case "int64", "uint64" -> "Long";
622 case "bytes" -> "byte[]";
623 default -> type;
624 };
625 }
626 String index = nn.token().matcher().group("index");
627 out.print(indent + " @f(" + index + ") " + type + " " + name);
628 }
629 }
630 out.println(") implements " + PROTOGEN_MODEL_CLASS + " {");
631 generateModelCode(indent + " ", n.nested(), out);
632 out.println(indent + "}");
633 }
634 }
635 }
636
637 static List<TreeNode> toTree(Iterator<Token> tokens) {
638 List<TreeNode> nodes = new ArrayList<>();
639 List<String> comments = new ArrayList<>();
640 int oneofs = 0;
641 while (tokens.hasNext()) {
642 Token t = tokens.next();
643 switch (t.type()) {
644 case COMMENT -> comments.add(t.matcher().group("comment"));
645 case EMPTY -> comments.clear(); // do not merge isolated comment blocks
646 case ONEOF -> oneofs++; // flat ONEOF
647 case ENUM_ELEMENT, FIELD, RESERVED, SYNTAX, PACKAGE -> {
648 if (t.matcher().group("comment") instanceof String c) comments.add(c);
649 nodes.add(new TreeNode(comments, t, List.of()));
650 comments = new ArrayList<>();
651 }
652 case ENUM, MESSAGE -> {
653 if (t.matcher().group("comment") instanceof String c) comments.add(c);
654 nodes.add(new TreeNode(comments, t, toTree(tokens)));
655 comments = new ArrayList<>();
656 }
657 case END -> {
658 if (oneofs-- == 0) return nodes;
659 }
660 }
661 }
662 return nodes;
663 }
664
665 static final Pattern SNAKE = Pattern.compile("_([a-z])");
666
667 static String snakeToCamelCase(String name) {
668 return SNAKE.matcher(name).replaceAll(mr -> mr.group(1).toUpperCase());
669 }
670
671 public static void main(String[] args) throws Exception {
672 List<TreeNode> tree = toTree(Files.lines(Path.of("opgen/onnx.in.proto")).map(ProtoGen::lineToToken).iterator());
673 try (var constants = new PrintStream(new FileOutputStream(SOURCES_PATH + PROTOGEN_CONSTANTS_CLASS + ".java"))) {
674 generateConstants(tree, constants);
675 }
676 try (var builder = new PrintStream(new FileOutputStream(SOURCES_PATH + PROTOGEN_BUILDER_CLASS + ".java"))) {
677 generateBuilder(tree, builder);
678 }
679 try (var model = new PrintStream(new FileOutputStream(SOURCES_PATH + PROTOGEN_MODEL_CLASS + ".java"))) {
680 generateModel(tree, model);
681 }
682 }
683 }