1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.codebuilders;
 26 
 27 import hat.dialect.HATF16Op;
 28 import hat.dialect.HATVectorOp;
 29 import hat.optools.IfaceBufferPattern;
 30 import hat.optools.RefactorMe;
 31 import hat.types.HAType;
 32 import hat.device.DeviceType;
 33 import hat.dialect.HATMemoryVarOp;
 34 import optkl.FuncOpParams;
 35 import optkl.OpTkl;
 36 import optkl.ParamVar;
 37 import optkl.util.ops.Precedence;
 38 import optkl.util.Regex;
 39 import optkl.util.StreamMutable;
 40 import jdk.incubator.code.Op;
 41 import jdk.incubator.code.dialect.core.CoreOp;
 42 import jdk.incubator.code.dialect.java.ClassType;
 43 import jdk.incubator.code.dialect.java.JavaOp;
 44 import jdk.incubator.code.dialect.java.JavaType;
 45 import jdk.incubator.code.dialect.java.PrimitiveType;
 46 import optkl.codebuilders.BabylonCoreOpBuilder;
 47 import optkl.codebuilders.CodeBuilder;
 48 import optkl.codebuilders.ScopedCodeBuilderContext;
 49 
 50 import static optkl.OpTkl.condBlock;
 51 import static optkl.OpTkl.elseBlock;
 52 import static optkl.OpTkl.getStaticFinalPrimitiveValue;
 53 import static optkl.OpTkl.initBlock;
 54 import static optkl.OpTkl.javaReturnType;
 55 import static optkl.OpTkl.javaReturnTypeIsVoid;
 56 import static optkl.OpTkl.lhsOps;
 57 import static optkl.OpTkl.lhsResult;
 58 import static optkl.OpTkl.mutateBlock;
 59 import static optkl.OpTkl.needExtraParenthesis;
 60 import static optkl.OpTkl.result;
 61 import static optkl.OpTkl.resultOrNull;
 62 import static optkl.OpTkl.rhsOps;
 63 import static optkl.OpTkl.rhsResult;
 64 import static optkl.OpTkl.thenBlock;
 65 
 66 public abstract class C99HATCodeBuilderContext<T extends C99HATCodeBuilderContext<T>> extends C99HATCodeBuilder<T>
 67         implements BabylonCoreOpBuilder<T, ScopedCodeBuilderContext> {
 68 
 69     @Override
 70     public final T varLoadOp(ScopedCodeBuilderContext buildContext, CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
 71         Op resolve = buildContext.scope.resolve(varLoadOp.operands().getFirst());
 72         switch (resolve) {
 73             case CoreOp.VarOp $ -> varName($);
 74             case HATMemoryVarOp $ -> varName($);
 75             case HATVectorOp.HATVectorVarOp $ -> varName($);
 76             case HATVectorOp.HATVectorLoadOp $ -> varName($);
 77             case HATVectorOp.HATVectorBinaryOp $ -> varName($);
 78             case HATF16Op.HATF16VarOp $ -> varName($);
 79             case null, default -> {
 80             }
 81         }
 82         return self();
 83     }
 84 
 85     @Override
 86     public final T varStoreOp(ScopedCodeBuilderContext buildContext, CoreOp.VarAccessOp.VarStoreOp varStoreOp) {
 87         Op op = buildContext.scope.resolve(varStoreOp.operands().getFirst());
 88 
 89         //TODO see if VarLikeOp marker interface fixes this
 90 
 91         // When the op is intended to operate as VarOp, then we need to include it in the following switch.
 92         // This is because HAT has its own dialect, and some of the Ops operate on HAT Types (not included in the Java
 93         // dialect). For instance, private data structures, local data structures, vector types, etc.
 94         switch (op) {
 95             case CoreOp.VarOp varOp -> varName(varOp).equals();
 96             case HATF16Op.HATF16VarOp hatf16VarOp -> varName(hatf16VarOp).equals();
 97             case HATMemoryVarOp.HATPrivateInitVarOp hatPrivateInitVarOp -> varName(hatPrivateInitVarOp).equals();
 98             case HATMemoryVarOp.HATPrivateVarOp hatPrivateVarOp -> varName(hatPrivateVarOp).equals();
 99             case HATMemoryVarOp.HATLocalVarOp hatLocalVarOp -> varName(hatLocalVarOp).equals();
100             case HATVectorOp.HATVectorVarOp hatVectorVarOp -> varName(hatVectorVarOp).equals();
101             case null, default -> {
102             }
103         }
104         parenthesisIfNeeded(buildContext, varStoreOp, ((Op.Result)varStoreOp.operands().get(1)).op());
105         return self();
106     }
107 
108     private void varDeclarationWithInitialization(ScopedCodeBuilderContext buildContext, CoreOp.VarOp varOp) {
109         if (buildContext.isVarOpFinal(varOp)) {
110             constKeyword().space();
111         }
112         type(buildContext, (JavaType) varOp.varValueType()).space().varName(varOp).space().equals().space();
113         parenthesisIfNeeded(buildContext, varOp, ((Op.Result)varOp.operands().getFirst()).op());
114     }
115 
116     @Override
117     public T varOp(ScopedCodeBuilderContext buildContext, CoreOp.VarOp varOp) {
118         if (varOp.isUninitialized()) {
119             type(buildContext, (JavaType) varOp.varValueType()).space().varName(varOp);
120         } else {
121             varDeclarationWithInitialization(buildContext, varOp);
122         }
123         return self();
124     }
125 
126     @Override
127     public T varOp(ScopedCodeBuilderContext buildContext, CoreOp.VarOp varOp, ParamVar paramVar) {
128         varName(varOp);
129         return self();
130     }
131 
132     @Override
133     public T fieldLoadOp(ScopedCodeBuilderContext buildContext, JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
134         if (fieldLoadOp.operands().isEmpty() && fieldLoadOp.result().type() instanceof PrimitiveType) {
135             Object value = getStaticFinalPrimitiveValue(buildContext.lookup,fieldLoadOp);
136             literal(value.toString());
137         } else {
138             throw new IllegalStateException("What is this field load ?" + fieldLoadOp);
139         }
140         return self();
141     }
142 
143     @Override
144     public T fieldStoreOp(ScopedCodeBuilderContext buildContext, JavaOp.FieldAccessOp.FieldStoreOp fieldStoreOp) {
145         return self();
146     }
147 
148 
149     @Override
150     public T unaryOp(ScopedCodeBuilderContext buildContext, JavaOp.UnaryOp unaryOp) {
151         symbol(unaryOp).parenthesisIfNeeded(buildContext, unaryOp, ((Op.Result)unaryOp.operands().getFirst()).op());
152         return self();
153     }
154 
155     @Override
156     public T binaryOp(ScopedCodeBuilderContext buildContext, JavaOp.BinaryOp binaryOp) {
157         parenthesisIfNeeded(buildContext, binaryOp, lhsResult(binaryOp).op());
158         symbol(binaryOp);
159         parenthesisIfNeeded(buildContext, binaryOp, rhsResult(binaryOp).op());
160         return self();
161     }
162 
163 
164     @Override
165     public T conditionalOp(ScopedCodeBuilderContext buildContext, JavaOp.JavaConditionalOp logicalOp) {
166         lhsOps(logicalOp).stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o ->  recurse(buildContext, o));
167         space().symbol(logicalOp).space();
168         rhsOps(logicalOp).stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o-> recurse(buildContext, o));
169         return self();
170     }
171 
172     @Override
173     public T binaryTestOp(ScopedCodeBuilderContext buildContext, JavaOp.BinaryTestOp binaryTestOp) {
174         parenthesisIfNeeded(buildContext, binaryTestOp, lhsResult(binaryTestOp).op());
175         symbol(binaryTestOp);
176         parenthesisIfNeeded(buildContext, binaryTestOp, rhsResult(binaryTestOp).op());
177         return self();
178     }
179 
180     @Override
181     public T convOp(ScopedCodeBuilderContext buildContext, JavaOp.ConvOp convOp) {
182         if (convOp.resultType() == JavaType.DOUBLE) {
183             paren(_ -> type(buildContext,JavaType.FLOAT)); // why double to float?
184         } else {
185             paren(_ -> type(buildContext,(JavaType)convOp.resultType()));
186         }
187         parenthesisIfNeeded(buildContext, convOp, result(convOp).op());
188         return self();
189     }
190 
191     @Override
192     public T constantOp(ScopedCodeBuilderContext buildContext, CoreOp.ConstantOp constantOp) {
193         if (constantOp.value() == null) {
194             nullConst();
195         } else {
196             literal(constantOp.value().toString());
197         }
198         return self();
199     }
200 
201     @Override
202     public T yieldOp(ScopedCodeBuilderContext buildContext, CoreOp.YieldOp yieldOp) {
203         if (yieldOp.operands().getFirst() instanceof Op.Result result) {
204             recurse(buildContext, result.op());
205         }
206         return self();
207     }
208 
209     @Override
210     public T lambdaOp(ScopedCodeBuilderContext buildContext, JavaOp.LambdaOp lambdaOp) {
211         return comment("/*LAMBDA*/");
212     }
213 
214     @Override
215     public T tupleOp(ScopedCodeBuilderContext buildContext, CoreOp.TupleOp tupleOp) {
216         commaSpaceSeparated(tupleOp.operands(),operand->{
217             if (operand instanceof Op.Result result) {
218                 recurse(buildContext, result.op());
219             } else {
220                 comment("/*nothing to tuple*/");
221             }
222         });
223         return self();
224     }
225 
226     @Override
227     public T funcCallOp(ScopedCodeBuilderContext buildContext, CoreOp.FuncCallOp funcCallOp) {
228         funcName(funcCallOp);
229         paren(_ ->
230             commaSpaceSeparated(
231                     funcCallOp.operands().stream().filter(e->e instanceof Op.Result ).map(e->(Op.Result)e),
232                     result -> recurse(buildContext,result.op())
233             )
234         );
235         return self();
236     }
237 
238     @Override
239     public T labeledOp(ScopedCodeBuilderContext buildContext, JavaOp.LabeledOp labeledOp) {
240         var labelNameOp = labeledOp.bodies().getFirst().entryBlock().ops().getFirst();
241         CoreOp.ConstantOp constantOp = (CoreOp.ConstantOp) labelNameOp;
242         literal(constantOp.value().toString()).colon().nl();
243         var forLoopOp = labeledOp.bodies().getFirst().entryBlock().ops().get(1);
244         recurse(buildContext,forLoopOp);
245         return self();
246     }
247 
248     @Override
249     public T breakOp(ScopedCodeBuilderContext buildContext, JavaOp.BreakOp breakOp) {
250         breakKeyword();
251         if (!breakOp.operands().isEmpty() && breakOp.operands().getFirst() instanceof Op.Result result) {
252             space();
253             if (result.op() instanceof CoreOp.ConstantOp c) {
254                 literal(c.value().toString());
255             }
256         }
257         return self();
258     }
259 
260     @Override
261     public T continueOp(ScopedCodeBuilderContext buildContext, JavaOp.ContinueOp continueOp) {
262         if (!continueOp.operands().isEmpty()
263                 && continueOp.operands().getFirst() instanceof Op.Result result
264                 && result.op() instanceof CoreOp.ConstantOp c
265         ) {
266             continueKeyword().space().literal(c.value().toString());
267         } else if (buildContext.scope.parent instanceof ScopedCodeBuilderContext.ForScope) {
268             // nope
269         } else {
270             continueKeyword();
271         }
272 
273         return self();
274     }
275 
276     @Override
277     public T ifOp(ScopedCodeBuilderContext buildContext, JavaOp.IfOp ifOp) {
278         buildContext.ifScope(ifOp, () -> {
279             var lastWasBody = StreamMutable.of(false);
280             var i = StreamMutable.of(0);
281             // We probably should just use a regular for loop here ;)
282             ifOp.bodies().forEach(b->{
283                 int idx = i.get();
284                 if (b.yieldType() instanceof JavaType javaType && javaType == JavaType.VOID) {
285                     if (ifOp.bodies().size() > idx && ifOp.bodies().get(idx).entryBlock().ops().size() > 1){
286                         if (lastWasBody.get()) {
287                             elseKeyword();
288                         }
289                         braceNlIndented(_ ->
290                                         nlSeparated(OpTkl.statements(ifOp.bodies().get(idx).entryBlock()),
291                                         root-> statement(buildContext,root)
292                                         ));
293                     }
294                     lastWasBody.set(true);
295                 } else {
296                     if (idx>0) {
297                         elseKeyword().space();
298                     }
299                     ifKeyword().paren(_ ->
300                             ifOp.bodies().get(idx).entryBlock()            // get the entryblock if bodies[c.value]
301                                     .ops().stream().filter(o->o instanceof CoreOp.YieldOp) // we want all the yields
302                                     .forEach((yield) -> recurse(buildContext, yield))
303                     );
304                     lastWasBody.set(false);
305                 }
306                 i.set(i.get()+1);
307             });
308         });
309         return self();
310     }
311 
312     @Override
313     public T whileOp(ScopedCodeBuilderContext buildContext, JavaOp.WhileOp whileOp) {
314         whileKeyword().paren(_ ->
315                 condBlock(whileOp).ops().stream().filter(o -> o instanceof CoreOp.YieldOp)
316                         .forEach(o -> recurse(buildContext, o))
317         );
318         braceNlIndented(_ ->
319                         nlSeparated(OpTkl.loopBodyStatements(whileOp),
320                         statement->statement(buildContext,statement)
321                 )
322         );
323         return self();
324     }
325 
326     @Override
327     public T forOp(ScopedCodeBuilderContext buildContext, JavaOp.ForOp forOp) {
328         buildContext.forScope(forOp, () ->
329                 forKeyword().paren(_ -> {
330                     initBlock(forOp).ops().stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o -> recurse(buildContext, o));
331                     semicolon().space();
332                     condBlock(forOp).ops().stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o -> recurse(buildContext, o));
333                     semicolon().space();
334                     commaSpaceSeparated(
335                             OpTkl.statements(mutateBlock(forOp)),
336                             op -> recurse(buildContext, op)
337                     );
338                 }).braceNlIndented(_ ->
339                             nlSeparated(OpTkl.loopBodyStatements(forOp),
340                                     statement ->statement(buildContext,statement)
341                         )
342                 )
343         );
344         return self();
345     }
346 
347     public abstract  T atomicInc(ScopedCodeBuilderContext buildContext, Op.Result instanceResult, String name);
348 
349     static Regex atomicInc = Regex.of("(atomic.*)Inc");
350 
351     @Override
352     public T invokeOp(ScopedCodeBuilderContext buildContext, JavaOp.InvokeOp invokeOp) {
353         if (IfaceBufferPattern.isInvokeOp(buildContext.lookup, invokeOp)
354                 || RefactorMe.isInvokeDescriptorSubtypeOfAnyMatch(buildContext.lookup,invokeOp, HAType.class, DeviceType.class)) {
355             if (invokeOp.operands().size() == 1
356                    // && OpTk.funcName(invokeOp) instanceof String funcName
357                     && atomicInc.is(OpTkl.funcName(invokeOp)) instanceof Regex.Match matcher
358                     && javaReturnType(invokeOp).equals(JavaType.INT)) {
359                 if (invokeOp.operands().getFirst() instanceof Op.Result instanceResult) {
360                     atomicInc(buildContext, instanceResult, matcher.stringOf(1));
361                 } else {
362                     throw new IllegalStateException("bad atomic");
363                 }
364             } else {
365 
366                if (invokeOp.operands().getFirst() instanceof Op.Result instanceResult) {
367                 /*
368                 We have three types of returned values from an ifaceBuffer
369                 A primitive
370                     int id = stage.firstTreeId(); -> stage->firstTreeId;
371 
372                 Or a sub interface from an array
373                      Tree tree = cascade.tree(treeIdx); -> Tree_t * tree = &cascade->tree[treeIdx]
374                                                         ->               = cascade->tree + treeIdx;
375 
376                 Or a sub interface from a field
377 
378                 var left = feature.left();              ->  LinkOrValue_t * left= &feature->left
379 
380                                 -
381                     if (left.hasValue()) {                  left->hasValue
382                         sum += left.anon().value();         left->anon.value;
383                         feature = null; // loop ends
384                     } else {
385                         feature = cascade.feature(tree.firstFeatureId() + left.anon().featureId());
386                     }
387                  sumOfThisStage += left.anon().value();
388 
389 
390                 For a primitive we know that the accessor refers to a field so we just  map
391                          stage.firstTreeId() -> stage->firstTreeId;
392 
393                 For the sub interface we need to treat the call
394                           cascade.tree(treeIdx);
395 
396                 As an array index into cascade->tree[] that returns a typedef of Tree_t
397                 so we need to prefix with an & to return a Tree_t ptr
398                           &cascade->tree[treeIdx]
399 
400                  of course we could return
401                           cascade->tree + treeIdx;
402                  */
403 
404                    // TODO: extra parenthesis to be removed if we have a dialect to express iface memory access
405                    boolean needExtraParenthesis = needExtraParenthesis(invokeOp);
406                    when(needExtraParenthesis, _ -> oparen());
407 
408                    if (javaReturnType(invokeOp) instanceof ClassType) { // isAssignable?
409                        ampersand();
410                         /* This is way more complicated I think we need to determine the expression type.
411                          * sumOfThisStage=sumOfThisStage+&left->anon->value; from    sumOfThisStage += left.anon().value();
412                          */
413                    }
414 
415                    recurse(buildContext, instanceResult.op());
416 
417                    // TODO: extra parenthesis to be removed if we have a dialect to express iface memory access
418                    when(needExtraParenthesis, _ -> cparen());
419 
420                     // Check if the varOpLoad that could follow corresponds to a local/private type
421                     boolean isLocalOrPrivateDS = false;
422                     if (instanceResult.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
423                         Op resolve = buildContext.scope.resolve(varLoadOp.operands().getFirst());
424                         //if (localDataStructures.contains(resolve)) {
425                         if (resolve instanceof HATMemoryVarOp) {
426                             isLocalOrPrivateDS = true;
427                         }
428                     }
429 
430                     either(isLocalOrPrivateDS, CodeBuilder::dot, CodeBuilder::rarrow);
431 
432                     funcName(invokeOp);
433 
434                     if (javaReturnTypeIsVoid(invokeOp)) {
435                         //   setter
436                         switch (invokeOp.operands().size()) {
437                             case 2: {
438                                 if (invokeOp.operands().get(1) instanceof Op.Result result1) {
439                                     equals().recurse(buildContext, result1.op());
440                                 } else {
441                                     throw new IllegalStateException("How ");
442                                 }
443                                 break;
444                             }
445                             case 3: {
446                                 if (invokeOp.operands().get(1) instanceof Op.Result result1
447                                         && invokeOp.operands().get(2) instanceof Op.Result result2) {
448                                     sbrace(_ -> recurse(buildContext, result1.op()));
449                                     equals().recurse(buildContext, result2.op());
450                                 } else {
451                                     throw new IllegalStateException("How ");
452                                 }
453                                 break;
454                             }
455                             default: {
456                                 throw new IllegalStateException("How ");
457                             }
458                         }
459                     } else {
460                         if (resultOrNull(invokeOp,1) instanceof Op.Result result1) {
461                             sbrace(_ -> recurse(buildContext, result1.op()));
462                         } else {
463                             // This is a simple usage.   So scaleTable->multiScaleAccumRange
464                         }
465                     }
466                 } else {
467                     throw new IllegalStateException("[Illegal] Expected a parameter for the InvokOpWrapper Node");
468                 }
469             }
470         } else {
471             // General case
472             funcName(invokeOp).paren(_ ->
473                     commaSpaceSeparated(invokeOp.operands(),
474                             op -> {if (op instanceof Op.Result result) {recurse(buildContext, result.op());}
475                     })
476             );
477         }
478         return self();
479     }
480 
481     @Override
482     public T conditionalExpressionOp(ScopedCodeBuilderContext buildContext, JavaOp.ConditionalExpressionOp ternaryOp) {
483         condBlock(ternaryOp).ops().stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o -> recurse(buildContext, o));
484         questionMark();
485         thenBlock(ternaryOp).ops().stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o -> recurse(buildContext, o));
486         colon();
487         elseBlock(ternaryOp).ops().stream().filter(o -> o instanceof CoreOp.YieldOp).forEach(o -> recurse(buildContext, o));
488         return self();
489     }
490 
491     /**
492      * Wrap paren() of precedence of op is higher than parent.
493      *
494      * @param buildContext
495      * @param parent
496      * @param child
497      */
498     @Override
499     public T parenthesisIfNeeded(ScopedCodeBuilderContext buildContext, Op parent, Op child) {
500         return parenWhen(Precedence.needsParenthesis(parent,child), _ -> recurse(buildContext, child));
501     }
502 
503     @Override
504     public T returnOp(ScopedCodeBuilderContext buildContext, CoreOp.ReturnOp returnOp) {
505         returnKeyword().when(!returnOp.operands().isEmpty(),
506                         $-> $.space().parenthesisIfNeeded(buildContext, returnOp, OpTkl.result(returnOp).op())
507                 );
508         return self();
509     }
510 
511     public T statement(ScopedCodeBuilderContext buildContext,Op op) {
512         recurse(buildContext, op);
513         if (switch (op){
514                 case JavaOp.ForOp _ -> false;
515                 case JavaOp.WhileOp _ -> false;
516                 case JavaOp.IfOp _ -> false;
517                 case JavaOp.LabeledOp _ -> false;
518                 case JavaOp.YieldOp _ -> false;
519                 case CoreOp.TupleOp _ ->false;
520                 default -> true;
521             }
522         ){
523             semicolon();
524         }
525         return self();
526     }
527 
528     public T declareParam(ScopedCodeBuilderContext buildContext, FuncOpParams.Info param){
529         return  type(buildContext,(JavaType) param.parameter.type()).space().varName(param.varOp);
530     }
531 }