1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 /*
25 * @test
26 * @modules jdk.incubator.code
27 * @run junit TestExpressionGraphs
28 */
29
30 import jdk.incubator.code.*;
31 import jdk.incubator.code.Reflect;
32 import jdk.incubator.code.analysis.SSA;
33 import jdk.incubator.code.dialect.core.CoreOp;
34 import jdk.incubator.code.extern.OpWriter;
35 import org.junit.jupiter.api.Test;
36
37 import java.io.Writer;
38 import java.lang.reflect.Method;
39 import java.util.*;
40 import java.util.function.Function;
41 import java.util.stream.IntStream;
42 import java.util.stream.Stream;
43
44 public class TestExpressionGraphs {
45
46 @Reflect
47 static double sub(double a, double b) {
48 return a - b;
49 }
50
51 @Reflect
52 static double distance1(double a, double b) {
53 return Math.abs(a - b);
54 }
55
56 @Reflect
57 static double distance1a(final double a, final double b) {
58 final double diff = a - b;
59 final double result = Math.abs(diff);
60 return result;
61 }
62
63 @Reflect
64 static double distance1b(final double a, final double b) {
65 final double diff = a - b;
66 // Note, incorrect for negative zero values
67 final double result = diff < 0d ? -diff : diff;
68 return result;
69 }
70
71 @Reflect
72 static double distanceN(double[] a, double[] b) {
73 double sum = 0d;
74 for (int i = 0; i < a.length; i++) {
75 sum += Math.pow(a[i] - b[i], 2d);
76 }
77 return Math.sqrt(sum);
78 }
79
80 @Reflect
81 static double squareDiff(double a, double b) {
82 // a^2 - b^2 = (a + b) * (a - b)
83 final double plus = a + b;
84 final double minus = a - b;
85 return plus * minus;
86 }
87
88 @Test
89 void traverseSub() throws ReflectiveOperationException {
90 // Get the reflective object for method sub
91 Method m = TestExpressionGraphs.class.getDeclaredMethod(
92 "sub", double.class, double.class);
93 // Get the code model for method sub
94 Optional<CoreOp.FuncOp> oModel = Op.ofMethod(m);
95 CoreOp.FuncOp model = oModel.orElseThrow();
96
97 // Stream of elements topologically sorted in depth-first search pre-order
98 model.elements().forEach(codeElement -> {
99 // Count the depth of the code element
100 int depth = 0;
101 CodeElement<?, ?> parent = codeElement;
102 while ((parent = parent.parent()) != null) depth++;
103 // Print out code element class
104 System.out.println(" ".repeat(depth) + codeElement.getClass());
105 });
106 }
107
108 @Test
109 void traverseDistance1() throws ReflectiveOperationException {
110 // Get the reflective object for method distance1
111 Method m = TestExpressionGraphs.class.getDeclaredMethod(
112 "distance1", double.class, double.class);
113 // Get the code model for method distance1
114 Optional<CoreOp.FuncOp> oModel = Op.ofMethod(m);
115 CoreOp.FuncOp model = oModel.orElseThrow();
116
117 // Stream of elements topologically sorted in depth-first search pre-order
118 model.elements().forEach(codeElement -> {
119 // Count the depth of the code element
120 int depth = 0;
121 CodeElement<?, ?> parent = codeElement;
122 while ((parent = parent.parent()) != null) depth++;
123 // Print out code element class
124 System.out.println(" ".repeat(depth) + codeElement.getClass());
125 });
126 }
127
128
129 @Test
130 void printSub() {
131 CoreOp.FuncOp model = getFuncOp("sub");
132 print(model);
133 }
134
135 @Test
136 void printDistance1() {
137 CoreOp.FuncOp model = getFuncOp("distance1");
138 print(model);
139 }
140
141 @Test
142 void printDistance1a() {
143 CoreOp.FuncOp model = getFuncOp("distance1a");
144 print(model);
145 }
146
147 @Test
148 void printDistance1b() {
149 CoreOp.FuncOp model = getFuncOp("distance1b");
150 print(model);
151 }
152
153 @Test
154 void printDistanceN() {
155 CoreOp.FuncOp model = getFuncOp("distanceN");
156 print(model);
157 }
158
159 void print(CoreOp.FuncOp f) {
160 System.out.println(f.toText());
161
162 f = f.transform(CodeTransformer.LOWERING_TRANSFORMER);
163 System.out.println(f.toText());
164
165 f = SSA.transform(f);
166 System.out.println(f.toText());
167 }
168
169
170 @Test
171 void graphsDistance1() {
172 CoreOp.FuncOp model = getFuncOp("distance1");
173 Function<CodeItem, String> names = names(model);
174 System.out.println(printOpWriteVoid(names, model));
175
176 // Create the expression graph for the terminating operation result
177 Op.Result returnResult = model.body().entryBlock().terminatingOp().result();
178 Node<Value> returnGraph = expressionGraph(returnResult);
179 System.out.println("Expression graph for terminating operation result");
180 // Transform from Node<Value> to Node<String> and print the graph
181 System.out.println(returnGraph.transformValues(v -> printValue(names, v)));
182
183 System.out.println("Use graphs for block parameters");
184 for (Block.Parameter parameter : model.parameters()) {
185 Node<Value> useNode = useGraph(parameter);
186 System.out.println(useNode.transformValues(v -> printValue(names, v)));
187 }
188
189 // Create the expression graphs for all values
190 Map<Value, Node<Value>> graphs = expressionGraphs(model);
191 System.out.println("Expression graphs for all declared values in the model");
192 graphs.values().forEach(n -> {
193 System.out.println(n.transformValues(v -> printValue(names, v)));
194 });
195
196 // The graphs for the terminating operation result are the same
197 assert returnGraph.equals(graphs.get(returnGraph.value()));
198
199 // Filter for root graphs, operation results with no uses
200 List<Node<Value>> rootGraphs = graphs.values().stream()
201 .filter(n -> n.value() instanceof Op.Result opr &&
202 switch (opr.op()) {
203 // An operation result with no uses
204 default -> opr.uses().isEmpty();
205 })
206 .toList();
207 System.out.println("Root expression graphs");
208 rootGraphs.forEach(n -> {
209 System.out.println(n.transformValues(v -> printValue(names, v)));
210 });
211 }
212
213 @Test
214 void graphsDistance1a() {
215 CoreOp.FuncOp f = getFuncOp("distance1a");
216 Function<CodeItem, String> names = names(f);
217 System.out.println(printOpWriteVoid(names, f));
218
219 {
220 Map<Value, Node<Value>> graphs = expressionGraphs(f);
221 List<Node<Value>> rootGraphs = graphs.values().stream()
222 .filter(n -> n.value() instanceof Op.Result opr &&
223 switch (opr.op()) {
224 // An operation result with no uses
225 default -> opr.uses().isEmpty();
226 })
227 .toList();
228 System.out.println("Root expression graphs");
229 rootGraphs.forEach(n -> {
230 System.out.println(n.transformValues(v -> printValue(names, v)));
231 });
232 }
233
234 {
235 Map<Value, Node<Value>> graphs = expressionGraphs(f);
236 List<Node<Value>> rootGraphs = graphs.values().stream()
237 .filter(n -> n.value() instanceof Op.Result opr &&
238 switch (opr.op()) {
239 // Variable declarations modeling local variables
240 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
241 // An operation result with no uses
242 default -> opr.uses().isEmpty();
243 })
244 .toList();
245 System.out.println("Root (with variable) expression graphs");
246 rootGraphs.forEach(n -> {
247 System.out.println(n.transformValues(v -> printValue(names, v)));
248 });
249 }
250
251 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f);
252 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream()
253 .filter(n -> n.value() instanceof Op.Result opr &&
254 switch (opr.op()) {
255 // Variable declarations modeling local variables
256 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
257 // An operation result with no uses
258 default -> opr.uses().isEmpty();
259 })
260 .toList();
261 System.out.println("Pruned root expression graphs");
262 prunedRootGraphs.forEach(n -> {
263 System.out.println(n.transformValues(v -> printValue(names, v)));
264 });
265 }
266
267 @Test
268 void graphsDistance1b() {
269 CoreOp.FuncOp f = getFuncOp("distance1b");
270 Function<CodeItem, String> names = names(f);
271 System.out.println(printOpWriteVoid(names, f));
272
273 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f);
274 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream()
275 .filter(n -> n.value() instanceof Op.Result opr &&
276 switch (opr.op()) {
277 // Variable declarations modeling declaration of local variables
278 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
279 // An operation result with no uses
280 default -> opr.uses().isEmpty();
281 })
282 .toList();
283 System.out.println("Pruned root expression graphs");
284 prunedRootGraphs.forEach(n -> {
285 System.out.println(n.transformValues(v -> printValue(names, v)));
286 });
287 }
288
289 @Test
290 void graphsDistanceN() {
291 CoreOp.FuncOp f = getFuncOp("distanceN");
292 Function<CodeItem, String> names = names(f);
293 System.out.println(printOpWriteVoid(names, f));
294
295 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f);
296 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream()
297 .filter(n -> n.value() instanceof Op.Result opr &&
298 switch (opr.op()) {
299 // Variable declarations modeling declaration of local variables
300 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
301 // An operation result with no uses
302 default -> opr.uses().isEmpty();
303 })
304 .toList();
305 System.out.println("Pruned root expression graphs");
306 prunedRootGraphs.forEach(n -> {
307 System.out.println(n.transformValues(v -> printValue(names, v)));
308 });
309 }
310
311 @Test
312 void graphsSquareDiff() {
313 CoreOp.FuncOp f = getFuncOp("squareDiff");
314 Function<CodeItem, String> names = names(f);
315 System.out.println(printOpWriteVoid(names, f));
316
317 {
318 Map<Value, Node<Value>> graphs = expressionGraphs(f);
319 List<Node<Value>> rootGraphs = graphs.values().stream()
320 .filter(n -> n.value() instanceof Op.Result opr &&
321 switch (opr.op()) {
322 // An operation result with no uses
323 default -> opr.uses().isEmpty();
324 })
325 .toList();
326 System.out.println("Root expression graphs");
327 rootGraphs.forEach(n -> {
328 System.out.println(n.transformValues(v -> printValue(names, v)));
329 });
330 }
331
332 {
333 Map<Value, Node<Value>> graphs = expressionGraphs(f);
334 List<Node<Value>> rootGraphs = graphs.values().stream()
335 .filter(n -> n.value() instanceof Op.Result opr &&
336 switch (opr.op()) {
337 // Variable declarations modeling local variables
338 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
339 // An operation result with no uses
340 default -> opr.uses().isEmpty();
341 })
342 .toList();
343 System.out.println("Root (with variable) expression graphs");
344 rootGraphs.forEach(n -> {
345 System.out.println(n.transformValues(v -> printValue(names, v)));
346 });
347 }
348
349 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f);
350 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream()
351 .filter(n -> n.value() instanceof Op.Result opr &&
352 switch (opr.op()) {
353 // Variable declarations modeling local variables
354 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
355 // An operation result with no uses
356 default -> opr.uses().isEmpty();
357 })
358 .toList();
359 System.out.println("Pruned root expression graphs");
360 prunedRootGraphs.forEach(n -> {
361 System.out.println(n.transformValues(v -> printValue(names, v)));
362 });
363 }
364
365
366
367 @Reflect
368 static int h(int x) {
369 x += 2; // Statement 1
370 g(x); // Statement 2
371 int y = 1 + g(x) + (x += 2) + (x > 2 ? x : 10); // Statement 3
372 for ( // Statement 4
373 int i = 0, j = 1; // Statements 4.1
374 i < 3 && j < 3;
375 i++, j++) { // Statements 4.2
376 System.out.println(i); // Statement 4.3
377 }
378 return x + y; // Statement 5
379 }
380
381 static int g(int i) {
382 return i;
383 }
384
385 @Test
386 void graphsH() {
387 CoreOp.FuncOp f = getFuncOp("h");
388 Function<CodeItem, String> names = names(f);
389 System.out.println(printOpWriteVoid(names, f));
390
391 Map<Value, Node<Value>> graphs = prunedExpressionGraphs(f);
392 List<Node<Value>> rootGraphs = graphs.values().stream()
393 .filter(n -> n.value() instanceof Op.Result opr &&
394 switch (opr.op()) {
395 // Variable declarations modeling declaration of local variables
396 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result;
397 // Variable stores modeling assignment expressions whose result is used
398 case CoreOp.VarAccessOp.VarStoreOp vsop -> vsop.operands().get(1).uses().size() == 1;
399 // An operation result with no uses
400 default -> opr.uses().isEmpty();
401 })
402 .toList();
403 rootGraphs.forEach(n -> {
404 System.out.println(n.transformValues(v -> printValue(names, v)));
405 });
406 }
407
408
409 static String printValue(Function<CodeItem, String> names, Value v) {
410 if (v instanceof Op.Result opr) {
411 return printOpHeader(names, opr.op());
412 } else {
413 return "%" + names.apply(v) + " <block parameter>";
414 }
415 }
416
417 static String printOpHeader(Function<CodeItem, String> names, Op op) {
418 return OpWriter.toText(op,
419 OpWriter.OpDescendantsOption.DROP_DESCENDANTS,
420 OpWriter.VoidOpResultOption.WRITE_VOID,
421 OpWriter.CodeItemNamerOption.of(names));
422 }
423
424 static String printOpWriteVoid(Function<CodeItem, String> names, Op op) {
425 return OpWriter.toText(op,
426 OpWriter.VoidOpResultOption.WRITE_VOID,
427 OpWriter.CodeItemNamerOption.of(names));
428 }
429
430 static Function<CodeItem, String> names(Op op) {
431 OpWriter w = new OpWriter(Writer.nullWriter(),
432 OpWriter.VoidOpResultOption.WRITE_VOID);
433 w.writeOp(op);
434 return w.namer();
435 }
436
437
438 static Node<Value> expressionGraph(Value value) {
439 return expressionGraph(new HashMap<>(), value);
440 }
441
442 static Node<Value> expressionGraph(Map<Value, Node<Value>> visited, Value value) {
443 // If value has already been visited return its node
444 if (visited.containsKey(value)) {
445 return visited.get(value);
446 }
447
448 // Find the expression graphs for each operand
449 List<Node<Value>> edges = new ArrayList<>();
450 for (Value operand : value.dependsOn()) {
451 edges.add(expressionGraph(operand));
452 }
453 Node<Value> node = new Node<>(value, edges);
454 visited.put(value, node);
455 return node;
456 }
457
458 static Node<Value> expressionGraphDetailed(Map<Value, Node<Value>> visited, Value value) {
459 // If value has already been visited return its node
460 if (visited.containsKey(value)) {
461 return visited.get(value);
462 }
463
464 List<Node<Value>> edges;
465 if (value instanceof Op.Result result) {
466 edges = new ArrayList<>();
467 // Find the expression graphs for each operand
468 Set<Value> valueVisited = new HashSet<>();
469 for (Value operand : result.op().operands()) {
470 // Ensure an operand is visited only once
471 if (valueVisited.add(operand)) {
472 edges.add(expressionGraph(operand));
473 }
474 }
475 // TODO if terminating operation find expression graphs
476 // for each successor argument
477 } else {
478 assert value instanceof Block.Parameter;
479 // A block parameter has no outgoing edges
480 edges = List.of();
481 }
482 Node<Value> node = new Node<>(value, edges);
483 visited.put(value, node);
484 return node;
485 }
486
487
488 static Node<Value> useGraph(Value value) {
489 return useGraph(new HashMap<>(), value);
490 }
491
492 static Node<Value> useGraph(Map<Value, Node<Value>> visited, Value value) {
493 // If value has already been visited return its node
494 if (visited.containsKey(value)) {
495 return visited.get(value);
496 }
497
498 // Find the use graphs for each use
499 List<Node<Value>> edges = new ArrayList<>();
500 for (Op.Result use : value.uses()) {
501 edges.add(useGraph(visited, use));
502 }
503 Node<Value> node = new Node<>(value, edges);
504 visited.put(value, node);
505 return node;
506 }
507
508 static Map<Value, Node<Value>> expressionGraphs(CoreOp.FuncOp f) {
509 return expressionGraphs(f.body());
510 }
511
512 static Map<Value, Node<Value>> expressionGraphs(Body b) {
513 LinkedHashMap<Value, Node<Value>> graphs = new LinkedHashMap<>();
514 // Traverse the model building structurally shared expression graphs
515 b.elements().forEach(codeElement -> {
516 switch (codeElement) {
517 case Body _ -> {
518 // Do nothing
519 }
520 case Block block -> {
521 // Create the expression graphs for each block parameter
522 // A block parameter has no outgoing edges
523 for (Block.Parameter parameter : block.parameters()) {
524 graphs.put(parameter, new Node<>(parameter, List.of()));
525 }
526 }
527 case Op op -> {
528 // Find the expression graphs for each operand
529 List<Node<Value>> edges = new ArrayList<>();
530 for (Value operand : op.result().dependsOn()) {
531 // Get expression graph for the operand
532 // It must be previously computed since we encounter the
533 // declaration of values before their use
534 edges.add(graphs.get(operand));
535 }
536 // Create the expression graph for this operation result
537 graphs.put(op.result(), new Node<>(op.result(), edges));
538 }
539 }
540 });
541 return graphs;
542 }
543
544 static Map<Value, Node<Value>> prunedExpressionGraphs(CoreOp.FuncOp f) {
545 return prunedExpressionGraphs(f.body());
546 }
547
548 static Map<Value, Node<Value>> prunedExpressionGraphs(Body b) {
549 LinkedHashMap<Value, Node<Value>> graphs = new LinkedHashMap<>();
550 // Traverse the model building structurally shared expression graphs
551 b.elements().forEach(codeElement -> {
552 switch (codeElement) {
553 case Body _ -> {
554 // Do nothing
555 }
556 case Block block -> {
557 // Create the expression graphs for each block parameter
558 // A block parameter has no outgoing edges
559 for (Block.Parameter parameter : block.parameters()) {
560 graphs.put(parameter, new Node<>(parameter, List.of()));
561 }
562 }
563 // Prune graph for variable load operation
564 case CoreOp.VarAccessOp.VarLoadOp op -> {
565 // Ignore edge for the variable value operand
566 graphs.put(op.result(), new Node<>(op.result(), List.of()));
567 }
568 // Prune graph for variable store operation
569 case CoreOp.VarAccessOp.VarStoreOp op -> {
570 // Ignore edge for the variable value operand
571 // Add edge for value to store
572 List<Node<Value>> edges = List.of(graphs.get(op.operands().get(1)));
573 graphs.put(op.result(), new Node<>(op.result(), edges));
574 }
575 case Op op -> {
576 // Find the expression graphs for each operand
577 List<Node<Value>> edges = new ArrayList<>();
578 for (Value operand : op.result().dependsOn()) {
579 // Get expression graph for the operand
580 // It must be previously computed since we encounter the
581 // declaration of values before their use
582 edges.add(graphs.get(operand));
583 }
584 // Create the expression graph for this operation result
585 graphs.put(op.result(), new Node<>(op.result(), edges));
586 }
587 }
588 });
589 return graphs;
590 }
591
592 record Node<T>(T value, List<Node<T>> edges) {
593 <U> Node<U> transformValues(Function<T, U> f) {
594 List<Node<U>> transformedEdges = new ArrayList<>();
595 for (Node<T> edge : edges()) {
596 transformedEdges.add(edge.transformValues(f));
597 }
598 return new Node<>(f.apply(value()), transformedEdges);
599 }
600
601 Node<T> transformGraph(Function<Node<T>, Node<T>> f) {
602 Node<T> apply = f.apply(this);
603 if (apply != this) {
604 // The function returned a new node
605 return apply;
606 } else {
607 // The function returned the same node
608 // Apply the transformation to the children
609 List<Node<T>> transformedEdges = new ArrayList<>();
610 for (Node<T> edge : edges()) {
611 transformedEdges.add(edge.transformGraph(f));
612 }
613 boolean same = IntStream.range(0, edges().size())
614 .allMatch(i -> edges().get(i) ==
615 transformedEdges.get(i));
616 if (same) {
617 return this;
618 } else {
619 return new Node<>(this.value(), transformedEdges);
620 }
621 }
622 }
623
624 @Override
625 public String toString() {
626 StringBuilder sb = new StringBuilder();
627 print(sb, "", "");
628 return sb.toString();
629 }
630
631 private void print(StringBuilder sb, String prefix, String edgePrefix) {
632 sb.append(prefix);
633 sb.append(value);
634 sb.append('\n');
635 for (Iterator<Node<T>> it = edges.iterator(); it.hasNext(); ) {
636 Node<T> edge = it.next();
637 if (it.hasNext()) {
638 edge.print(sb, edgePrefix + "├── ", edgePrefix + "│ ");
639 } else {
640 edge.print(sb, edgePrefix + "└── ", edgePrefix + " ");
641 }
642 }
643 }
644 }
645
646
647 static CoreOp.FuncOp getFuncOp(String name) {
648 Optional<Method> om = Stream.of(TestExpressionGraphs.class.getDeclaredMethods())
649 .filter(m -> m.getName().equals(name))
650 .findFirst();
651
652 Method m = om.get();
653 return Op.ofMethod(m).get();
654 }
655
656 }