1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @modules jdk.incubator.code 27 * @run testng TestExpressionGraphs 28 */ 29 30 import org.testng.annotations.Test; 31 32 import java.io.Writer; 33 import java.lang.reflect.Method; 34 import jdk.incubator.code.*; 35 import jdk.incubator.code.analysis.SSA; 36 import jdk.incubator.code.op.CoreOp; 37 import jdk.incubator.code.writer.OpWriter; 38 import jdk.incubator.code.CodeReflection; 39 import java.util.*; 40 import java.util.function.Function; 41 import java.util.stream.IntStream; 42 import java.util.stream.Stream; 43 44 public class TestExpressionGraphs { 45 46 @CodeReflection 47 static double sub(double a, double b) { 48 return a - b; 49 } 50 51 @CodeReflection 52 static double distance1(double a, double b) { 53 return Math.abs(a - b); 54 } 55 56 @CodeReflection 57 static double distance1a(final double a, final double b) { 58 final double diff = a - b; 59 final double result = Math.abs(diff); 60 return result; 61 } 62 63 @CodeReflection 64 static double distance1b(final double a, final double b) { 65 final double diff = a - b; 66 // Note, incorrect for negative zero values 67 final double result = diff < 0d ? -diff : diff; 68 return result; 69 } 70 71 @CodeReflection 72 static double distanceN(double[] a, double[] b) { 73 double sum = 0d; 74 for (int i = 0; i < a.length; i++) { 75 sum += Math.pow(a[i] - b[i], 2d); 76 } 77 return Math.sqrt(sum); 78 } 79 80 @CodeReflection 81 static double squareDiff(double a, double b) { 82 // a^2 - b^2 = (a + b) * (a - b) 83 final double plus = a + b; 84 final double minus = a - b; 85 return plus * minus; 86 } 87 88 @Test 89 void traverseSub() throws ReflectiveOperationException { 90 // Get the reflective object for method sub 91 Method m = TestExpressionGraphs.class.getDeclaredMethod( 92 "sub", double.class, double.class); 93 // Get the code model for method sub 94 Optional<CoreOp.FuncOp> oModel = Op.ofMethod(m); 95 CoreOp.FuncOp model = oModel.orElseThrow(); 96 97 // Depth-first search, reporting elements in pre-order 98 model.traverse(null, (acc, codeElement) -> { 99 // Count the depth of the code element by 100 // traversing up the tree from child to parent 101 int depth = 0; 102 CodeElement<?, ?> parent = codeElement; 103 while ((parent = parent.parent()) != null) depth++; 104 // Print out code element class 105 System.out.println(" ".repeat(depth) + codeElement.getClass()); 106 return acc; 107 }); 108 109 // Stream of elements topologically sorted in depth-first search pre-order 110 model.elements().forEach(codeElement -> { 111 // Count the depth of the code element 112 int depth = 0; 113 CodeElement<?, ?> parent = codeElement; 114 while ((parent = parent.parent()) != null) depth++; 115 // Print out code element class 116 System.out.println(" ".repeat(depth) + codeElement.getClass()); 117 }); 118 } 119 120 @Test 121 void traverseDistance1() throws ReflectiveOperationException { 122 // Get the reflective object for method distance1 123 Method m = TestExpressionGraphs.class.getDeclaredMethod( 124 "distance1", double.class, double.class); 125 // Get the code model for method distance1 126 Optional<CoreOp.FuncOp> oModel = Op.ofMethod(m); 127 CoreOp.FuncOp model = oModel.orElseThrow(); 128 129 // Depth-first search, reporting elements in pre-order 130 model.traverse(null, (acc, codeElement) -> { 131 // Count the depth of the code element by 132 // traversing up the tree from child to parent 133 int depth = 0; 134 CodeElement<?, ?> parent = codeElement; 135 while ((parent = parent.parent()) != null) depth++; 136 // Print out code element class 137 System.out.println(" ".repeat(depth) + codeElement.getClass()); 138 return acc; 139 }); 140 141 // Stream of elements topologically sorted in depth-first search pre-order 142 model.elements().forEach(codeElement -> { 143 // Count the depth of the code element 144 int depth = 0; 145 CodeElement<?, ?> parent = codeElement; 146 while ((parent = parent.parent()) != null) depth++; 147 // Print out code element class 148 System.out.println(" ".repeat(depth) + codeElement.getClass()); 149 }); 150 } 151 152 153 @Test 154 void printSub() { 155 CoreOp.FuncOp model = getFuncOp("sub"); 156 print(model); 157 } 158 159 @Test 160 void printDistance1() { 161 CoreOp.FuncOp model = getFuncOp("distance1"); 162 print(model); 163 } 164 165 @Test 166 void printDistance1a() { 167 CoreOp.FuncOp model = getFuncOp("distance1a"); 168 print(model); 169 } 170 171 @Test 172 void printDistance1b() { 173 CoreOp.FuncOp model = getFuncOp("distance1b"); 174 print(model); 175 } 176 177 @Test 178 void printDistanceN() { 179 CoreOp.FuncOp model = getFuncOp("distanceN"); 180 print(model); 181 } 182 183 void print(CoreOp.FuncOp f) { 184 System.out.println(f.toText()); 185 186 f = f.transform(OpTransformer.LOWERING_TRANSFORMER); 187 System.out.println(f.toText()); 188 189 f = SSA.transform(f); 190 System.out.println(f.toText()); 191 } 192 193 194 @Test 195 void graphsDistance1() { 196 CoreOp.FuncOp model = getFuncOp("distance1"); 197 Function<CodeItem, String> names = names(model); 198 System.out.println(printOpWriteVoid(names, model)); 199 200 // Create the expression graph for the terminating operation result 201 Op.Result returnResult = model.body().entryBlock().terminatingOp().result(); 202 Node<Value> returnGraph = expressionGraph(returnResult); 203 System.out.println("Expression graph for terminating operation result"); 204 // Transform from Node<Value> to Node<String> and print the graph 205 System.out.println(returnGraph.transformValues(v -> printValue(names, v))); 206 207 System.out.println("Use graphs for block parameters"); 208 for (Block.Parameter parameter : model.parameters()) { 209 Node<Value> useNode = useGraph(parameter); 210 System.out.println(useNode.transformValues(v -> printValue(names, v))); 211 } 212 213 // Create the expression graphs for all values 214 Map<Value, Node<Value>> graphs = expressionGraphs(model); 215 System.out.println("Expression graphs for all declared values in the model"); 216 graphs.values().forEach(n -> { 217 System.out.println(n.transformValues(v -> printValue(names, v))); 218 }); 219 220 // The graphs for the terminating operation result are the same 221 assert returnGraph.equals(graphs.get(returnGraph.value())); 222 223 // Filter for root graphs, operation results with no uses 224 List<Node<Value>> rootGraphs = graphs.values().stream() 225 .filter(n -> n.value() instanceof Op.Result opr && 226 switch (opr.op()) { 227 // An operation result with no uses 228 default -> opr.uses().isEmpty(); 229 }) 230 .toList(); 231 System.out.println("Root expression graphs"); 232 rootGraphs.forEach(n -> { 233 System.out.println(n.transformValues(v -> printValue(names, v))); 234 }); 235 } 236 237 @Test 238 void graphsDistance1a() { 239 CoreOp.FuncOp f = getFuncOp("distance1a"); 240 Function<CodeItem, String> names = names(f); 241 System.out.println(printOpWriteVoid(names, f)); 242 243 { 244 Map<Value, Node<Value>> graphs = expressionGraphs(f); 245 List<Node<Value>> rootGraphs = graphs.values().stream() 246 .filter(n -> n.value() instanceof Op.Result opr && 247 switch (opr.op()) { 248 // An operation result with no uses 249 default -> opr.uses().isEmpty(); 250 }) 251 .toList(); 252 System.out.println("Root expression graphs"); 253 rootGraphs.forEach(n -> { 254 System.out.println(n.transformValues(v -> printValue(names, v))); 255 }); 256 } 257 258 { 259 Map<Value, Node<Value>> graphs = expressionGraphs(f); 260 List<Node<Value>> rootGraphs = graphs.values().stream() 261 .filter(n -> n.value() instanceof Op.Result opr && 262 switch (opr.op()) { 263 // Variable declarations modeling local variables 264 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 265 // An operation result with no uses 266 default -> opr.uses().isEmpty(); 267 }) 268 .toList(); 269 System.out.println("Root (with variable) expression graphs"); 270 rootGraphs.forEach(n -> { 271 System.out.println(n.transformValues(v -> printValue(names, v))); 272 }); 273 } 274 275 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f); 276 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream() 277 .filter(n -> n.value() instanceof Op.Result opr && 278 switch (opr.op()) { 279 // Variable declarations modeling local variables 280 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 281 // An operation result with no uses 282 default -> opr.uses().isEmpty(); 283 }) 284 .toList(); 285 System.out.println("Pruned root expression graphs"); 286 prunedRootGraphs.forEach(n -> { 287 System.out.println(n.transformValues(v -> printValue(names, v))); 288 }); 289 } 290 291 @Test 292 void graphsDistance1b() { 293 CoreOp.FuncOp f = getFuncOp("distance1b"); 294 Function<CodeItem, String> names = names(f); 295 System.out.println(printOpWriteVoid(names, f)); 296 297 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f); 298 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream() 299 .filter(n -> n.value() instanceof Op.Result opr && 300 switch (opr.op()) { 301 // Variable declarations modeling declaration of local variables 302 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 303 // An operation result with no uses 304 default -> opr.uses().isEmpty(); 305 }) 306 .toList(); 307 System.out.println("Pruned root expression graphs"); 308 prunedRootGraphs.forEach(n -> { 309 System.out.println(n.transformValues(v -> printValue(names, v))); 310 }); 311 } 312 313 @Test 314 void graphsDistanceN() { 315 CoreOp.FuncOp f = getFuncOp("distanceN"); 316 Function<CodeItem, String> names = names(f); 317 System.out.println(printOpWriteVoid(names, f)); 318 319 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f); 320 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream() 321 .filter(n -> n.value() instanceof Op.Result opr && 322 switch (opr.op()) { 323 // Variable declarations modeling declaration of local variables 324 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 325 // An operation result with no uses 326 default -> opr.uses().isEmpty(); 327 }) 328 .toList(); 329 System.out.println("Pruned root expression graphs"); 330 prunedRootGraphs.forEach(n -> { 331 System.out.println(n.transformValues(v -> printValue(names, v))); 332 }); 333 } 334 335 @Test 336 void graphsSquareDiff() { 337 CoreOp.FuncOp f = getFuncOp("squareDiff"); 338 Function<CodeItem, String> names = names(f); 339 System.out.println(printOpWriteVoid(names, f)); 340 341 { 342 Map<Value, Node<Value>> graphs = expressionGraphs(f); 343 List<Node<Value>> rootGraphs = graphs.values().stream() 344 .filter(n -> n.value() instanceof Op.Result opr && 345 switch (opr.op()) { 346 // An operation result with no uses 347 default -> opr.uses().isEmpty(); 348 }) 349 .toList(); 350 System.out.println("Root expression graphs"); 351 rootGraphs.forEach(n -> { 352 System.out.println(n.transformValues(v -> printValue(names, v))); 353 }); 354 } 355 356 { 357 Map<Value, Node<Value>> graphs = expressionGraphs(f); 358 List<Node<Value>> rootGraphs = graphs.values().stream() 359 .filter(n -> n.value() instanceof Op.Result opr && 360 switch (opr.op()) { 361 // Variable declarations modeling local variables 362 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 363 // An operation result with no uses 364 default -> opr.uses().isEmpty(); 365 }) 366 .toList(); 367 System.out.println("Root (with variable) expression graphs"); 368 rootGraphs.forEach(n -> { 369 System.out.println(n.transformValues(v -> printValue(names, v))); 370 }); 371 } 372 373 Map<Value, Node<Value>> prunedGraphs = prunedExpressionGraphs(f); 374 List<Node<Value>> prunedRootGraphs = prunedGraphs.values().stream() 375 .filter(n -> n.value() instanceof Op.Result opr && 376 switch (opr.op()) { 377 // Variable declarations modeling local variables 378 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 379 // An operation result with no uses 380 default -> opr.uses().isEmpty(); 381 }) 382 .toList(); 383 System.out.println("Pruned root expression graphs"); 384 prunedRootGraphs.forEach(n -> { 385 System.out.println(n.transformValues(v -> printValue(names, v))); 386 }); 387 } 388 389 390 391 @CodeReflection 392 static int h(int x) { 393 x += 2; // Statement 1 394 g(x); // Statement 2 395 int y = 1 + g(x) + (x += 2) + (x > 2 ? x : 10); // Statement 3 396 for ( // Statement 4 397 int i = 0, j = 1; // Statements 4.1 398 i < 3 && j < 3; 399 i++, j++) { // Statements 4.2 400 System.out.println(i); // Statement 4.3 401 } 402 return x + y; // Statement 5 403 } 404 405 static int g(int i) { 406 return i; 407 } 408 409 @Test 410 void graphsH() { 411 CoreOp.FuncOp f = getFuncOp("h"); 412 Function<CodeItem, String> names = names(f); 413 System.out.println(printOpWriteVoid(names, f)); 414 415 Map<Value, Node<Value>> graphs = prunedExpressionGraphs(f); 416 List<Node<Value>> rootGraphs = graphs.values().stream() 417 .filter(n -> n.value() instanceof Op.Result opr && 418 switch (opr.op()) { 419 // Variable declarations modeling declaration of local variables 420 case CoreOp.VarOp vop -> vop.operands().get(0) instanceof Op.Result; 421 // Variable stores modeling assignment expressions whose result is used 422 case CoreOp.VarAccessOp.VarStoreOp vsop -> vsop.operands().get(1).uses().size() == 1; 423 // An operation result with no uses 424 default -> opr.uses().isEmpty(); 425 }) 426 .toList(); 427 rootGraphs.forEach(n -> { 428 System.out.println(n.transformValues(v -> printValue(names, v))); 429 }); 430 } 431 432 433 static String printValue(Function<CodeItem, String> names, Value v) { 434 if (v instanceof Op.Result opr) { 435 return printOpHeader(names, opr.op()); 436 } else { 437 return "%" + names.apply(v) + " <block parameter>"; 438 } 439 } 440 441 static String printOpHeader(Function<CodeItem, String> names, Op op) { 442 return OpWriter.toText(op, 443 OpWriter.OpDescendantsOption.DROP_DESCENDANTS, 444 OpWriter.VoidOpResultOption.WRITE_VOID, 445 OpWriter.CodeItemNamerOption.of(names)); 446 } 447 448 static String printOpWriteVoid(Function<CodeItem, String> names, Op op) { 449 return OpWriter.toText(op, 450 OpWriter.VoidOpResultOption.WRITE_VOID, 451 OpWriter.CodeItemNamerOption.of(names)); 452 } 453 454 static Function<CodeItem, String> names(Op op) { 455 OpWriter w = new OpWriter(Writer.nullWriter(), 456 OpWriter.VoidOpResultOption.WRITE_VOID); 457 w.writeOp(op); 458 return w.namer(); 459 } 460 461 462 static Node<Value> expressionGraph(Value value) { 463 return expressionGraph(new HashMap<>(), value); 464 } 465 466 static Node<Value> expressionGraph(Map<Value, Node<Value>> visited, Value value) { 467 // If value has already been visited return its node 468 if (visited.containsKey(value)) { 469 return visited.get(value); 470 } 471 472 // Find the expression graphs for each operand 473 List<Node<Value>> edges = new ArrayList<>(); 474 for (Value operand : value.dependsOn()) { 475 edges.add(expressionGraph(operand)); 476 } 477 Node<Value> node = new Node<>(value, edges); 478 visited.put(value, node); 479 return node; 480 } 481 482 static Node<Value> expressionGraphDetailed(Map<Value, Node<Value>> visited, Value value) { 483 // If value has already been visited return its node 484 if (visited.containsKey(value)) { 485 return visited.get(value); 486 } 487 488 List<Node<Value>> edges; 489 if (value instanceof Op.Result result) { 490 edges = new ArrayList<>(); 491 // Find the expression graphs for each operand 492 Set<Value> valueVisited = new HashSet<>(); 493 for (Value operand : result.op().operands()) { 494 // Ensure an operand is visited only once 495 if (valueVisited.add(operand)) { 496 edges.add(expressionGraph(operand)); 497 } 498 } 499 // TODO if terminating operation find expression graphs 500 // for each successor argument 501 } else { 502 assert value instanceof Block.Parameter; 503 // A block parameter has no outgoing edges 504 edges = List.of(); 505 } 506 Node<Value> node = new Node<>(value, edges); 507 visited.put(value, node); 508 return node; 509 } 510 511 512 static Node<Value> useGraph(Value value) { 513 return useGraph(new HashMap<>(), value); 514 } 515 516 static Node<Value> useGraph(Map<Value, Node<Value>> visited, Value value) { 517 // If value has already been visited return its node 518 if (visited.containsKey(value)) { 519 return visited.get(value); 520 } 521 522 // Find the use graphs for each use 523 List<Node<Value>> edges = new ArrayList<>(); 524 for (Op.Result use : value.uses()) { 525 edges.add(useGraph(visited, use)); 526 } 527 Node<Value> node = new Node<>(value, edges); 528 visited.put(value, node); 529 return node; 530 } 531 532 static Map<Value, Node<Value>> expressionGraphs(CoreOp.FuncOp f) { 533 return expressionGraphs(f.body()); 534 } 535 536 static Map<Value, Node<Value>> expressionGraphs(Body b) { 537 // Traverse the model building structurally shared expression graphs 538 return b.traverse(new LinkedHashMap<>(), (graphs, codeElement) -> { 539 switch (codeElement) { 540 case Body _ -> { 541 // Do nothing 542 } 543 case Block block -> { 544 // Create the expression graphs for each block parameter 545 // A block parameter has no outgoing edges 546 for (Block.Parameter parameter : block.parameters()) { 547 graphs.put(parameter, new Node<>(parameter, List.of())); 548 } 549 } 550 case Op op -> { 551 // Find the expression graphs for each operand 552 List<Node<Value>> edges = new ArrayList<>(); 553 for (Value operand : op.result().dependsOn()) { 554 // Get expression graph for the operand 555 // It must be previously computed since we encounter the 556 // declaration of values before their use 557 edges.add(graphs.get(operand)); 558 } 559 // Create the expression graph for this operation result 560 graphs.put(op.result(), new Node<>(op.result(), edges)); 561 } 562 } 563 return graphs; 564 }); 565 } 566 567 static Map<Value, Node<Value>> prunedExpressionGraphs(CoreOp.FuncOp f) { 568 return prunedExpressionGraphs(f.body()); 569 } 570 571 static Map<Value, Node<Value>> prunedExpressionGraphs(Body b) { 572 // Traverse the model building structurally shared expression graphs 573 return b.traverse(new LinkedHashMap<>(), (graphs, codeElement) -> { 574 switch (codeElement) { 575 case Body _ -> { 576 // Do nothing 577 } 578 case Block block -> { 579 // Create the expression graphs for each block parameter 580 // A block parameter has no outgoing edges 581 for (Block.Parameter parameter : block.parameters()) { 582 graphs.put(parameter, new Node<>(parameter, List.of())); 583 } 584 } 585 // Prune graph for variable load operation 586 case CoreOp.VarAccessOp.VarLoadOp op -> { 587 // Ignore edge for the variable value operand 588 graphs.put(op.result(), new Node<>(op.result(), List.of())); 589 } 590 // Prune graph for variable store operation 591 case CoreOp.VarAccessOp.VarStoreOp op -> { 592 // Ignore edge for the variable value operand 593 // Add edge for value to store 594 List<Node<Value>> edges = List.of(graphs.get(op.operands().get(1))); 595 graphs.put(op.result(), new Node<>(op.result(), edges)); 596 } 597 case Op op -> { 598 // Find the expression graphs for each operand 599 List<Node<Value>> edges = new ArrayList<>(); 600 for (Value operand : op.result().dependsOn()) { 601 // Get expression graph for the operand 602 // It must be previously computed since we encounter the 603 // declaration of values before their use 604 edges.add(graphs.get(operand)); 605 } 606 // Create the expression graph for this operation result 607 graphs.put(op.result(), new Node<>(op.result(), edges)); 608 } 609 } 610 return graphs; 611 }); 612 } 613 614 record Node<T>(T value, List<Node<T>> edges) { 615 <U> Node<U> transformValues(Function<T, U> f) { 616 List<Node<U>> transformedEdges = new ArrayList<>(); 617 for (Node<T> edge : edges()) { 618 transformedEdges.add(edge.transformValues(f)); 619 } 620 return new Node<>(f.apply(value()), transformedEdges); 621 } 622 623 Node<T> transformGraph(Function<Node<T>, Node<T>> f) { 624 Node<T> apply = f.apply(this); 625 if (apply != this) { 626 // The function returned a new node 627 return apply; 628 } else { 629 // The function returned the same node 630 // Apply the transformation to the children 631 List<Node<T>> transformedEdges = new ArrayList<>(); 632 for (Node<T> edge : edges()) { 633 transformedEdges.add(edge.transformGraph(f)); 634 } 635 boolean same = IntStream.range(0, edges().size()) 636 .allMatch(i -> edges().get(i) == 637 transformedEdges.get(i)); 638 if (same) { 639 return this; 640 } else { 641 return new Node<>(this.value(), transformedEdges); 642 } 643 } 644 } 645 646 @Override 647 public String toString() { 648 StringBuilder sb = new StringBuilder(); 649 print(sb, "", ""); 650 return sb.toString(); 651 } 652 653 private void print(StringBuilder sb, String prefix, String edgePrefix) { 654 sb.append(prefix); 655 sb.append(value); 656 sb.append('\n'); 657 for (Iterator<Node<T>> it = edges.iterator(); it.hasNext(); ) { 658 Node<T> edge = it.next(); 659 if (it.hasNext()) { 660 edge.print(sb, edgePrefix + "├── ", edgePrefix + "│ "); 661 } else { 662 edge.print(sb, edgePrefix + "└── ", edgePrefix + " "); 663 } 664 } 665 } 666 } 667 668 669 static CoreOp.FuncOp getFuncOp(String name) { 670 Optional<Method> om = Stream.of(TestExpressionGraphs.class.getDeclaredMethods()) 671 .filter(m -> m.getName().equals(name)) 672 .findFirst(); 673 674 Method m = om.get(); 675 return Op.ofMethod(m).get(); 676 } 677 678 }