1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx.protogen;
 27 
 28 import java.io.FileOutputStream;
 29 import java.io.PrintStream;
 30 import java.nio.file.Files;
 31 import java.nio.file.Path;
 32 import java.util.ArrayList;
 33 import java.util.Iterator;
 34 import java.util.List;
 35 import java.util.regex.Matcher;
 36 import java.util.regex.Pattern;
 37 import java.util.stream.Stream;
 38 
 39 public class ProtoGen {
 40 
 41     static final String COPYRIGHT_NOTICE = """
 42             /*
 43              * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
 44              * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 45              *
 46              * This code is free software; you can redistribute it and/or modify it
 47              * under the terms of the GNU General Public License version 2 only, as
 48              * published by the Free Software Foundation.  Oracle designates this
 49              * particular file as subject to the "Classpath" exception as provided
 50              * by Oracle in the LICENSE file that accompanied this code.
 51              *
 52              * This code is distributed in the hope that it will be useful, but WITHOUT
 53              * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 54              * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 55              * version 2 for more details (a copy is included in the LICENSE file that
 56              * accompanied this code).
 57              *
 58              * You should have received a copy of the GNU General Public License version
 59              * 2 along with this work; if not, write to the Free Software Foundation,
 60              * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 61              *
 62              * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 63              * or visit www.oracle.com if you need additional information or have any
 64              * questions.
 65              */
 66 
 67             """;
 68 
 69     static final String PROTOGEN_PACKAGE = "oracle.code.onnx.proto";
 70     static final String PROTOGEN_CONSTANTS_CLASS = "OnnxConstants";
 71     static final String PROTOGEN_BUILDER_CLASS = "OnnxBuilder";
 72     static final String PROTOGEN_MODEL_CLASS = "OnnxModel";
 73 
 74     static final String SOURCES_PATH = "src/main/java/" + PROTOGEN_PACKAGE.replace(".", "/") + "/";
 75 
 76     private static final String E = "\\s*";
 77     private static final String C = "\\s*(?<comment>//.*)";
 78     private static final String OC = C + "?";
 79     private static final String NB = "\\s+(?<name>\\w+)\\s*\\{" + OC;
 80 
 81     enum TokenType {
 82         EMPTY(E),
 83         COMMENT(E + C),
 84         FIELD(E + "(?<flag>optional |repeated |)\\s*(?<type>\\w+)\\s+(?<name>\\w+)\\s*=\\s*(?<index>\\d+)\\s*(\\[.*\\])?\\s*;" + OC),
 85         ENUM_ELEMENT(E + "(?<name>\\w+)\\s*=\\s*(?<value>\\w+)\\s*;" + OC),
 86         END(E + "\\}\\s*;?" + OC),
 87         MESSAGE(E + "message" + NB),
 88         ENUM(E + "enum" + NB),
 89         ONEOF(E + "oneof" + NB),
 90         PACKAGE(E + "package\\s+(?<name>\\S+)\\s*;" + OC),
 91         RESERVED(E + "reserved\\s+(?<words>.+)\\s*;" + OC),
 92         SYNTAX(E + "syntax\\s*=\\s*(?<version>.+)\\s*;" + OC);
 93 
 94         final Pattern pattern;
 95 
 96         TokenType(String pattern) {
 97             this.pattern = Pattern.compile(pattern);
 98         }
 99     }
100 
101     record Token(TokenType type, Matcher matcher) {}
102 
103     record TreeNode(List<String> comments, Token token, List<TreeNode> nested) {}
104 
105     static Token lineToToken(String line) {
106         for (var tt : TokenType.values()) {
107             var m = tt.pattern.matcher(line);
108             if (m.matches()) {
109                 return new Token(tt, m);
110             }
111         }
112         throw new IllegalArgumentException(line);
113     }
114 
115     static void generateConstants(List<TreeNode> tree, PrintStream out) {
116         out.print(COPYRIGHT_NOTICE);
117         out.print("""
118                 package %1$s;
119 
120                 import java.util.function.IntSupplier;
121 
122                 // Generated from onnx.in.proto
123                 public final class %2$s {
124                 """.formatted(PROTOGEN_PACKAGE, PROTOGEN_CONSTANTS_CLASS));
125         for (TreeNode en : tree.stream().flatMap(n -> Stream.concat(Stream.of(n), n.nested().stream())).filter(n -> n.token().type() == TokenType.ENUM).toList()) {
126             out.println();
127             String name = en.token().matcher().group("name");
128             for (String c : en.comments()) {
129                 out.println("    /" + c);
130             }
131             out.println("    public enum " + name + " implements IntSupplier {");
132             for (TreeNode ev : en.nested()) {
133                 if (ev.token().type() == TokenType.ENUM_ELEMENT) {
134                     out.println();
135                     for (String c : ev.comments()) {
136                         out.println("        /" + c);
137                     }
138                     out.println("        " + ev.token().matcher().group("name") + "(" + ev.token().matcher().group("value") + "),");
139                 }
140             }
141             out.print("""
142                             ;
143 
144                             final int value;
145 
146                             %1$s(int value) {
147                                 this.value = value;
148                             }
149 
150                             @Override
151                             public int getAsInt() {
152                                 return value;
153                             }
154                         }
155                     """.formatted(name));
156         }
157         out.println("}");
158     }
159 
160     static void generateBuilder(List<TreeNode> tree, PrintStream out) {
161         out.print(COPYRIGHT_NOTICE);
162         out.print("""
163                 package %1$s;
164 
165                 import java.io.ByteArrayOutputStream;
166                 import java.nio.charset.StandardCharsets;
167                 import java.util.function.BiConsumer;
168                 import java.util.function.IntSupplier;
169 
170                 import %1$s.%2$s.*;
171 
172                 // Generated from onnx.in.proto
173                 public sealed class %3$s<T extends %3$s> {
174                 """.formatted(PROTOGEN_PACKAGE, PROTOGEN_CONSTANTS_CLASS, PROTOGEN_BUILDER_CLASS));
175         generateBuilderCode(null, "    ", tree, out);
176         out.print("""
177 
178                     // Implementation
179 
180                     final ByteArrayOutputStream buf = new ByteArrayOutputStream();
181 
182                     public byte[] getBytes() {
183                         return buf.toByteArray();
184                     }
185 
186                     @SuppressWarnings("unchecked")
187                     public <P> T forEach(Iterable<P> sup, BiConsumer<T, ? super P> cons) {
188                         sup.forEach(p -> cons.accept((T)this, p));
189                         return (T)this;
190                     }
191 
192                     void _encode(long number) {
193                         for (int i = 64 - Long.numberOfLeadingZeros(number); i > 7; i -= 7) {
194                             buf.write(0x80 | (int)number & 0x7f);
195                             number >>= 7;
196                         }
197                         buf.write((int)number & 0x7f);
198                     }
199 
200                     void _encode(float value) {
201                         int bits =  Float.floatToRawIntBits(value);
202                         buf.write((byte)bits);
203                         buf.write((byte)(bits >> 8));
204                         buf.write((byte)(bits >> 16));
205                         buf.write((byte)(bits >> 24));
206                     }
207 
208                     void _encode(double value) {
209                         long bits =  Double.doubleToRawLongBits(value);
210                         buf.write((byte)bits);
211                         buf.write((byte)(bits >> 8));
212                         buf.write((byte)(bits >> 16));
213                         buf.write((byte)(bits >> 24));
214                         buf.write((byte)(bits >> 32));
215                         buf.write((byte)(bits >> 40));
216                         buf.write((byte)(bits >> 48));
217                         buf.write((byte)(bits >> 56));
218                     }
219 
220                     @SuppressWarnings("unchecked")
221                     T _f(int fieldIndex, String value) {
222                         return value == null ? (T)this : _f(fieldIndex, value.getBytes(StandardCharsets.UTF_8));
223                     }
224 
225                     @SuppressWarnings("unchecked")
226                     T _f(int fieldIndex, byte[] bytes) {
227                         _encode(fieldIndex << 3 | 2);
228                         _encode(bytes.length);
229                         buf.writeBytes(bytes);
230                         return (T)this;
231                     }
232 
233                     @SuppressWarnings("unchecked")
234                     T _f(int fieldIndex, float value) {
235                         _encode(fieldIndex << 3 | 5);
236                         _encode(value);
237                         return (T)this;
238                     }
239 
240                     @SuppressWarnings("unchecked")
241                     T _f(int fieldIndex, float... values) {
242                         if (values.length == 1) {
243                             return _f(fieldIndex, values[0]);
244                         }
245                         var b = new %1$s();
246                         for (var v : values) b._encode(v);
247                         _f(fieldIndex, b);
248                         return (T)this;
249                     }
250 
251                     @SuppressWarnings("unchecked")
252                     T _f(int fieldIndex, double value) {
253                         _encode(fieldIndex << 3 | 1);
254                         _encode(value);
255                         return (T)this;
256                     }
257 
258                     @SuppressWarnings("unchecked")
259                     T _f(int fieldIndex, double... values) {
260                         if (values.length == 1) {
261                             return _f(fieldIndex, values[0]);
262                         }
263                         var b = new %1$s();
264                         for (var v : values) b._encode(v);
265                         _f(fieldIndex, b);
266                         return (T)this;
267                     }
268 
269                     @SuppressWarnings("unchecked")
270                     T _f(int fieldIndex, long value) {
271                         _encode(fieldIndex << 3);
272                         _encode(value);
273                         return (T)this;
274                     }
275 
276                     @SuppressWarnings("unchecked")
277                     T _f(int fieldIndex, long... values) {
278                         if (values.length == 1) {
279                             return _f(fieldIndex, values[0]);
280                         }
281                         var b = new %1$s();
282                         for (var v : values) b._encode(v);
283                         _f(fieldIndex, b);
284                         return (T)this;
285                     }
286 
287                     @SuppressWarnings("unchecked")
288                     T _f(int fieldIndex, int... values) {
289                         if (values.length == 1) {
290                             return _f(fieldIndex, values[0]);
291                         }
292                         var b = new %1$s();
293                         for (var v : values) b._encode(v);
294                         _f(fieldIndex, b);
295                         return (T)this;
296                     }
297 
298                     @SuppressWarnings("unchecked")
299                     T _f(int fieldIndex, %1$s value) {
300                         return _f(fieldIndex, value.buf.toByteArray());
301                     }
302 
303                     @SuppressWarnings("unchecked")
304                     T _f(int fieldIndex, IntSupplier value) {
305                         return _f(fieldIndex, value.getAsInt());
306                     }
307                 }
308                 """.formatted(PROTOGEN_BUILDER_CLASS));
309     }
310 
311     static void generateBuilderCode(String parentName, String indent, List<TreeNode> tree, PrintStream out) {
312         for (TreeNode n : tree) {
313             switch (n.token().type()) {
314                 case MESSAGE, FIELD -> {
315                     out.println();
316                     for (String c : n.comments()) out.println(indent + '/' + c);
317                     String name = snakeToCamelCase(n.token().matcher().group("name"));
318                     if (n.token().type() == TokenType.MESSAGE) {
319                         out.println(indent + "public static final class " + name + " extends " + PROTOGEN_BUILDER_CLASS + "<" + name + "> {");
320                         generateBuilderCode(name, indent + "    ", n.nested(), out);
321                         out.println(indent + "}");
322                     } else {
323                         String type = n.token().matcher().group("type");
324                         type = switch (type) {
325                             case "string" -> "String";
326                             case "int32" -> "int";
327                             case "int64" -> "long";
328                             case "uint64" -> "long";
329                             case "bytes" -> "byte[]";
330                             default -> type;
331                         };
332                         if (Character.isLowerCase(type.charAt(0)) && !type.equals("byte[]") && n.token().matcher().group("flag").equals("repeated ")) {
333                             type += "...";
334                         }
335                         String index = n.token().matcher().group("index");
336                         out.println(indent + "public " + parentName + " " + name + "(" + type + " " + name + ") {return _f(" + index + ", " + name + ");}");
337                     }
338                 }
339             }
340         }
341     }
342 
343     static void generateModel(List<TreeNode> tree, PrintStream out) {
344         out.print(COPYRIGHT_NOTICE);
345         out.print("""
346                 package %1$s;
347 
348                 import java.io.RandomAccessFile;
349                 import java.lang.annotation.ElementType;
350                 import java.lang.annotation.Retention;
351                 import java.lang.annotation.RetentionPolicy;
352                 import java.lang.annotation.Target;
353                 import java.lang.reflect.ParameterizedType;
354                 import java.lang.reflect.RecordComponent;
355                 import java.nio.ByteBuffer;
356                 import java.nio.ByteOrder;
357                 import java.nio.channels.FileChannel;
358                 import java.util.ArrayList;
359                 import java.util.Arrays;
360                 import java.util.List;
361                 import java.util.function.IntSupplier;
362                 import java.util.function.Supplier;
363 
364                 import %1$s.%2$s.*;
365 
366                 // Generated from onnx.in.proto
367                 public sealed interface %3$s {
368                 """.formatted(PROTOGEN_PACKAGE, PROTOGEN_CONSTANTS_CLASS, PROTOGEN_MODEL_CLASS));
369         generateModelCode("    ", tree, out);
370         out.print("""
371 
372                     // Implementation
373 
374 
375                     @Retention(RetentionPolicy.RUNTIME)
376                     @Target(ElementType.RECORD_COMPONENT)
377                     @interface f {
378                         int value();
379                     }
380 
381                     private static long decodeVarint(ByteBuffer data) {
382                         long i, shift = 0, value = 0;
383                         do {
384                             value |= ((i = data.get()) & 0x7f) << shift;
385                             shift += 7;
386                         } while ((i & 0x80) != 0);
387                         return value;
388                     }
389 
390                     private static int countVarInts(ByteBuffer data) {
391                         long end  = decodeVarint(data);
392                         int start = data.position();
393                         end += start;
394                         int count = 0;
395                         while (data.position() < end) {
396                             if ((data.get() & 0x80) == 0) count++;
397                         }
398                         data.position(start);
399                         return count;
400                     }
401 
402                     private static int[] readPackedInts(ByteBuffer data) {
403                         var ret = new int[countVarInts(data)];
404                         for (int i = 0; i < ret.length; i++) {
405                             ret[i] = (int)decodeVarint(data);
406                         }
407                         return ret;
408                     }
409 
410                     private static long[] readPackedLongs(ByteBuffer data) {
411                         var ret = new long[countVarInts(data)];
412                         for (int i = 0; i < ret.length; i++) {
413                             ret[i] = decodeVarint(data);
414                         }
415                         return ret;
416                     }
417 
418                     private static float[] readPackedFloats(ByteBuffer data) {
419                         var ret = new float[(int)(decodeVarint(data)/4)];
420                         for (int i = 0; i < ret.length; i++) {
421                             ret[i] = data.getFloat();
422                         }
423                         return ret;
424                     }
425 
426                     private static double[] readPackedDoubles(ByteBuffer data) {
427                         var ret = new double[(int)(decodeVarint(data)/8)];
428                         for (int i = 0; i < ret.length; i++) {
429                             ret[i] = data.getDouble();
430                         }
431                         return ret;
432                     }
433 
434                     private static byte[] readBytes(ByteBuffer data) {
435                         var bytes = new byte[(int)decodeVarint(data)];
436                         data.get(bytes);
437                         return bytes;
438                     }
439 
440                     private static Object readData(Class<?> baseType, boolean packed, ByteBuffer bb) {
441                         if (baseType == Integer.class) {
442                             return (int)decodeVarint(bb);
443                         } else if (baseType == int[].class) {
444                             return packed ? readPackedInts(bb) : new int[]{(int)decodeVarint(bb)};
445                         } else if (baseType == Long.class) {
446                             return decodeVarint(bb);
447                         } else if (baseType == long[].class) {
448                             return packed ? readPackedLongs(bb) : new long[]{decodeVarint(bb)};
449                         } else if (baseType == Float.class) {
450                             return bb.getFloat();
451                         } else if (baseType == float[].class) {
452                             return packed ? readPackedFloats(bb) : new float[] {bb.getFloat()};
453                         } else if (baseType == Double.class) {
454                             return bb.getDouble();
455                         } else if (baseType == double[].class) {
456                             return packed ? readPackedDoubles(bb) : new double[] {bb.getDouble()};
457                         } else if (baseType == byte[].class) {
458                             return readBytes(bb);
459                         } else if (baseType == String.class) {
460                             return new String(readBytes(bb));
461                         } else if (baseType.getEnclosingClass() == %1$s.class) {
462                             int value = (int)decodeVarint(bb);
463                             for (Object cs : baseType.getEnumConstants()) {
464                                 if (cs instanceof IntSupplier is && is.getAsInt() == value) {
465                                     return cs;
466                                 }
467                             }
468                             throw new IllegalArgumentException(baseType.toString());
469                         } else {
470                             var size = decodeVarint(bb);
471                             int limit = bb.limit();
472                             var data = readFrom((Class<Record>)baseType, bb.limit(bb.position() + (int)size));
473                             bb.limit(limit);
474                             return data;
475                         }
476                     }
477 
478                     private static int getRecordFieldIndex(RecordComponent[] rcs, int fieldIndex) {
479                         for (int i = 0; i < rcs.length; i++) {
480                             if (rcs[i].getAnnotation(f.class).value() == fieldIndex) {
481                                 return i;
482                             }
483                         }
484                         throw new IllegalArgumentException("Field index " + fieldIndex + " not found in " + rcs[0].getDeclaringRecord());
485                     }
486 
487                     private static <T> T readFrom(Class<T> type, ByteBuffer bb) {
488                         Object[] fieldsData = new Object[type.getRecordComponents().length];
489                         while (bb.remaining() > 0) {
490                             long tag = decodeVarint(bb);
491                             RecordComponent[] rcs = type.getRecordComponents();
492                             int rfi = getRecordFieldIndex(rcs, (int)tag >> 3);
493                             boolean packed = (tag & 7) == 2;
494                             RecordComponent rc = rcs[rfi];
495                             Class<?> rcType = rc.getType();
496                             if (rcType == List.class) {
497                                 List list;
498                                 if (fieldsData[rfi] instanceof List l) {
499                                     list = l;
500                                 } else {
501                                     list = new ArrayList();
502                                     fieldsData[rfi] = list;
503                                 }
504                                 Class baseType = (Class)((ParameterizedType)rc.getGenericType()).getActualTypeArguments()[0];
505                                 list.add(readData(baseType, packed, bb));
506                             } else {
507                                 fieldsData[rfi] = readData(rcType, packed, bb);
508                             }
509                         }
510                         try {
511                             return (T)type.getDeclaredConstructors()[0].newInstance(fieldsData);
512                         } catch (ReflectiveOperationException e) {
513                             throw new RuntimeException(e);
514                         }
515                     }
516 
517                     private static void print(StringBuilder out, int indent, String name, Object value, boolean skipBigData) throws ReflectiveOperationException {
518                         if (value == null) return;
519                         out.append("  ".repeat(indent)).append(name);
520                         switch (value) {
521                             case List l -> {
522                                 out.append(name.endsWith("s") ? ":" : "s:").append(System.lineSeparator());
523                                 for (var el : l) print(out, indent + 1, "- " + (name.endsWith("s") ? name.substring(0, name.length() - 1) : name), el, skipBigData);
524                             }
525                             case Record r -> {
526                                 out.append(':').append(System.lineSeparator());
527                                 for (var rc : r.getClass().getRecordComponents()) {
528                                     print(out, indent + 2, rc.getName(), rc.getAccessor().invoke(r), skipBigData);
529                                 }
530                             }
531                             case byte[] a ->
532                                 out.append(checkSize(a.length, () -> Arrays.toString(a), skipBigData));
533                             case long[] a ->
534                                 out.append(checkSize(a.length, () -> Arrays.toString(a), skipBigData));
535                             case float[] a ->
536                                 out.append(checkSize(a.length, () -> Arrays.toString(a), skipBigData));
537                             case double[] a ->
538                                 out.append(checkSize(a.length, () -> Arrays.toString(a), skipBigData));
539                             case String s ->
540                                 out.append(": \\"").append(s).append('"').append(System.lineSeparator());
541                             default ->
542                                 out.append(": ").append(value).append(System.lineSeparator());
543                         }
544                     }
545 
546                     static final int SKIP_LIMIT = 1000;
547 
548                     private static String checkSize(int size, Supplier<String> sup, boolean skipBigData) {
549                         return ": " + (skipBigData && size > SKIP_LIMIT ? "# skipped " + size + " values" : sup.get()) + System.lineSeparator();
550                     }
551 
552                     default String toText() {
553                         return toText(true);
554                     }
555 
556                     default String toText(boolean skipBigData) {
557                         try {
558                             var sb = new StringBuilder();
559                             print(sb, 0, getClass().getSimpleName(), this, skipBigData);
560                             return sb.toString();
561                         } catch (ReflectiveOperationException e) {
562                             throw new RuntimeException(e);
563                         }
564                     }
565 
566                     public static %2$s.ModelProto readFrom(byte[] onnxProtoModel) {
567                         return readFrom(ByteBuffer.wrap(onnxProtoModel));
568                     }
569 
570                     public static %2$s.ModelProto readFrom(ByteBuffer onnxProtoModel) {
571                         return readFrom(%2$s.ModelProto.class, onnxProtoModel.order(ByteOrder.LITTLE_ENDIAN));
572                     }
573 
574                     public static void main(String... args) throws Exception {
575                         for (var fName : args) {
576                             try (var in = new RandomAccessFile(fName, "r")) {
577                                 %2$s.ModelProto model = readFrom(in.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, in.length()));
578                                 System.out.println(model.toText());
579                             }
580                         }
581                     }
582                 }
583                 """.formatted(PROTOGEN_CONSTANTS_CLASS, PROTOGEN_MODEL_CLASS));
584     }
585 
586     static void generateModelCode(String indent, List<TreeNode> tree, PrintStream out) {
587         for (TreeNode n : tree) {
588             if (n.token().type() == TokenType.MESSAGE) {
589                 out.println();
590                 for (String c : n.comments()) out.println(indent + '/' + c);
591                 String recordName = n.token().matcher().group("name");
592                 out.println(indent + "public record " + recordName + " (");
593                 boolean first = true;
594                 for (TreeNode nn : n.nested()) {
595                     if (nn.token().type() == TokenType.FIELD) {
596                         if (first) {
597                             first = false;
598                         } else {
599                             out.println(",");
600                         }
601                         out.println();
602                         for (String c : nn.comments()) out.println(indent + "    /" + c);
603                         String name = snakeToCamelCase(nn.token().matcher().group("name"));
604                         String type = nn.token().matcher().group("type");
605                         if (nn.token().matcher().group("flag").equals("repeated ")) {
606                             type = switch (type) {
607                                 case "float" -> "List<float[]>";
608                                 case "double" -> "List<double[]>";
609                                 case "string" -> "List<String>";
610                                 case "int32" -> "List<int[]>";
611                                 case "int64", "uint64" -> "List<long[]>";
612                                 case "bytes" -> "List<byte[]>";
613                                 default -> "List<" + type + ">";
614                             };
615                         } else {
616                             type = switch (type) {
617                                 case "float" -> "Float";
618                                 case "double" -> "Double";
619                                 case "string" -> "String";
620                                 case "int32" -> "Integer";
621                                 case "int64", "uint64" -> "Long";
622                                 case "bytes" -> "byte[]";
623                                 default -> type;
624                             };
625                         }
626                         String index = nn.token().matcher().group("index");
627                         out.print(indent + "    @f(" + index + ") " + type + " " + name);
628                     }
629                 }
630                 out.println(") implements " + PROTOGEN_MODEL_CLASS + " {");
631                 generateModelCode(indent + "    ", n.nested(), out);
632                 out.println(indent + "}");
633             }
634         }
635     }
636 
637     static List<TreeNode> toTree(Iterator<Token> tokens) {
638         List<TreeNode> nodes = new ArrayList<>();
639         List<String> comments = new ArrayList<>();
640         int oneofs = 0;
641         while (tokens.hasNext()) {
642             Token t = tokens.next();
643             switch (t.type()) {
644                 case COMMENT -> comments.add(t.matcher().group("comment"));
645                 case EMPTY -> comments.clear(); // do not merge isolated comment blocks
646                 case ONEOF -> oneofs++; // flat ONEOF
647                 case ENUM_ELEMENT, FIELD, RESERVED, SYNTAX, PACKAGE -> {
648                     if (t.matcher().group("comment") instanceof String c) comments.add(c);
649                     nodes.add(new TreeNode(comments, t, List.of()));
650                     comments = new ArrayList<>();
651                 }
652                 case ENUM, MESSAGE -> {
653                     if (t.matcher().group("comment") instanceof String c) comments.add(c);
654                     nodes.add(new TreeNode(comments, t, toTree(tokens)));
655                     comments = new ArrayList<>();
656                 }
657                 case END -> {
658                     if (oneofs-- == 0) return nodes;
659                 }
660             }
661         }
662         return nodes;
663     }
664 
665     static final Pattern SNAKE = Pattern.compile("_([a-z])");
666 
667     static String snakeToCamelCase(String name) {
668         return SNAKE.matcher(name).replaceAll(mr -> mr.group(1).toUpperCase());
669     }
670 
671     public static void main(String[] args) throws Exception {
672         List<TreeNode> tree = toTree(Files.lines(Path.of("opgen/onnx.in.proto")).map(ProtoGen::lineToToken).iterator());
673         try (var constants = new PrintStream(new FileOutputStream(SOURCES_PATH + PROTOGEN_CONSTANTS_CLASS + ".java"))) {
674             generateConstants(tree, constants);
675         }
676         try (var builder = new PrintStream(new FileOutputStream(SOURCES_PATH + PROTOGEN_BUILDER_CLASS + ".java"))) {
677             generateBuilder(tree, builder);
678         }
679         try (var model = new PrintStream(new FileOutputStream(SOURCES_PATH + PROTOGEN_MODEL_CLASS + ".java"))) {
680             generateModel(tree, model);
681         }
682     }
683 }