1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 /*
25 * @test
26 * @modules jdk.incubator.code
27 * @run junit TestExpressionGraphs
28 */
29
30 import jdk.incubator.code.*;
31 import jdk.incubator.code.analysis.SSA;
32 import jdk.incubator.code.dialect.core.CoreOp;
33 import jdk.incubator.code.extern.OpWriter;
34 import org.junit.jupiter.api.Test;
35
36 import java.io.Writer;
37 import java.lang.reflect.Method;
38 import java.util.*;
39 import java.util.function.Function;
40 import java.util.stream.IntStream;
41 import java.util.stream.Stream;
42
43 public class TestExpressionGraphs {
44
45 @CodeReflection
46 static double sub(double a, double b) {
47 return a - b;
48 }
49
50 @CodeReflection
51 static double distance1(double a, double b) {
52 return Math.abs(a - b);
53 }
54
55 @CodeReflection
56 static double distance1a(final double a, final double b) {
57 final double diff = a - b;
58 final double result = Math.abs(diff);
59 return result;
60 }
61
62 @CodeReflection
63 static double distance1b(final double a, final double b) {
64 final double diff = a - b;
65 // Note, incorrect for negative zero values
66 final double result = diff < 0d ? -diff : diff;
67 return result;
68 }
69
70 @CodeReflection
71 static double distanceN(double[] a, double[] b) {
72 double sum = 0d;
73 for (int i = 0; i < a.length; i++) {
74 sum += Math.pow(a[i] - b[i], 2d);
75 }
76 return Math.sqrt(sum);
77 }
78
79 @CodeReflection
80 static double squareDiff(double a, double b) {
81 // a^2 - b^2 = (a + b) * (a - b)
82 final double plus = a + b;
83 final double minus = a - b;
84 return plus * minus;
85 }
86
87 @Test
88 void traverseSub() throws ReflectiveOperationException {
89 // Get the reflective object for method sub
90 Method m = TestExpressionGraphs.class.getDeclaredMethod(
91 "sub", double.class, double.class);
92 // Get the code model for method sub
93 Optional<CoreOp.FuncOp> oModel = Op.ofMethod(m);
94 CoreOp.FuncOp model = oModel.orElseThrow();
95
96 // Depth-first search, reporting elements in pre-order
97 model.traverse(null, (acc, codeElement) -> {
98 // Count the depth of the code element by
99 // traversing up the tree from child to parent
100 int depth = 0;
101 CodeElement<?, ?> parent = codeElement;
102 while ((parent = parent.parent()) != null) depth++;
103 // Print out code element class
104 System.out.println(" ".repeat(depth) + codeElement.getClass());
105 return acc;
106 });
107
108 // Stream of elements topologically sorted in depth-first search pre-order
109 model.elements().forEach(codeElement -> {
110 // Count the depth of the code element
111 int depth = 0;
112 CodeElement<?, ?> parent = codeElement;
113 while ((parent = parent.parent()) != null) depth++;
114 // Print out code element class
115 System.out.println(" ".repeat(depth) + codeElement.getClass());
116 });
117 }
118
119 @Test
120 void traverseDistance1() throws ReflectiveOperationException {
121 // Get the reflective object for method distance1
122 Method m = TestExpressionGraphs.class.getDeclaredMethod(
123 "distance1", double.class, double.class);
124 // Get the code model for method distance1
125 Optional<CoreOp.FuncOp> oModel = Op.ofMethod(m);
126 CoreOp.FuncOp model = oModel.orElseThrow();
127
128 // Depth-first search, reporting elements in pre-order
129 model.traverse(null, (acc, codeElement) -> {
130 // Count the depth of the code element by
131 // traversing up the tree from child to parent
132 int depth = 0;
133 CodeElement<?, ?> parent = codeElement;
134 while ((parent = parent.parent()) != null) depth++;
135 // Print out code element class
136 System.out.println(" ".repeat(depth) + codeElement.getClass());
137 return acc;
138 });
139
140 // Stream of elements topologically sorted in depth-first search pre-order
141 model.elements().forEach(codeElement -> {
142 // Count the depth of the code element
143 int depth = 0;
144 CodeElement<?, ?> parent = codeElement;
145 while ((parent = parent.parent()) != null) depth++;
146 // Print out code element class
147 System.out.println(" ".repeat(depth) + codeElement.getClass());
148 });
149 }
150
151
152 @Test
153 void printSub() {
154 CoreOp.FuncOp model = getFuncOp("sub");
155 print(model);
156 }
157
158 @Test
159 void printDistance1() {
160 CoreOp.FuncOp model = getFuncOp("distance1");
161 print(model);
162 }
163
164 @Test
165 void printDistance1a() {
166 CoreOp.FuncOp model = getFuncOp("distance1a");
167 print(model);
168 }
169
170 @Test
171 void printDistance1b() {
172 CoreOp.FuncOp model = getFuncOp("distance1b");
173 print(model);
174 }
175
176 @Test
177 void printDistanceN() {
178 CoreOp.FuncOp model = getFuncOp("distanceN");
179 print(model);
180 }
181
182 void print(CoreOp.FuncOp f) {
183 System.out.println(f.toText());
184
185 f = f.transform(OpTransformer.LOWERING_TRANSFORMER);
186 System.out.println(f.toText());
187
188 f = SSA.transform(f);
189 System.out.println(f.toText());
190 }
191
192
193 @Test
194 void graphsDistance1() {
195 CoreOp.FuncOp model = getFuncOp("distance1");
196 Function<CodeItem, String> names = names(model);
197 System.out.println(printOpWriteVoid(names, model));
198
199 // Create the expression graph for the terminating operation result
200 Op.Result returnResult = model.body().entryBlock().terminatingOp().result();
201 Node<Value> returnGraph = expressionGraph(returnResult);
202 System.out.println("Expression graph for terminating operation result");
203 // Transform from Node<Value> to Node<String> and print the graph
204 System.out.println(returnGraph.transformValues(v -> printValue(names, v)));
205
206 System.out.println("Use graphs for block parameters");
207 for (Block.Parameter parameter : model.parameters()) {
208 Node<Value> useNode = useGraph(parameter);
209 System.out.println(useNode.transformValues(v -> printValue(names, v)));
210 }
211
212 // Create the expression graphs for all values
213 Map<Value, Node<Value>> graphs = expressionGraphs(model);
214 System.out.println("Expression graphs for all declared values in the model");
215 graphs.values().forEach(n -> {
216 System.out.println(n.transformValues(v -> printValue(names, v)));
217 });
218
219 // The graphs for the terminating operation result are the same
220 assert returnGraph.equals(graphs.get(returnGraph.value()));
221
222 // Filter for root graphs, operation results with no uses
223 List<Node<Value>> rootGraphs = graphs.values().stream()
224 .filter(n -> n.value() instanceof Op.Result opr &&
225 switch (opr.op()) {
226 // An operation result with no uses
227 default -> opr.uses().isEmpty();
228 })
229 .toList();
230 System.out.println("Root expression graphs");
231 rootGraphs.forEach(n -> {
232 System.out.println(n.transformValues(v -> printValue(names, v)));
233 });
234 }
235
236 @Test
237 void graphsDistance1a() {
238 CoreOp.FuncOp f = getFuncOp("distance1a");
239 Function<CodeItem, String> names = names(f);
240 System.out.println(printOpWriteVoid(names, f));
241
242 {
243 Map<Value, Node<Value>> graphs = expressionGraphs(f);
244 List<Node<Value>> rootGraphs = graphs.values().stream()
245 .filter(n -> n.value() instanceof Op.Result opr &&
246 switch (opr.op()) {
247 // An operation result with no uses
248 default -> opr.uses().isEmpty();
249 })
250 .toList();
251 System.out.println("Root expression graphs");
252 rootGraphs.forEach(n -> {
253 System.out.println(n.transformValues(v -> printValue(names, v)));
254 });
255 }
256
257 {
258 Map<Value, Node<Value>> graphs = expressionGraphs(f);
259 List<Node<Value>> rootGraphs = graphs.values().stream()
260 .filter(n -> n.value() instanceof Op.Result opr &&
261 switch (opr.op()) {
262 // Variable declarations modeling local variables
263 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
264 // An operation result with no uses
265 default -> opr.uses().isEmpty();
266 })
267 .toList();
268 System.out.println("Root (with variable) expression graphs");
269 rootGraphs.forEach(n -> {
270 System.out.println(n.transformValues(v -> printValue(names, v)));
271 });
272 }
273
274 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f);
275 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream()
276 .filter(n -> n.value() instanceof Op.Result opr &&
277 switch (opr.op()) {
278 // Variable declarations modeling local variables
279 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
280 // An operation result with no uses
281 default -> opr.uses().isEmpty();
282 })
283 .toList();
284 System.out.println("Pruned root expression graphs");
285 prunedRootGraphs.forEach(n -> {
286 System.out.println(n.transformValues(v -> printValue(names, v)));
287 });
288 }
289
290 @Test
291 void graphsDistance1b() {
292 CoreOp.FuncOp f = getFuncOp("distance1b");
293 Function<CodeItem, String> names = names(f);
294 System.out.println(printOpWriteVoid(names, f));
295
296 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f);
297 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream()
298 .filter(n -> n.value() instanceof Op.Result opr &&
299 switch (opr.op()) {
300 // Variable declarations modeling declaration of local variables
301 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
302 // An operation result with no uses
303 default -> opr.uses().isEmpty();
304 })
305 .toList();
306 System.out.println("Pruned root expression graphs");
307 prunedRootGraphs.forEach(n -> {
308 System.out.println(n.transformValues(v -> printValue(names, v)));
309 });
310 }
311
312 @Test
313 void graphsDistanceN() {
314 CoreOp.FuncOp f = getFuncOp("distanceN");
315 Function<CodeItem, String> names = names(f);
316 System.out.println(printOpWriteVoid(names, f));
317
318 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f);
319 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream()
320 .filter(n -> n.value() instanceof Op.Result opr &&
321 switch (opr.op()) {
322 // Variable declarations modeling declaration of local variables
323 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
324 // An operation result with no uses
325 default -> opr.uses().isEmpty();
326 })
327 .toList();
328 System.out.println("Pruned root expression graphs");
329 prunedRootGraphs.forEach(n -> {
330 System.out.println(n.transformValues(v -> printValue(names, v)));
331 });
332 }
333
334 @Test
335 void graphsSquareDiff() {
336 CoreOp.FuncOp f = getFuncOp("squareDiff");
337 Function<CodeItem, String> names = names(f);
338 System.out.println(printOpWriteVoid(names, f));
339
340 {
341 Map<Value, Node<Value>> graphs = expressionGraphs(f);
342 List<Node<Value>> rootGraphs = graphs.values().stream()
343 .filter(n -> n.value() instanceof Op.Result opr &&
344 switch (opr.op()) {
345 // An operation result with no uses
346 default -> opr.uses().isEmpty();
347 })
348 .toList();
349 System.out.println("Root expression graphs");
350 rootGraphs.forEach(n -> {
351 System.out.println(n.transformValues(v -> printValue(names, v)));
352 });
353 }
354
355 {
356 Map<Value, Node<Value>> graphs = expressionGraphs(f);
357 List<Node<Value>> rootGraphs = graphs.values().stream()
358 .filter(n -> n.value() instanceof Op.Result opr &&
359 switch (opr.op()) {
360 // Variable declarations modeling local variables
361 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
362 // An operation result with no uses
363 default -> opr.uses().isEmpty();
364 })
365 .toList();
366 System.out.println("Root (with variable) expression graphs");
367 rootGraphs.forEach(n -> {
368 System.out.println(n.transformValues(v -> printValue(names, v)));
369 });
370 }
371
372 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f);
373 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream()
374 .filter(n -> n.value() instanceof Op.Result opr &&
375 switch (opr.op()) {
376 // Variable declarations modeling local variables
377 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
378 // An operation result with no uses
379 default -> opr.uses().isEmpty();
380 })
381 .toList();
382 System.out.println("Pruned root expression graphs");
383 prunedRootGraphs.forEach(n -> {
384 System.out.println(n.transformValues(v -> printValue(names, v)));
385 });
386 }
387
388
389
390 @CodeReflection
391 static int h(int x) {
392 x += 2; // Statement 1
393 g(x); // Statement 2
394 int y = 1 + g(x) + (x += 2) + (x > 2 ? x : 10); // Statement 3
395 for ( // Statement 4
396 int i = 0, j = 1; // Statements 4.1
397 i < 3 && j < 3;
398 i++, j++) { // Statements 4.2
399 System.out.println(i); // Statement 4.3
400 }
401 return x + y; // Statement 5
402 }
403
404 static int g(int i) {
405 return i;
406 }
407
408 @Test
409 void graphsH() {
410 CoreOp.FuncOp f = getFuncOp("h");
411 Function<CodeItem, String> names = names(f);
412 System.out.println(printOpWriteVoid(names, f));
413
414 Map<Value, Node<Value>> graphs = prunedExpressionGraphs(f);
415 List<Node<Value>> rootGraphs = graphs.values().stream()
416 .filter(n -> n.value() instanceof Op.Result opr &&
417 switch (opr.op()) {
418 // Variable declarations modeling declaration of local variables
419 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
420 // Variable stores modeling assignment expressions whose result is used
421 case CoreOp.VarAccessOp.VarStoreOp vsop -> vsop.operands().get(1).uses().size() == 1;
422 // An operation result with no uses
423 default -> opr.uses().isEmpty();
424 })
425 .toList();
426 rootGraphs.forEach(n -> {
427 System.out.println(n.transformValues(v -> printValue(names, v)));
428 });
429 }
430
431
432 static String printValue(Function<CodeItem, String> names, Value v) {
433 if (v instanceof Op.Result opr) {
434 return printOpHeader(names, opr.op());
435 } else {
436 return "%" + names.apply(v) + " <block parameter>";
437 }
438 }
439
440 static String printOpHeader(Function<CodeItem, String> names, Op op) {
441 return OpWriter.toText(op,
442 OpWriter.OpDescendantsOption.DROP_DESCENDANTS,
443 OpWriter.VoidOpResultOption.WRITE_VOID,
444 OpWriter.CodeItemNamerOption.of(names));
445 }
446
447 static String printOpWriteVoid(Function<CodeItem, String> names, Op op) {
448 return OpWriter.toText(op,
449 OpWriter.VoidOpResultOption.WRITE_VOID,
450 OpWriter.CodeItemNamerOption.of(names));
451 }
452
453 static Function<CodeItem, String> names(Op op) {
454 OpWriter w = new OpWriter(Writer.nullWriter(),
455 OpWriter.VoidOpResultOption.WRITE_VOID);
456 w.writeOp(op);
457 return w.namer();
458 }
459
460
461 static Node<Value> expressionGraph(Value value) {
462 return expressionGraph(new HashMap<>(), value);
463 }
464
465 static Node<Value> expressionGraph(Map<Value, Node<Value>> visited, Value value) {
466 // If value has already been visited return its node
467 if (visited.containsKey(value)) {
468 return visited.get(value);
469 }
470
471 // Find the expression graphs for each operand
472 List<Node<Value>> edges = new ArrayList<>();
473 for (Value operand : value.dependsOn()) {
474 edges.add(expressionGraph(operand));
475 }
476 Node<Value> node = new Node<>(value, edges);
477 visited.put(value, node);
478 return node;
479 }
480
481 static Node<Value> expressionGraphDetailed(Map<Value, Node<Value>> visited, Value value) {
482 // If value has already been visited return its node
483 if (visited.containsKey(value)) {
484 return visited.get(value);
485 }
486
487 List<Node<Value>> edges;
488 if (value instanceof Op.Result result) {
489 edges = new ArrayList<>();
490 // Find the expression graphs for each operand
491 Set<Value> valueVisited = new HashSet<>();
492 for (Value operand : result.op().operands()) {
493 // Ensure an operand is visited only once
494 if (valueVisited.add(operand)) {
495 edges.add(expressionGraph(operand));
496 }
497 }
498 // TODO if terminating operation find expression graphs
499 // for each successor argument
500 } else {
501 assert value instanceof Block.Parameter;
502 // A block parameter has no outgoing edges
503 edges = List.of();
504 }
505 Node<Value> node = new Node<>(value, edges);
506 visited.put(value, node);
507 return node;
508 }
509
510
511 static Node<Value> useGraph(Value value) {
512 return useGraph(new HashMap<>(), value);
513 }
514
515 static Node<Value> useGraph(Map<Value, Node<Value>> visited, Value value) {
516 // If value has already been visited return its node
517 if (visited.containsKey(value)) {
518 return visited.get(value);
519 }
520
521 // Find the use graphs for each use
522 List<Node<Value>> edges = new ArrayList<>();
523 for (Op.Result use : value.uses()) {
524 edges.add(useGraph(visited, use));
525 }
526 Node<Value> node = new Node<>(value, edges);
527 visited.put(value, node);
528 return node;
529 }
530
531 static Map<Value, Node<Value>> expressionGraphs(CoreOp.FuncOp f) {
532 return expressionGraphs(f.body());
533 }
534
535 static Map<Value, Node<Value>> expressionGraphs(Body b) {
536 // Traverse the model building structurally shared expression graphs
537 return b.traverse(new LinkedHashMap<>(), (graphs, codeElement) -> {
538 switch (codeElement) {
539 case Body _ -> {
540 // Do nothing
541 }
542 case Block block -> {
543 // Create the expression graphs for each block parameter
544 // A block parameter has no outgoing edges
545 for (Block.Parameter parameter : block.parameters()) {
546 graphs.put(parameter, new Node<>(parameter, List.of()));
547 }
548 }
549 case Op op -> {
550 // Find the expression graphs for each operand
551 List<Node<Value>> edges = new ArrayList<>();
552 for (Value operand : op.result().dependsOn()) {
553 // Get expression graph for the operand
554 // It must be previously computed since we encounter the
555 // declaration of values before their use
556 edges.add(graphs.get(operand));
557 }
558 // Create the expression graph for this operation result
559 graphs.put(op.result(), new Node<>(op.result(), edges));
560 }
561 }
562 return graphs;
563 });
564 }
565
566 static Map<Value, Node<Value>> prunedExpressionGraphs(CoreOp.FuncOp f) {
567 return prunedExpressionGraphs(f.body());
568 }
569
570 static Map<Value, Node<Value>> prunedExpressionGraphs(Body b) {
571 // Traverse the model building structurally shared expression graphs
572 return b.traverse(new LinkedHashMap<>(), (graphs, codeElement) -> {
573 switch (codeElement) {
574 case Body _ -> {
575 // Do nothing
576 }
577 case Block block -> {
578 // Create the expression graphs for each block parameter
579 // A block parameter has no outgoing edges
580 for (Block.Parameter parameter : block.parameters()) {
581 graphs.put(parameter, new Node<>(parameter, List.of()));
582 }
583 }
584 // Prune graph for variable load operation
585 case CoreOp.VarAccessOp.VarLoadOp op -> {
586 // Ignore edge for the variable value operand
587 graphs.put(op.result(), new Node<>(op.result(), List.of()));
588 }
589 // Prune graph for variable store operation
590 case CoreOp.VarAccessOp.VarStoreOp op -> {
591 // Ignore edge for the variable value operand
592 // Add edge for value to store
593 List<Node<Value>> edges = List.of(graphs.get(op.operands().get(1)));
594 graphs.put(op.result(), new Node<>(op.result(), edges));
595 }
596 case Op op -> {
597 // Find the expression graphs for each operand
598 List<Node<Value>> edges = new ArrayList<>();
599 for (Value operand : op.result().dependsOn()) {
600 // Get expression graph for the operand
601 // It must be previously computed since we encounter the
602 // declaration of values before their use
603 edges.add(graphs.get(operand));
604 }
605 // Create the expression graph for this operation result
606 graphs.put(op.result(), new Node<>(op.result(), edges));
607 }
608 }
609 return graphs;
610 });
611 }
612
613 record Node<T>(T value, List<Node<T>> edges) {
614 <U> Node<U> transformValues(Function<T, U> f) {
615 List<Node<U>> transformedEdges = new ArrayList<>();
616 for (Node<T> edge : edges()) {
617 transformedEdges.add(edge.transformValues(f));
618 }
619 return new Node<>(f.apply(value()), transformedEdges);
620 }
621
622 Node<T> transformGraph(Function<Node<T>, Node<T>> f) {
623 Node<T> apply = f.apply(this);
624 if (apply != this) {
625 // The function returned a new node
626 return apply;
627 } else {
628 // The function returned the same node
629 // Apply the transformation to the children
630 List<Node<T>> transformedEdges = new ArrayList<>();
631 for (Node<T> edge : edges()) {
632 transformedEdges.add(edge.transformGraph(f));
633 }
634 boolean same = IntStream.range(0, edges().size())
635 .allMatch(i -> edges().get(i) ==
636 transformedEdges.get(i));
637 if (same) {
638 return this;
639 } else {
640 return new Node<>(this.value(), transformedEdges);
641 }
642 }
643 }
644
645 @Override
646 public String toString() {
647 StringBuilder sb = new StringBuilder();
648 print(sb, "", "");
649 return sb.toString();
650 }
651
652 private void print(StringBuilder sb, String prefix, String edgePrefix) {
653 sb.append(prefix);
654 sb.append(value);
655 sb.append('\n');
656 for (Iterator<Node<T>> it = edges.iterator(); it.hasNext(); ) {
657 Node<T> edge = it.next();
658 if (it.hasNext()) {
659 edge.print(sb, edgePrefix + "├── ", edgePrefix + "│ ");
660 } else {
661 edge.print(sb, edgePrefix + "└── ", edgePrefix + " ");
662 }
663 }
664 }
665 }
666
667
668 static CoreOp.FuncOp getFuncOp(String name) {
669 Optional<Method> om = Stream.of(TestExpressionGraphs.class.getDeclaredMethods())
670 .filter(m -> m.getName().equals(name))
671 .findFirst();
672
673 Method m = om.get();
674 return Op.ofMethod(m).get();
675 }
676
677 }