1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.codebuilders;
 26 
 27 import hat.buffer.HAType;
 28 import hat.device.DeviceType;
 29 import hat.dialect.HATF16VarOp;
 30 import hat.dialect.HATLocalVarOp;
 31 import hat.dialect.HATMemoryOp;
 32 import hat.dialect.HATPrivateInitVarOp;
 33 import hat.dialect.HATPrivateVarOp;
 34 import hat.dialect.HATVectorBinaryOp;
 35 import hat.dialect.HATVectorLoadOp;
 36 import hat.dialect.HATVectorVarOp;
 37 import hat.ifacemapper.MappableIface;
 38 import hat.optools.FuncOpParams;
 39 import hat.optools.OpTk;
 40 import hat.util.StreamMutable;
 41 import jdk.incubator.code.Op;
 42 import jdk.incubator.code.dialect.core.CoreOp;
 43 import jdk.incubator.code.dialect.java.ClassType;
 44 import jdk.incubator.code.dialect.java.JavaOp;
 45 import jdk.incubator.code.dialect.java.JavaType;
 46 import jdk.incubator.code.dialect.java.PrimitiveType;
 47 
 48 import java.util.List;
 49 
 50 public abstract class C99HATCodeBuilderContext<T extends C99HATCodeBuilderContext<T>> extends C99HATCodeBuilder<T> implements BabylonCoreOpBuilder<T> {
 51 
 52   /*  public final  T type(ScopedCodeBuilderContext buildContext, JavaType javaType) {
 53         if (OpTk.isAssignable(buildContext.lookup, javaType, MappableIface.class)
 54                         && javaType instanceof ClassType classType) {
 55             suffix_t(classType).asterisk();
 56         } else {
 57             typeName(javaType.toBasicType().toString());
 58         }
 59         return self();
 60     } */
 61 
 62     @Override
 63     public final T varLoadOp(ScopedCodeBuilderContext buildContext, CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
 64         Op resolve = buildContext.scope.resolve(varLoadOp.operands().getFirst());
 65         switch (resolve) {
 66             case CoreOp.VarOp $ -> varName($);
 67             case HATMemoryOp $ -> varName($);
 68             case HATVectorVarOp $ -> varName($);
 69             case HATVectorLoadOp $ -> varName($);
 70             case HATVectorBinaryOp $ -> varName($);
 71             case HATF16VarOp $ -> varName($);
 72             case null, default -> {
 73             }
 74         }
 75         return self();
 76     }
 77 
 78     @Override
 79     public final T varStoreOp(ScopedCodeBuilderContext buildContext, CoreOp.VarAccessOp.VarStoreOp varStoreOp) {
 80         Op op = buildContext.scope.resolve(varStoreOp.operands().getFirst());
 81         // When the op is intended to operate as VarOp, then we need to include it in the following switch.
 82         // This is because HAT has its own dialect, and some of the Ops operate on HAT Types (not included in the Java
 83         // dialect). For instance, private data structures, local data structures, vector types, etc.
 84         switch (op) {
 85             case CoreOp.VarOp varOp -> varName(varOp).equals();
 86             case HATF16VarOp hatf16VarOp -> varName(hatf16VarOp).equals();
 87             case HATPrivateInitVarOp hatPrivateInitVarOp -> varName(hatPrivateInitVarOp).equals();
 88             case HATPrivateVarOp hatPrivateVarOp -> varName(hatPrivateVarOp).equals();
 89             case HATLocalVarOp hatLocalVarOp -> varName(hatLocalVarOp).equals();
 90             case HATVectorVarOp hatVectorVarOp -> varName(hatVectorVarOp).equals();
 91             case null, default -> {
 92             }
 93         }
 94         parenthesisIfNeeded(buildContext, varStoreOp, ((Op.Result)varStoreOp.operands().get(1)).op());
 95         return self();
 96     }
 97 
 98 
 99     private void varDeclarationWithInitialization(ScopedCodeBuilderContext buildContext, CoreOp.VarOp varOp) {
100         if (buildContext.isVarOpFinal(varOp)) {
101             constKeyword().space();
102         }
103         type(buildContext, (JavaType) varOp.varValueType()).space().varName(varOp).space().equals().space();
104         parenthesisIfNeeded(buildContext, varOp, ((Op.Result)varOp.operands().getFirst()).op());
105     }
106 
107     @Override
108     public T varOp(ScopedCodeBuilderContext buildContext, CoreOp.VarOp varOp) {
109         if (varOp.isUninitialized()) {
110             type(buildContext, (JavaType) varOp.varValueType()).space().varName(varOp);
111         } else {
112             varDeclarationWithInitialization(buildContext, varOp);
113         }
114         return self();
115     }
116 
117 
118 
119     @Override
120     public T varOp(ScopedCodeBuilderContext buildContext, CoreOp.VarOp varOp, OpTk.ParamVar paramVar) {
121         return self();
122     }
123 
124     @Override
125     public T fieldLoadOp(ScopedCodeBuilderContext buildContext, JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
126         if (fieldLoadOp.operands().isEmpty() && fieldLoadOp.result().type() instanceof PrimitiveType) {
127             Object value = OpTk.getStaticFinalPrimitiveValue(buildContext.lookup,fieldLoadOp);
128             literal(value.toString());
129         } else {
130             throw new IllegalStateException("What is this field load ?" + fieldLoadOp);
131         }
132         return self();
133     }
134 
135     @Override
136     public T fieldStoreOp(ScopedCodeBuilderContext buildContext, JavaOp.FieldAccessOp.FieldStoreOp fieldStoreOp) {
137         return self();
138     }
139 
140 
141     @Override
142     public T unaryOp(ScopedCodeBuilderContext buildContext, JavaOp.UnaryOp unaryOp) {
143         symbol(unaryOp).parenthesisIfNeeded(buildContext, unaryOp, ((Op.Result)unaryOp.operands().getFirst()).op());
144         return self();
145     }
146 
147     @Override
148     public T binaryOp(ScopedCodeBuilderContext buildContext, JavaOp.BinaryOp binaryOp) {
149         parenthesisIfNeeded(buildContext, binaryOp, OpTk.lhsResult(binaryOp).op());
150         symbol(binaryOp);
151         parenthesisIfNeeded(buildContext, binaryOp, OpTk.rhsResult(binaryOp).op());
152         return self();
153     }
154 
155 
156     @Override
157     public T conditionalOp(ScopedCodeBuilderContext buildContext, JavaOp.JavaConditionalOp logicalOp) {
158         OpTk.lhsOps(logicalOp).stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o ->  recurse(buildContext, o));
159         space().symbol(logicalOp).space();
160         OpTk.rhsOps(logicalOp).stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o-> recurse(buildContext, o));
161         return self();
162     }
163 
164     @Override
165     public T binaryTestOp(ScopedCodeBuilderContext buildContext, JavaOp.BinaryTestOp binaryTestOp) {
166         parenthesisIfNeeded(buildContext, binaryTestOp, OpTk.lhsResult(binaryTestOp).op());
167         symbol(binaryTestOp);
168         parenthesisIfNeeded(buildContext, binaryTestOp, OpTk.rhsResult(binaryTestOp).op());
169         return self();
170     }
171 
172     @Override
173     public T convOp(ScopedCodeBuilderContext buildContext, JavaOp.ConvOp convOp) {
174         if (convOp.resultType() == JavaType.DOUBLE) {
175             paren(_ -> type(buildContext,JavaType.FLOAT)); // why double to float?
176         } else {
177             paren(_ -> type(buildContext,(JavaType)convOp.resultType()));
178         }
179         parenthesisIfNeeded(buildContext, convOp, OpTk.result(convOp).op());
180         return self();
181     }
182 
183     @Override
184     public T constantOp(ScopedCodeBuilderContext buildContext, CoreOp.ConstantOp constantOp) {
185         if (constantOp.value() == null) {
186             nullConst();
187         } else {
188             literal(constantOp.value().toString());
189         }
190         return self();
191     }
192 
193     @Override
194     public T yieldOp(ScopedCodeBuilderContext buildContext, CoreOp.YieldOp yieldOp) {
195         if (yieldOp.operands().getFirst() instanceof Op.Result result) {
196             recurse(buildContext, result.op());
197         }
198         return self();
199     }
200 
201     @Override
202     public T lambdaOp(ScopedCodeBuilderContext buildContext, JavaOp.LambdaOp lambdaOp) {
203         return comment("/*LAMBDA*/");
204     }
205 
206     @Override
207     public T tupleOp(ScopedCodeBuilderContext buildContext, CoreOp.TupleOp tupleOp) {
208         commaSpaceSeparated(tupleOp.operands(),operand->{
209             if (operand instanceof Op.Result result) {
210                 recurse(buildContext, result.op());
211             } else {
212                 comment("/*nothing to tuple*/");
213             }
214         });
215         return self();
216     }
217 
218     @Override
219     public T funcCallOp(ScopedCodeBuilderContext buildContext, CoreOp.FuncCallOp funcCallOp) {
220         funcName(funcCallOp);
221         paren(_ ->
222             commaSpaceSeparated(
223                     funcCallOp.operands().stream().filter(e->e instanceof Op.Result ).map(e->(Op.Result)e),
224                     result -> recurse(buildContext,result.op())
225             )
226         );
227         return self();
228     }
229 
230     @Override
231     public T labeledOp(ScopedCodeBuilderContext buildContext, JavaOp.LabeledOp labeledOp) {
232         var labelNameOp = labeledOp.bodies().getFirst().entryBlock().ops().getFirst();
233         CoreOp.ConstantOp constantOp = (CoreOp.ConstantOp) labelNameOp;
234         literal(constantOp.value().toString()).colon().nl();
235         var forLoopOp = labeledOp.bodies().getFirst().entryBlock().ops().get(1);
236         recurse(buildContext,forLoopOp);
237         return self();
238     }
239     @Override
240     public T breakOp(ScopedCodeBuilderContext buildContext, JavaOp.BreakOp breakOp) {
241         breakKeyword();
242         if (!breakOp.operands().isEmpty() && breakOp.operands().getFirst() instanceof Op.Result result) {
243             space();
244             if (result.op() instanceof CoreOp.ConstantOp c) {
245                 literal(c.value().toString());
246             }
247         }
248         return self();
249     }
250     @Override
251     public T continueOp(ScopedCodeBuilderContext buildContext, JavaOp.ContinueOp continueOp) {
252         if (!continueOp.operands().isEmpty()
253                 && continueOp.operands().getFirst() instanceof Op.Result result
254                 && result.op() instanceof CoreOp.ConstantOp c
255         ) {
256             continueKeyword().space().literal(c.value().toString());
257         } else if (buildContext.scope.parent instanceof ScopedCodeBuilderContext.ForScope) {
258             // nope
259         } else {
260             continueKeyword();
261         }
262 
263         return self();
264     }
265 
266     @Override
267     public T ifOp(ScopedCodeBuilderContext buildContext, JavaOp.IfOp ifOp) {
268         buildContext.ifScope(ifOp, () -> {
269             var lastWasBody = StreamMutable.of(false);
270             var i = StreamMutable.of(0);
271             // We probably should just use a regular for loop here ;)
272             ifOp.bodies().forEach(b->{
273                 int idx = i.get();
274                 if (b.yieldType() instanceof JavaType javaType && javaType == JavaType.VOID) {
275                     if (ifOp.bodies().size() > idx && ifOp.bodies().get(idx).entryBlock().ops().size() > 1){
276                         if (lastWasBody.get()) {
277                             elseKeyword();
278                         }
279                         braceNlIndented(_ ->
280                                         nlSeparated(OpTk.statements(ifOp.bodies().get(idx).entryBlock()),
281                                         root-> statement(buildContext,root)
282                                         ));
283                     }
284                     lastWasBody.set(true);
285                 } else {
286                     if (idx>0) {
287                         elseKeyword().space();
288                     }
289                     ifKeyword().paren(_ ->
290                             ifOp.bodies().get(idx).entryBlock()            // get the entryblock if bodies[c.value]
291                                     .ops().stream().filter(o->o instanceof CoreOp.YieldOp) // we want all the yields
292                                     .forEach((yield) -> recurse(buildContext, yield))
293                     );
294                     lastWasBody.set(false);
295                 }
296                 i.set(i.get()+1);
297             });
298         });
299         return self();
300     }
301 
302     @Override
303     public T whileOp(ScopedCodeBuilderContext buildContext, JavaOp.WhileOp whileOp) {
304         whileKeyword().paren(_ ->
305                 OpTk.condBlock(whileOp).ops().stream().filter(o -> o instanceof CoreOp.YieldOp)
306                         .forEach(o -> recurse(buildContext, o))
307         );
308         braceNlIndented(_ ->
309                         nlSeparated(OpTk.statements(whileOp),
310                         statement->statement(buildContext,statement)
311                 )
312         );
313         return self();
314     }
315 
316     @Override
317     public T forOp(ScopedCodeBuilderContext buildContext, JavaOp.ForOp forOp) {
318         buildContext.forScope(forOp, () ->
319                 forKeyword().paren(_ -> {
320                     OpTk.initBlock(forOp).ops().stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o -> recurse(buildContext, o));
321                     semicolon().space();
322                     OpTk.condBlock(forOp).ops().stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o -> recurse(buildContext, o));
323                     semicolon().space();
324                     commaSpaceSeparated(
325                             OpTk.statements(OpTk.mutateBlock(forOp)),
326                             op -> recurse(buildContext, op)
327                     );
328                 }).braceNlIndented(_ ->
329                             nlSeparated(OpTk.statements(forOp),
330                                     statement ->statement(buildContext,statement)
331                         )
332                 )
333         );
334         return self();
335     }
336 
337     public abstract  T atomicInc(ScopedCodeBuilderContext buildContext, Op.Result instanceResult, String name);
338 
339     @Override
340     public T invokeOp(ScopedCodeBuilderContext buildContext, JavaOp.InvokeOp invokeOp) {
341         if (OpTk.isIfaceBufferMethod(buildContext.lookup, invokeOp)
342                 || OpTk.isInvokeDescriptorSubtypeOfAnyMatch(invokeOp, List.of(HAType.class, DeviceType.class))) {
343             if (invokeOp.operands().size() == 1
344                     && OpTk.funcName(invokeOp) instanceof String funcName
345                     && funcName.startsWith("atomic")
346                     && funcName.endsWith("Inc")
347                     && OpTk.javaReturnType(invokeOp).equals(JavaType.INT)) {
348                 // this is a bit of a hack for atomics.
349                 if (invokeOp.operands().getFirst() instanceof Op.Result instanceResult) {
350                     atomicInc(buildContext, instanceResult, funcName.substring(0, funcName.length() - "Inc".length()));
351                 } else {
352                     throw new IllegalStateException("bad atomic");
353                 }
354             } else {
355 
356                if (invokeOp.operands().getFirst() instanceof Op.Result instanceResult) {
357                 /*
358                 We have three types of returned values from an ifaceBuffer
359                 A primitive
360                     int id = stage.firstTreeId(); -> stage->firstTreeId;
361 
362                 Or a sub interface from an array
363                      Tree tree = cascade.tree(treeIdx); -> Tree_t * tree = &cascade->tree[treeIdx]
364                                                         ->               = cascade->tree + treeIdx;
365 
366                 Or a sub interface from a field
367 
368                 var left = feature.left();              ->  LinkOrValue_t * left= &feature->left
369 
370                                 -
371                     if (left.hasValue()) {                  left->hasValue
372                         sum += left.anon().value();         left->anon.value;
373                         feature = null; // loop ends
374                     } else {
375                         feature = cascade.feature(tree.firstFeatureId() + left.anon().featureId());
376                     }
377                  sumOfThisStage += left.anon().value();
378 
379 
380                 For a primitive we know that the accessor refers to a field so we just  map
381                          stage.firstTreeId() -> stage->firstTreeId;
382 
383                 For the sub interface we need to treat the call
384                           cascade.tree(treeIdx);
385 
386                 As an array index into cascade->tree[] that returns a typedef of Tree_t
387                 so we need to prefix with an & to return a Tree_t ptr
388                           &cascade->tree[treeIdx]
389 
390                  of course we could return
391                           cascade->tree + treeIdx;
392                  */
393 
394                    // TODO: extra parenthesis to be removed if we have a dialect to express iface memory access
395                    boolean needExtraParenthesis = OpTk.needExtraParenthesis(invokeOp);
396                    when(needExtraParenthesis, _ -> oparen());
397 
398                    if (OpTk.javaReturnType(invokeOp) instanceof ClassType) { // isAssignable?
399                        ampersand();
400                         /* This is way more complicated I think we need to determine the expression type.
401                          * sumOfThisStage=sumOfThisStage+&left->anon->value; from    sumOfThisStage += left.anon().value();
402                          */
403                    }
404 
405                    recurse(buildContext, instanceResult.op());
406 
407                    // TODO: extra parenthesis to be removed if we have a dialect to express iface memory access
408                    when(needExtraParenthesis, _ -> cparen());
409 
410                     // Check if the varOpLoad that could follow corresponds to a local/private type
411                     boolean isLocalOrPrivateDS = false;
412                     if (instanceResult.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
413                         Op resolve = buildContext.scope.resolve(varLoadOp.operands().getFirst());
414                         //if (localDataStructures.contains(resolve)) {
415                         if (resolve instanceof HATMemoryOp) {
416                             isLocalOrPrivateDS = true;
417                         }
418                     }
419 
420                     either(isLocalOrPrivateDS, CodeBuilder::dot, CodeBuilder::rarrow);
421 
422                     funcName(invokeOp);
423 
424                     if (OpTk.javaReturnTypeIsVoid(invokeOp)) {
425                         //   setter
426                         switch (invokeOp.operands().size()) {
427                             case 2: {
428                                 if (invokeOp.operands().get(1) instanceof Op.Result result1) {
429                                     equals().recurse(buildContext, result1.op());
430                                 } else {
431                                     throw new IllegalStateException("How ");
432                                 }
433                                 break;
434                             }
435                             case 3: {
436                                 if (invokeOp.operands().get(1) instanceof Op.Result result1
437                                         && invokeOp.operands().get(2) instanceof Op.Result result2) {
438                                     sbrace(_ -> recurse(buildContext, result1.op()));
439                                     equals().recurse(buildContext, result2.op());
440                                 } else {
441                                     throw new IllegalStateException("How ");
442                                 }
443                                 break;
444                             }
445                             default: {
446                                 throw new IllegalStateException("How ");
447                             }
448                         }
449                     } else {
450                         if (OpTk.resultOrNull(invokeOp,1) instanceof Op.Result result1) {
451                             sbrace(_ -> recurse(buildContext, result1.op()));
452                         } else {
453                             // This is a simple usage.   So scaleTable->multiScaleAccumRange
454                         }
455                     }
456                 } else {
457                     throw new IllegalStateException("[Illegal] Expected a parameter for the InvokOpWrapper Node");
458                 }
459             }
460         } else {
461             // General case
462             funcName(invokeOp).paren(_ ->
463                     commaSpaceSeparated(invokeOp.operands(),
464                             op -> {if (op instanceof Op.Result result) {recurse(buildContext, result.op());}
465                     })
466             );
467         }
468         return self();
469     }
470 
471 
472     @Override
473     public T conditionalExpressionOp(ScopedCodeBuilderContext buildContext, JavaOp.ConditionalExpressionOp ternaryOp) {
474         OpTk.condBlock(ternaryOp).ops().stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o -> recurse(buildContext, o));
475         questionMark();
476         OpTk.thenBlock(ternaryOp).ops().stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o -> recurse(buildContext, o));
477         colon();
478         OpTk.elseBlock(ternaryOp).ops().stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o -> recurse(buildContext, o));
479         return self();
480     }
481 
482     /**
483      * Wrap paren() of precedence of op is higher than parent.
484      *
485      * @param buildContext
486      * @param parent
487      * @param child
488      */
489     @Override
490     public T parenthesisIfNeeded(ScopedCodeBuilderContext buildContext, Op parent, Op child) {
491         return parenWhen(OpTk.needsParenthesis(parent,child), _ -> recurse(buildContext, child));
492     }
493 
494     @Override
495     public T returnOp(ScopedCodeBuilderContext buildContext, CoreOp.ReturnOp returnOp) {
496         returnKeyword().when(!returnOp.operands().isEmpty(),
497                         $-> $.space().parenthesisIfNeeded(buildContext, returnOp, OpTk.result(returnOp).op())
498                 );
499         return self();
500     }
501 
502     public T statement(ScopedCodeBuilderContext buildContext,Op op) {
503         recurse(buildContext, op);
504         if (switch (op){
505                 case JavaOp.ForOp _ -> false;
506                 case JavaOp.WhileOp _ -> false;
507                 case JavaOp.IfOp _ -> false;
508                 case JavaOp.LabeledOp _ -> false;
509                 case JavaOp.YieldOp _ -> false;
510                 case CoreOp.TupleOp _ ->false;
511                 default -> true;
512             }
513         ){
514             semicolon();
515         }
516         return self();
517     }
518 
519 
520 
521     public T declareParam(ScopedCodeBuilderContext buildContext, FuncOpParams.Info param){
522         return  type(buildContext,(JavaType) param.parameter.type()).space().varName(param.varOp);
523     }
524 
525 
526 }