1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @modules jdk.incubator.code 27 * @run junit TestExpressionGraphs 28 */ 29 30 import jdk.incubator.code.*; 31 import jdk.incubator.code.analysis.SSA; 32 import jdk.incubator.code.dialect.core.CoreOp; 33 import jdk.incubator.code.extern.OpWriter; 34 import org.junit.jupiter.api.Test; 35 36 import java.io.Writer; 37 import java.lang.reflect.Method; 38 import java.util.*; 39 import java.util.function.Function; 40 import java.util.stream.IntStream; 41 import java.util.stream.Stream; 42 43 public class TestExpressionGraphs { 44 45 @CodeReflection 46 static double sub(double a, double b) { 47 return a - b; 48 } 49 50 @CodeReflection 51 static double distance1(double a, double b) { 52 return Math.abs(a - b); 53 } 54 55 @CodeReflection 56 static double distance1a(final double a, final double b) { 57 final double diff = a - b; 58 final double result = Math.abs(diff); 59 return result; 60 } 61 62 @CodeReflection 63 static double distance1b(final double a, final double b) { 64 final double diff = a - b; 65 // Note, incorrect for negative zero values 66 final double result = diff < 0d ? -diff : diff; 67 return result; 68 } 69 70 @CodeReflection 71 static double distanceN(double[] a, double[] b) { 72 double sum = 0d; 73 for (int i = 0; i < a.length; i++) { 74 sum += Math.pow(a[i] - b[i], 2d); 75 } 76 return Math.sqrt(sum); 77 } 78 79 @CodeReflection 80 static double squareDiff(double a, double b) { 81 // a^2 - b^2 = (a + b) * (a - b) 82 final double plus = a + b; 83 final double minus = a - b; 84 return plus * minus; 85 } 86 87 @Test 88 void traverseSub() throws ReflectiveOperationException { 89 // Get the reflective object for method sub 90 Method m = TestExpressionGraphs.class.getDeclaredMethod( 91 "sub", double.class, double.class); 92 // Get the code model for method sub 93 Optional<CoreOp.FuncOp> oModel = Op.ofMethod(m); 94 CoreOp.FuncOp model = oModel.orElseThrow(); 95 96 // Depth-first search, reporting elements in pre-order 97 model.traverse(null, (acc, codeElement) -> { 98 // Count the depth of the code element by 99 // traversing up the tree from child to parent 100 int depth = 0; 101 CodeElement<?, ?> parent = codeElement; 102 while ((parent = parent.parent()) != null) depth++; 103 // Print out code element class 104 System.out.println(" ".repeat(depth) + codeElement.getClass()); 105 return acc; 106 }); 107 108 // Stream of elements topologically sorted in depth-first search pre-order 109 model.elements().forEach(codeElement -> { 110 // Count the depth of the code element 111 int depth = 0; 112 CodeElement<?, ?> parent = codeElement; 113 while ((parent = parent.parent()) != null) depth++; 114 // Print out code element class 115 System.out.println(" ".repeat(depth) + codeElement.getClass()); 116 }); 117 } 118 119 @Test 120 void traverseDistance1() throws ReflectiveOperationException { 121 // Get the reflective object for method distance1 122 Method m = TestExpressionGraphs.class.getDeclaredMethod( 123 "distance1", double.class, double.class); 124 // Get the code model for method distance1 125 Optional<CoreOp.FuncOp> oModel = Op.ofMethod(m); 126 CoreOp.FuncOp model = oModel.orElseThrow(); 127 128 // Depth-first search, reporting elements in pre-order 129 model.traverse(null, (acc, codeElement) -> { 130 // Count the depth of the code element by 131 // traversing up the tree from child to parent 132 int depth = 0; 133 CodeElement<?, ?> parent = codeElement; 134 while ((parent = parent.parent()) != null) depth++; 135 // Print out code element class 136 System.out.println(" ".repeat(depth) + codeElement.getClass()); 137 return acc; 138 }); 139 140 // Stream of elements topologically sorted in depth-first search pre-order 141 model.elements().forEach(codeElement -> { 142 // Count the depth of the code element 143 int depth = 0; 144 CodeElement<?, ?> parent = codeElement; 145 while ((parent = parent.parent()) != null) depth++; 146 // Print out code element class 147 System.out.println(" ".repeat(depth) + codeElement.getClass()); 148 }); 149 } 150 151 152 @Test 153 void printSub() { 154 CoreOp.FuncOp model = getFuncOp("sub"); 155 print(model); 156 } 157 158 @Test 159 void printDistance1() { 160 CoreOp.FuncOp model = getFuncOp("distance1"); 161 print(model); 162 } 163 164 @Test 165 void printDistance1a() { 166 CoreOp.FuncOp model = getFuncOp("distance1a"); 167 print(model); 168 } 169 170 @Test 171 void printDistance1b() { 172 CoreOp.FuncOp model = getFuncOp("distance1b"); 173 print(model); 174 } 175 176 @Test 177 void printDistanceN() { 178 CoreOp.FuncOp model = getFuncOp("distanceN"); 179 print(model); 180 } 181 182 void print(CoreOp.FuncOp f) { 183 System.out.println(f.toText()); 184 185 f = f.transform(OpTransformer.LOWERING_TRANSFORMER); 186 System.out.println(f.toText()); 187 188 f = SSA.transform(f); 189 System.out.println(f.toText()); 190 } 191 192 193 @Test 194 void graphsDistance1() { 195 CoreOp.FuncOp model = getFuncOp("distance1"); 196 Function<CodeItem, String> names = names(model); 197 System.out.println(printOpWriteVoid(names, model)); 198 199 // Create the expression graph for the terminating operation result 200 Op.Result returnResult = model.body().entryBlock().terminatingOp().result(); 201 Node<Value> returnGraph = expressionGraph(returnResult); 202 System.out.println("Expression graph for terminating operation result"); 203 // Transform from Node<Value> to Node<String> and print the graph 204 System.out.println(returnGraph.transformValues(v -> printValue(names, v))); 205 206 System.out.println("Use graphs for block parameters"); 207 for (Block.Parameter parameter : model.parameters()) { 208 Node<Value> useNode = useGraph(parameter); 209 System.out.println(useNode.transformValues(v -> printValue(names, v))); 210 } 211 212 // Create the expression graphs for all values 213 Map<Value, Node<Value>> graphs = expressionGraphs(model); 214 System.out.println("Expression graphs for all declared values in the model"); 215 graphs.values().forEach(n -> { 216 System.out.println(n.transformValues(v -> printValue(names, v))); 217 }); 218 219 // The graphs for the terminating operation result are the same 220 assert returnGraph.equals(graphs.get(returnGraph.value())); 221 222 // Filter for root graphs, operation results with no uses 223 List<Node<Value>> rootGraphs = graphs.values().stream() 224 .filter(n -> n.value() instanceof Op.Result opr && 225 switch (opr.op()) { 226 // An operation result with no uses 227 default -> opr.uses().isEmpty(); 228 }) 229 .toList(); 230 System.out.println("Root expression graphs"); 231 rootGraphs.forEach(n -> { 232 System.out.println(n.transformValues(v -> printValue(names, v))); 233 }); 234 } 235 236 @Test 237 void graphsDistance1a() { 238 CoreOp.FuncOp f = getFuncOp("distance1a"); 239 Function<CodeItem, String> names = names(f); 240 System.out.println(printOpWriteVoid(names, f)); 241 242 { 243 Map<Value, Node<Value>> graphs = expressionGraphs(f); 244 List<Node<Value>> rootGraphs = graphs.values().stream() 245 .filter(n -> n.value() instanceof Op.Result opr && 246 switch (opr.op()) { 247 // An operation result with no uses 248 default -> opr.uses().isEmpty(); 249 }) 250 .toList(); 251 System.out.println("Root expression graphs"); 252 rootGraphs.forEach(n -> { 253 System.out.println(n.transformValues(v -> printValue(names, v))); 254 }); 255 } 256 257 { 258 Map<Value, Node<Value>> graphs = expressionGraphs(f); 259 List<Node<Value>> rootGraphs = graphs.values().stream() 260 .filter(n -> n.value() instanceof Op.Result opr && 261 switch (opr.op()) { 262 // Variable declarations modeling local variables 263 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 264 // An operation result with no uses 265 default -> opr.uses().isEmpty(); 266 }) 267 .toList(); 268 System.out.println("Root (with variable) expression graphs"); 269 rootGraphs.forEach(n -> { 270 System.out.println(n.transformValues(v -> printValue(names, v))); 271 }); 272 } 273 274 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f); 275 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream() 276 .filter(n -> n.value() instanceof Op.Result opr && 277 switch (opr.op()) { 278 // Variable declarations modeling local variables 279 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 280 // An operation result with no uses 281 default -> opr.uses().isEmpty(); 282 }) 283 .toList(); 284 System.out.println("Pruned root expression graphs"); 285 prunedRootGraphs.forEach(n -> { 286 System.out.println(n.transformValues(v -> printValue(names, v))); 287 }); 288 } 289 290 @Test 291 void graphsDistance1b() { 292 CoreOp.FuncOp f = getFuncOp("distance1b"); 293 Function<CodeItem, String> names = names(f); 294 System.out.println(printOpWriteVoid(names, f)); 295 296 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f); 297 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream() 298 .filter(n -> n.value() instanceof Op.Result opr && 299 switch (opr.op()) { 300 // Variable declarations modeling declaration of local variables 301 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 302 // An operation result with no uses 303 default -> opr.uses().isEmpty(); 304 }) 305 .toList(); 306 System.out.println("Pruned root expression graphs"); 307 prunedRootGraphs.forEach(n -> { 308 System.out.println(n.transformValues(v -> printValue(names, v))); 309 }); 310 } 311 312 @Test 313 void graphsDistanceN() { 314 CoreOp.FuncOp f = getFuncOp("distanceN"); 315 Function<CodeItem, String> names = names(f); 316 System.out.println(printOpWriteVoid(names, f)); 317 318 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f); 319 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream() 320 .filter(n -> n.value() instanceof Op.Result opr && 321 switch (opr.op()) { 322 // Variable declarations modeling declaration of local variables 323 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 324 // An operation result with no uses 325 default -> opr.uses().isEmpty(); 326 }) 327 .toList(); 328 System.out.println("Pruned root expression graphs"); 329 prunedRootGraphs.forEach(n -> { 330 System.out.println(n.transformValues(v -> printValue(names, v))); 331 }); 332 } 333 334 @Test 335 void graphsSquareDiff() { 336 CoreOp.FuncOp f = getFuncOp("squareDiff"); 337 Function<CodeItem, String> names = names(f); 338 System.out.println(printOpWriteVoid(names, f)); 339 340 { 341 Map<Value, Node<Value>> graphs = expressionGraphs(f); 342 List<Node<Value>> rootGraphs = graphs.values().stream() 343 .filter(n -> n.value() instanceof Op.Result opr && 344 switch (opr.op()) { 345 // An operation result with no uses 346 default -> opr.uses().isEmpty(); 347 }) 348 .toList(); 349 System.out.println("Root expression graphs"); 350 rootGraphs.forEach(n -> { 351 System.out.println(n.transformValues(v -> printValue(names, v))); 352 }); 353 } 354 355 { 356 Map<Value, Node<Value>> graphs = expressionGraphs(f); 357 List<Node<Value>> rootGraphs = graphs.values().stream() 358 .filter(n -> n.value() instanceof Op.Result opr && 359 switch (opr.op()) { 360 // Variable declarations modeling local variables 361 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 362 // An operation result with no uses 363 default -> opr.uses().isEmpty(); 364 }) 365 .toList(); 366 System.out.println("Root (with variable) expression graphs"); 367 rootGraphs.forEach(n -> { 368 System.out.println(n.transformValues(v -> printValue(names, v))); 369 }); 370 } 371 372 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f); 373 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream() 374 .filter(n -> n.value() instanceof Op.Result opr && 375 switch (opr.op()) { 376 // Variable declarations modeling local variables 377 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 378 // An operation result with no uses 379 default -> opr.uses().isEmpty(); 380 }) 381 .toList(); 382 System.out.println("Pruned root expression graphs"); 383 prunedRootGraphs.forEach(n -> { 384 System.out.println(n.transformValues(v -> printValue(names, v))); 385 }); 386 } 387 388 389 390 @CodeReflection 391 static int h(int x) { 392 x += 2; // Statement 1 393 g(x); // Statement 2 394 int y = 1 + g(x) + (x += 2) + (x > 2 ? x : 10); // Statement 3 395 for ( // Statement 4 396 int i = 0, j = 1; // Statements 4.1 397 i < 3 && j < 3; 398 i++, j++) { // Statements 4.2 399 System.out.println(i); // Statement 4.3 400 } 401 return x + y; // Statement 5 402 } 403 404 static int g(int i) { 405 return i; 406 } 407 408 @Test 409 void graphsH() { 410 CoreOp.FuncOp f = getFuncOp("h"); 411 Function<CodeItem, String> names = names(f); 412 System.out.println(printOpWriteVoid(names, f)); 413 414 Map<Value, Node<Value>> graphs = prunedExpressionGraphs(f); 415 List<Node<Value>> rootGraphs = graphs.values().stream() 416 .filter(n -> n.value() instanceof Op.Result opr && 417 switch (opr.op()) { 418 // Variable declarations modeling declaration of local variables 419 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 420 // Variable stores modeling assignment expressions whose result is used 421 case CoreOp.VarAccessOp.VarStoreOp vsop -> vsop.operands().get(1).uses().size() == 1; 422 // An operation result with no uses 423 default -> opr.uses().isEmpty(); 424 }) 425 .toList(); 426 rootGraphs.forEach(n -> { 427 System.out.println(n.transformValues(v -> printValue(names, v))); 428 }); 429 } 430 431 432 static String printValue(Function<CodeItem, String> names, Value v) { 433 if (v instanceof Op.Result opr) { 434 return printOpHeader(names, opr.op()); 435 } else { 436 return "%" + names.apply(v) + " <block parameter>"; 437 } 438 } 439 440 static String printOpHeader(Function<CodeItem, String> names, Op op) { 441 return OpWriter.toText(op, 442 OpWriter.OpDescendantsOption.DROP_DESCENDANTS, 443 OpWriter.VoidOpResultOption.WRITE_VOID, 444 OpWriter.CodeItemNamerOption.of(names)); 445 } 446 447 static String printOpWriteVoid(Function<CodeItem, String> names, Op op) { 448 return OpWriter.toText(op, 449 OpWriter.VoidOpResultOption.WRITE_VOID, 450 OpWriter.CodeItemNamerOption.of(names)); 451 } 452 453 static Function<CodeItem, String> names(Op op) { 454 OpWriter w = new OpWriter(Writer.nullWriter(), 455 OpWriter.VoidOpResultOption.WRITE_VOID); 456 w.writeOp(op); 457 return w.namer(); 458 } 459 460 461 static Node<Value> expressionGraph(Value value) { 462 return expressionGraph(new HashMap<>(), value); 463 } 464 465 static Node<Value> expressionGraph(Map<Value, Node<Value>> visited, Value value) { 466 // If value has already been visited return its node 467 if (visited.containsKey(value)) { 468 return visited.get(value); 469 } 470 471 // Find the expression graphs for each operand 472 List<Node<Value>> edges = new ArrayList<>(); 473 for (Value operand : value.dependsOn()) { 474 edges.add(expressionGraph(operand)); 475 } 476 Node<Value> node = new Node<>(value, edges); 477 visited.put(value, node); 478 return node; 479 } 480 481 static Node<Value> expressionGraphDetailed(Map<Value, Node<Value>> visited, Value value) { 482 // If value has already been visited return its node 483 if (visited.containsKey(value)) { 484 return visited.get(value); 485 } 486 487 List<Node<Value>> edges; 488 if (value instanceof Op.Result result) { 489 edges = new ArrayList<>(); 490 // Find the expression graphs for each operand 491 Set<Value> valueVisited = new HashSet<>(); 492 for (Value operand : result.op().operands()) { 493 // Ensure an operand is visited only once 494 if (valueVisited.add(operand)) { 495 edges.add(expressionGraph(operand)); 496 } 497 } 498 // TODO if terminating operation find expression graphs 499 // for each successor argument 500 } else { 501 assert value instanceof Block.Parameter; 502 // A block parameter has no outgoing edges 503 edges = List.of(); 504 } 505 Node<Value> node = new Node<>(value, edges); 506 visited.put(value, node); 507 return node; 508 } 509 510 511 static Node<Value> useGraph(Value value) { 512 return useGraph(new HashMap<>(), value); 513 } 514 515 static Node<Value> useGraph(Map<Value, Node<Value>> visited, Value value) { 516 // If value has already been visited return its node 517 if (visited.containsKey(value)) { 518 return visited.get(value); 519 } 520 521 // Find the use graphs for each use 522 List<Node<Value>> edges = new ArrayList<>(); 523 for (Op.Result use : value.uses()) { 524 edges.add(useGraph(visited, use)); 525 } 526 Node<Value> node = new Node<>(value, edges); 527 visited.put(value, node); 528 return node; 529 } 530 531 static Map<Value, Node<Value>> expressionGraphs(CoreOp.FuncOp f) { 532 return expressionGraphs(f.body()); 533 } 534 535 static Map<Value, Node<Value>> expressionGraphs(Body b) { 536 // Traverse the model building structurally shared expression graphs 537 return b.traverse(new LinkedHashMap<>(), (graphs, codeElement) -> { 538 switch (codeElement) { 539 case Body _ -> { 540 // Do nothing 541 } 542 case Block block -> { 543 // Create the expression graphs for each block parameter 544 // A block parameter has no outgoing edges 545 for (Block.Parameter parameter : block.parameters()) { 546 graphs.put(parameter, new Node<>(parameter, List.of())); 547 } 548 } 549 case Op op -> { 550 // Find the expression graphs for each operand 551 List<Node<Value>> edges = new ArrayList<>(); 552 for (Value operand : op.result().dependsOn()) { 553 // Get expression graph for the operand 554 // It must be previously computed since we encounter the 555 // declaration of values before their use 556 edges.add(graphs.get(operand)); 557 } 558 // Create the expression graph for this operation result 559 graphs.put(op.result(), new Node<>(op.result(), edges)); 560 } 561 } 562 return graphs; 563 }); 564 } 565 566 static Map<Value, Node<Value>> prunedExpressionGraphs(CoreOp.FuncOp f) { 567 return prunedExpressionGraphs(f.body()); 568 } 569 570 static Map<Value, Node<Value>> prunedExpressionGraphs(Body b) { 571 // Traverse the model building structurally shared expression graphs 572 return b.traverse(new LinkedHashMap<>(), (graphs, codeElement) -> { 573 switch (codeElement) { 574 case Body _ -> { 575 // Do nothing 576 } 577 case Block block -> { 578 // Create the expression graphs for each block parameter 579 // A block parameter has no outgoing edges 580 for (Block.Parameter parameter : block.parameters()) { 581 graphs.put(parameter, new Node<>(parameter, List.of())); 582 } 583 } 584 // Prune graph for variable load operation 585 case CoreOp.VarAccessOp.VarLoadOp op -> { 586 // Ignore edge for the variable value operand 587 graphs.put(op.result(), new Node<>(op.result(), List.of())); 588 } 589 // Prune graph for variable store operation 590 case CoreOp.VarAccessOp.VarStoreOp op -> { 591 // Ignore edge for the variable value operand 592 // Add edge for value to store 593 List<Node<Value>> edges = List.of(graphs.get(op.operands().get(1))); 594 graphs.put(op.result(), new Node<>(op.result(), edges)); 595 } 596 case Op op -> { 597 // Find the expression graphs for each operand 598 List<Node<Value>> edges = new ArrayList<>(); 599 for (Value operand : op.result().dependsOn()) { 600 // Get expression graph for the operand 601 // It must be previously computed since we encounter the 602 // declaration of values before their use 603 edges.add(graphs.get(operand)); 604 } 605 // Create the expression graph for this operation result 606 graphs.put(op.result(), new Node<>(op.result(), edges)); 607 } 608 } 609 return graphs; 610 }); 611 } 612 613 record Node<T>(T value, List<Node<T>> edges) { 614 <U> Node<U> transformValues(Function<T, U> f) { 615 List<Node<U>> transformedEdges = new ArrayList<>(); 616 for (Node<T> edge : edges()) { 617 transformedEdges.add(edge.transformValues(f)); 618 } 619 return new Node<>(f.apply(value()), transformedEdges); 620 } 621 622 Node<T> transformGraph(Function<Node<T>, Node<T>> f) { 623 Node<T> apply = f.apply(this); 624 if (apply != this) { 625 // The function returned a new node 626 return apply; 627 } else { 628 // The function returned the same node 629 // Apply the transformation to the children 630 List<Node<T>> transformedEdges = new ArrayList<>(); 631 for (Node<T> edge : edges()) { 632 transformedEdges.add(edge.transformGraph(f)); 633 } 634 boolean same = IntStream.range(0, edges().size()) 635 .allMatch(i -> edges().get(i) == 636 transformedEdges.get(i)); 637 if (same) { 638 return this; 639 } else { 640 return new Node<>(this.value(), transformedEdges); 641 } 642 } 643 } 644 645 @Override 646 public String toString() { 647 StringBuilder sb = new StringBuilder(); 648 print(sb, "", ""); 649 return sb.toString(); 650 } 651 652 private void print(StringBuilder sb, String prefix, String edgePrefix) { 653 sb.append(prefix); 654 sb.append(value); 655 sb.append('\n'); 656 for (Iterator<Node<T>> it = edges.iterator(); it.hasNext(); ) { 657 Node<T> edge = it.next(); 658 if (it.hasNext()) { 659 edge.print(sb, edgePrefix + "├── ", edgePrefix + "│ "); 660 } else { 661 edge.print(sb, edgePrefix + "└── ", edgePrefix + " "); 662 } 663 } 664 } 665 } 666 667 668 static CoreOp.FuncOp getFuncOp(String name) { 669 Optional<Method> om = Stream.of(TestExpressionGraphs.class.getDeclaredMethods()) 670 .filter(m -> m.getName().equals(name)) 671 .findFirst(); 672 673 Method m = om.get(); 674 return Op.ofMethod(m).get(); 675 } 676 677 }