1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @modules jdk.incubator.code
 27  * @run junit TestExpressionGraphs
 28  */
 29 
 30 import jdk.incubator.code.*;
 31 import jdk.incubator.code.Reflect;
 32 import jdk.incubator.code.analysis.SSA;
 33 import jdk.incubator.code.dialect.core.CoreOp;
 34 import jdk.incubator.code.extern.OpWriter;
 35 import org.junit.jupiter.api.Test;
 36 
 37 import java.io.Writer;
 38 import java.lang.reflect.Method;
 39 import java.util.*;
 40 import java.util.function.Function;
 41 import java.util.stream.IntStream;
 42 import java.util.stream.Stream;
 43 
 44 public class TestExpressionGraphs {
 45 
 46     @Reflect
 47     static double sub(double a, double b) {
 48         return a - b;
 49     }
 50 
 51     @Reflect
 52     static double distance1(double a, double b) {
 53         return Math.abs(a - b);
 54     }
 55 
 56     @Reflect
 57     static double distance1a(final double a, final double b) {
 58         final double diff = a - b;
 59         final double result = Math.abs(diff);
 60         return result;
 61     }
 62 
 63     @Reflect
 64     static double distance1b(final double a, final double b) {
 65         final double diff = a - b;
 66         // Note, incorrect for negative zero values
 67         final double result = diff < 0d ? -diff : diff;
 68         return result;
 69     }
 70 
 71     @Reflect
 72     static double distanceN(double[] a, double[] b) {
 73         double sum = 0d;
 74         for (int i = 0; i < a.length; i++) {
 75             sum += Math.pow(a[i] - b[i], 2d);
 76         }
 77         return Math.sqrt(sum);
 78     }
 79 
 80     @Reflect
 81     static double squareDiff(double a, double b) {
 82         // a^2 - b^2 = (a + b) * (a - b)
 83         final double plus = a + b;
 84         final double minus = a - b;
 85         return plus * minus;
 86     }
 87 
 88     @Test
 89     void traverseSub() throws ReflectiveOperationException {
 90         // Get the reflective object for method sub
 91         Method m = TestExpressionGraphs.class.getDeclaredMethod(
 92                 "sub", double.class, double.class);
 93         // Get the code model for method sub
 94         Optional<CoreOp.FuncOp> oModel = Op.ofMethod(m);
 95         CoreOp.FuncOp model = oModel.orElseThrow();
 96 
 97         // Stream of elements topologically sorted in depth-first search pre-order
 98         model.elements().forEach(codeElement -> {
 99             // Count the depth of the code element
100             int depth = 0;
101             CodeElement<?, ?> parent = codeElement;
102             while ((parent = parent.parent()) != null) depth++;
103             // Print out code element class
104             System.out.println("  ".repeat(depth) + codeElement.getClass());
105         });
106     }
107 
108     @Test
109     void traverseDistance1() throws ReflectiveOperationException {
110         // Get the reflective object for method distance1
111         Method m = TestExpressionGraphs.class.getDeclaredMethod(
112                 "distance1", double.class, double.class);
113         // Get the code model for method distance1
114         Optional<CoreOp.FuncOp> oModel = Op.ofMethod(m);
115         CoreOp.FuncOp model = oModel.orElseThrow();
116 
117         // Stream of elements topologically sorted in depth-first search pre-order
118         model.elements().forEach(codeElement -> {
119             // Count the depth of the code element
120             int depth = 0;
121             CodeElement<?, ?> parent = codeElement;
122             while ((parent = parent.parent()) != null) depth++;
123             // Print out code element class
124             System.out.println("  ".repeat(depth) + codeElement.getClass());
125         });
126     }
127 
128 
129     @Test
130     void printSub() {
131         CoreOp.FuncOp model = getFuncOp("sub");
132         print(model);
133     }
134 
135     @Test
136     void printDistance1() {
137         CoreOp.FuncOp model = getFuncOp("distance1");
138         print(model);
139     }
140 
141     @Test
142     void printDistance1a() {
143         CoreOp.FuncOp model = getFuncOp("distance1a");
144         print(model);
145     }
146 
147     @Test
148     void printDistance1b() {
149         CoreOp.FuncOp model = getFuncOp("distance1b");
150         print(model);
151     }
152 
153     @Test
154     void printDistanceN() {
155         CoreOp.FuncOp model = getFuncOp("distanceN");
156         print(model);
157     }
158 
159     void print(CoreOp.FuncOp f) {
160         System.out.println(f.toText());
161 
162         f = f.transform(CodeTransformer.LOWERING_TRANSFORMER);
163         System.out.println(f.toText());
164 
165         f = SSA.transform(f);
166         System.out.println(f.toText());
167     }
168 
169 
170     @Test
171     void graphsDistance1() {
172         CoreOp.FuncOp model = getFuncOp("distance1");
173         Function<CodeItem, String> names = names(model);
174         System.out.println(printOpWriteVoid(names, model));
175 
176         // Create the expression graph for the terminating operation result
177         Op.Result returnResult = model.body().entryBlock().terminatingOp().result();
178         Node<Value> returnGraph = expressionGraph(returnResult);
179         System.out.println("Expression graph for terminating operation result");
180         // Transform from Node<Value> to Node<String> and print the graph
181         System.out.println(returnGraph.transformValues(v -> printValue(names, v)));
182 
183         System.out.println("Use graphs for block parameters");
184         for (Block.Parameter parameter : model.parameters()) {
185             Node<Value> useNode = useGraph(parameter);
186             System.out.println(useNode.transformValues(v -> printValue(names, v)));
187         }
188 
189         // Create the expression graphs for all values
190         Map<Value, Node<Value>> graphs = expressionGraphs(model);
191         System.out.println("Expression graphs for all declared values in the model");
192         graphs.values().forEach(n -> {
193             System.out.println(n.transformValues(v -> printValue(names, v)));
194         });
195 
196         // The graphs for the terminating operation result are the same
197         assert returnGraph.equals(graphs.get(returnGraph.value()));
198 
199         // Filter for root graphs, operation results with no uses
200         List<Node<Value>> rootGraphs = graphs.values().stream()
201                 .filter(n -> n.value() instanceof Op.Result opr &&
202                         switch (opr.op()) {
203                             // An operation result with no uses
204                             default -> opr.uses().isEmpty();
205                         })
206                 .toList();
207         System.out.println("Root expression graphs");
208         rootGraphs.forEach(n -> {
209             System.out.println(n.transformValues(v -> printValue(names, v)));
210         });
211     }
212 
213     @Test
214     void graphsDistance1a() {
215         CoreOp.FuncOp f = getFuncOp("distance1a");
216         Function<CodeItem, String> names = names(f);
217         System.out.println(printOpWriteVoid(names, f));
218 
219         {
220             Map<Value, Node<Value>> graphs = expressionGraphs(f);
221             List<Node<Value>> rootGraphs = graphs.values().stream()
222                     .filter(n -> n.value() instanceof Op.Result opr &&
223                             switch (opr.op()) {
224                                 // An operation result with no uses
225                                 default -> opr.uses().isEmpty();
226                             })
227                     .toList();
228             System.out.println("Root expression graphs");
229             rootGraphs.forEach(n -> {
230                 System.out.println(n.transformValues(v -> printValue(names, v)));
231             });
232         }
233 
234         {
235             Map<Value, Node<Value>> graphs = expressionGraphs(f);
236             List<Node<Value>> rootGraphs = graphs.values().stream()
237                     .filter(n -> n.value() instanceof Op.Result opr &&
238                             switch (opr.op()) {
239                                 // Variable declarations modeling local variables
240                                 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
241                                 // An operation result with no uses
242                                 default -> opr.uses().isEmpty();
243                             })
244                     .toList();
245             System.out.println("Root (with variable) expression graphs");
246             rootGraphs.forEach(n -> {
247                 System.out.println(n.transformValues(v -> printValue(names, v)));
248             });
249         }
250 
251         Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f);
252         List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream()
253                 .filter(n -> n.value() instanceof Op.Result opr &&
254                         switch (opr.op()) {
255                             // Variable declarations modeling local variables
256                             case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
257                             // An operation result with no uses
258                             default -> opr.uses().isEmpty();
259                         })
260                 .toList();
261         System.out.println("Pruned root expression graphs");
262         prunedRootGraphs.forEach(n -> {
263             System.out.println(n.transformValues(v -> printValue(names, v)));
264         });
265     }
266 
267     @Test
268     void graphsDistance1b() {
269         CoreOp.FuncOp f = getFuncOp("distance1b");
270         Function<CodeItem, String> names = names(f);
271         System.out.println(printOpWriteVoid(names, f));
272 
273         Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f);
274         List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream()
275                 .filter(n -> n.value() instanceof Op.Result opr &&
276                         switch (opr.op()) {
277                             // Variable declarations modeling declaration of local variables
278                             case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
279                             // An operation result with no uses
280                             default -> opr.uses().isEmpty();
281                         })
282                 .toList();
283         System.out.println("Pruned root expression graphs");
284         prunedRootGraphs.forEach(n -> {
285             System.out.println(n.transformValues(v -> printValue(names, v)));
286         });
287     }
288 
289     @Test
290     void graphsDistanceN() {
291         CoreOp.FuncOp f = getFuncOp("distanceN");
292         Function<CodeItem, String> names = names(f);
293         System.out.println(printOpWriteVoid(names, f));
294 
295         Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f);
296         List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream()
297                 .filter(n -> n.value() instanceof Op.Result opr &&
298                         switch (opr.op()) {
299                             // Variable declarations modeling declaration of local variables
300                             case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
301                             // An operation result with no uses
302                             default -> opr.uses().isEmpty();
303                         })
304                 .toList();
305         System.out.println("Pruned root expression graphs");
306         prunedRootGraphs.forEach(n -> {
307             System.out.println(n.transformValues(v -> printValue(names, v)));
308         });
309     }
310 
311     @Test
312     void graphsSquareDiff() {
313         CoreOp.FuncOp f = getFuncOp("squareDiff");
314         Function<CodeItem, String> names = names(f);
315         System.out.println(printOpWriteVoid(names, f));
316 
317         {
318             Map<Value, Node<Value>> graphs = expressionGraphs(f);
319             List<Node<Value>> rootGraphs = graphs.values().stream()
320                     .filter(n -> n.value() instanceof Op.Result opr &&
321                             switch (opr.op()) {
322                                 // An operation result with no uses
323                                 default -> opr.uses().isEmpty();
324                             })
325                     .toList();
326             System.out.println("Root expression graphs");
327             rootGraphs.forEach(n -> {
328                 System.out.println(n.transformValues(v -> printValue(names, v)));
329             });
330         }
331 
332         {
333             Map<Value, Node<Value>> graphs = expressionGraphs(f);
334             List<Node<Value>> rootGraphs = graphs.values().stream()
335                     .filter(n -> n.value() instanceof Op.Result opr &&
336                             switch (opr.op()) {
337                                 // Variable declarations modeling local variables
338                                 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
339                                 // An operation result with no uses
340                                 default -> opr.uses().isEmpty();
341                             })
342                     .toList();
343             System.out.println("Root (with variable) expression graphs");
344             rootGraphs.forEach(n -> {
345                 System.out.println(n.transformValues(v -> printValue(names, v)));
346             });
347         }
348 
349         Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f);
350         List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream()
351                 .filter(n -> n.value() instanceof Op.Result opr &&
352                         switch (opr.op()) {
353                             // Variable declarations modeling local variables
354                             case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
355                             // An operation result with no uses
356                             default -> opr.uses().isEmpty();
357                         })
358                 .toList();
359         System.out.println("Pruned root expression graphs");
360         prunedRootGraphs.forEach(n -> {
361             System.out.println(n.transformValues(v -> printValue(names, v)));
362         });
363     }
364 
365 
366 
367     @Reflect
368     static int h(int x) {
369         x += 2;                                                 // Statement 1
370         g(x);                                                   // Statement 2
371         int y = 1 + g(x) + (x += 2) + (x > 2 ? x : 10);         // Statement 3
372         for (                                                   // Statement 4
373                 int i = 0, j = 1;                               // Statements 4.1
374                 i < 3 && j < 3;
375                 i++, j++) {                                     // Statements 4.2
376             System.out.println(i);                              // Statement 4.3
377         }
378         return x + y;                                           // Statement 5
379     }
380 
381     static int g(int i) {
382         return i;
383     }
384 
385     @Test
386     void graphsH() {
387         CoreOp.FuncOp f = getFuncOp("h");
388         Function<CodeItem, String> names = names(f);
389         System.out.println(printOpWriteVoid(names, f));
390 
391         Map<Value, Node<Value>> graphs = prunedExpressionGraphs(f);
392         List<Node<Value>> rootGraphs = graphs.values().stream()
393                 .filter(n -> n.value() instanceof Op.Result opr &&
394                         switch (opr.op()) {
395                             // Variable declarations modeling declaration of local variables
396                             case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
397                             // Variable stores modeling assignment expressions whose result is used
398                             case CoreOp.VarAccessOp.VarStoreOp vsop -> vsop.operands().get(1).uses().size() == 1;
399                             // An operation result with no uses
400                             default -> opr.uses().isEmpty();
401                         })
402                 .toList();
403         rootGraphs.forEach(n -> {
404             System.out.println(n.transformValues(v -> printValue(names, v)));
405         });
406     }
407 
408 
409     static String printValue(Function<CodeItem, String> names, Value v) {
410         if (v instanceof Op.Result opr) {
411             return printOpHeader(names, opr.op());
412         } else {
413             return "%" + names.apply(v) + " <block parameter>";
414         }
415     }
416 
417     static String printOpHeader(Function<CodeItem, String> names, Op op) {
418         return OpWriter.toText(op,
419                 OpWriter.OpDescendantsOption.DROP_DESCENDANTS,
420                 OpWriter.VoidOpResultOption.WRITE_VOID,
421                 OpWriter.CodeItemNamerOption.of(names));
422     }
423 
424     static String printOpWriteVoid(Function<CodeItem, String> names, Op op) {
425         return OpWriter.toText(op,
426                 OpWriter.VoidOpResultOption.WRITE_VOID,
427                 OpWriter.CodeItemNamerOption.of(names));
428     }
429 
430     static Function<CodeItem, String> names(Op op) {
431         OpWriter w = new OpWriter(Writer.nullWriter(),
432                 OpWriter.VoidOpResultOption.WRITE_VOID);
433         w.writeOp(op);
434         return w.namer();
435     }
436 
437 
438     static Node<Value> expressionGraph(Value value) {
439         return expressionGraph(new HashMap<>(), value);
440     }
441 
442     static Node<Value> expressionGraph(Map<Value, Node<Value>> visited, Value value) {
443         // If value has already been visited return its node
444         if (visited.containsKey(value)) {
445             return visited.get(value);
446         }
447 
448         // Find the expression graphs for each operand
449         List<Node<Value>> edges = new ArrayList<>();
450         for (Value operand : value.dependsOn()) {
451             edges.add(expressionGraph(operand));
452         }
453         Node<Value> node = new Node<>(value, edges);
454         visited.put(value, node);
455         return node;
456     }
457 
458     static Node<Value> expressionGraphDetailed(Map<Value, Node<Value>> visited, Value value) {
459         // If value has already been visited return its node
460         if (visited.containsKey(value)) {
461             return visited.get(value);
462         }
463 
464         List<Node<Value>> edges;
465         if (value instanceof Op.Result result) {
466             edges = new ArrayList<>();
467             // Find the expression graphs for each operand
468             Set<Value> valueVisited = new HashSet<>();
469             for (Value operand : result.op().operands()) {
470                 // Ensure an operand is visited only once
471                 if (valueVisited.add(operand)) {
472                     edges.add(expressionGraph(operand));
473                 }
474             }
475             // TODO if terminating operation find expression graphs
476             //      for each successor argument
477         } else {
478             assert value instanceof Block.Parameter;
479             // A block parameter has no outgoing edges
480             edges = List.of();
481         }
482         Node<Value> node = new Node<>(value, edges);
483         visited.put(value, node);
484         return node;
485     }
486 
487 
488     static Node<Value> useGraph(Value value) {
489         return useGraph(new HashMap<>(), value);
490     }
491 
492     static Node<Value> useGraph(Map<Value, Node<Value>> visited, Value value) {
493         // If value has already been visited return its node
494         if (visited.containsKey(value)) {
495             return visited.get(value);
496         }
497 
498         // Find the use graphs for each use
499         List<Node<Value>> edges = new ArrayList<>();
500         for (Op.Result use : value.uses()) {
501             edges.add(useGraph(visited, use));
502         }
503         Node<Value> node = new Node<>(value, edges);
504         visited.put(value, node);
505         return node;
506     }
507 
508     static Map<Value, Node<Value>> expressionGraphs(CoreOp.FuncOp f) {
509         return expressionGraphs(f.body());
510     }
511 
512     static Map<Value, Node<Value>> expressionGraphs(Body b) {
513         LinkedHashMap<Value, Node<Value>> graphs = new LinkedHashMap<>();
514         // Traverse the model building structurally shared expression graphs
515         b.elements().forEach(codeElement -> {
516             switch (codeElement) {
517                 case Body _ -> {
518                     // Do nothing
519                 }
520                 case Block block -> {
521                     // Create the expression graphs for each block parameter
522                     // A block parameter has no outgoing edges
523                     for (Block.Parameter parameter : block.parameters()) {
524                         graphs.put(parameter, new Node<>(parameter, List.of()));
525                     }
526                 }
527                 case Op op -> {
528                     // Find the expression graphs for each operand
529                     List<Node<Value>> edges = new ArrayList<>();
530                     for (Value operand : op.result().dependsOn()) {
531                         // Get expression graph for the operand
532                         // It must be previously computed since we encounter the
533                         // declaration of values before their use
534                         edges.add(graphs.get(operand));
535                     }
536                     // Create the expression graph for this operation result
537                     graphs.put(op.result(), new Node<>(op.result(), edges));
538                 }
539             }
540         });
541         return graphs;
542     }
543 
544     static Map<Value, Node<Value>> prunedExpressionGraphs(CoreOp.FuncOp f) {
545         return prunedExpressionGraphs(f.body());
546     }
547 
548     static Map<Value, Node<Value>> prunedExpressionGraphs(Body b) {
549         LinkedHashMap<Value, Node<Value>> graphs = new LinkedHashMap<>();
550         // Traverse the model building structurally shared expression graphs
551         b.elements().forEach(codeElement -> {
552             switch (codeElement) {
553                 case Body _ -> {
554                     // Do nothing
555                 }
556                 case Block block -> {
557                     // Create the expression graphs for each block parameter
558                     // A block parameter has no outgoing edges
559                     for (Block.Parameter parameter : block.parameters()) {
560                         graphs.put(parameter, new Node<>(parameter, List.of()));
561                     }
562                 }
563                 // Prune graph for variable load operation
564                 case CoreOp.VarAccessOp.VarLoadOp op -> {
565                     // Ignore edge for the variable value operand
566                     graphs.put(op.result(), new Node<>(op.result(), List.of()));
567                 }
568                 // Prune graph for variable store operation
569                 case CoreOp.VarAccessOp.VarStoreOp op -> {
570                     // Ignore edge for the variable value operand
571                     // Add edge for value to store
572                     List<Node<Value>> edges = List.of(graphs.get(op.operands().get(1)));
573                     graphs.put(op.result(), new Node<>(op.result(), edges));
574                 }
575                 case Op op -> {
576                     // Find the expression graphs for each operand
577                     List<Node<Value>> edges = new ArrayList<>();
578                     for (Value operand : op.result().dependsOn()) {
579                         // Get expression graph for the operand
580                         // It must be previously computed since we encounter the
581                         // declaration of values before their use
582                         edges.add(graphs.get(operand));
583                     }
584                     // Create the expression graph for this operation result
585                     graphs.put(op.result(), new Node<>(op.result(), edges));
586                 }
587             }
588         });
589         return graphs;
590     }
591 
592     record Node<T>(T value, List<Node<T>> edges) {
593         <U> Node<U> transformValues(Function<T, U> f) {
594             List<Node<U>> transformedEdges = new ArrayList<>();
595             for (Node<T> edge : edges()) {
596                 transformedEdges.add(edge.transformValues(f));
597             }
598             return new Node<>(f.apply(value()), transformedEdges);
599         }
600 
601         Node<T> transformGraph(Function<Node<T>, Node<T>> f) {
602             Node<T> apply = f.apply(this);
603             if (apply != this) {
604                 // The function returned a new node
605                 return apply;
606             } else {
607                 // The function returned the same node
608                 // Apply the transformation to the children
609                 List<Node<T>> transformedEdges = new ArrayList<>();
610                 for (Node<T> edge : edges()) {
611                     transformedEdges.add(edge.transformGraph(f));
612                 }
613                 boolean same = IntStream.range(0, edges().size())
614                         .allMatch(i -> edges().get(i) ==
615                                 transformedEdges.get(i));
616                 if (same) {
617                     return this;
618                 } else {
619                     return new Node<>(this.value(), transformedEdges);
620                 }
621             }
622         }
623 
624         @Override
625         public String toString() {
626             StringBuilder sb = new StringBuilder();
627             print(sb, "", "");
628             return sb.toString();
629         }
630 
631         private void print(StringBuilder sb, String prefix, String edgePrefix) {
632             sb.append(prefix);
633             sb.append(value);
634             sb.append('\n');
635             for (Iterator<Node<T>> it = edges.iterator(); it.hasNext(); ) {
636                 Node<T> edge = it.next();
637                 if (it.hasNext()) {
638                     edge.print(sb, edgePrefix + "├── ", edgePrefix + "│   ");
639                 } else {
640                     edge.print(sb, edgePrefix + "└── ", edgePrefix + "    ");
641                 }
642             }
643         }
644     }
645 
646 
647     static CoreOp.FuncOp getFuncOp(String name) {
648         Optional<Method> om = Stream.of(TestExpressionGraphs.class.getDeclaredMethods())
649                 .filter(m -> m.getName().equals(name))
650                 .findFirst();
651 
652         Method m = om.get();
653         return Op.ofMethod(m).get();
654     }
655 
656 }