1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.triton; 27 28 import jdk.incubator.code.TypeElement; 29 import jdk.incubator.code.extern.ExternalizedTypeElement; 30 import java.util.ArrayList; 31 import java.util.List; 32 import java.util.Objects; 33 34 public final class TensorType extends TritonType { 35 static final String NAME = "tensor"; 36 37 final TypeElement eType; 38 final List<Integer> shape; 39 final int size; 40 41 public TensorType(TypeElement eType, List<Integer> shape) { 42 this.eType = eType; 43 this.shape = List.copyOf(shape); 44 int s = 1; 45 for (Integer i : shape) { 46 s *= i; 47 } 48 this.size = s; 49 } 50 51 public TypeElement eType() { 52 return eType; 53 } 54 55 public List<Integer> shape() { 56 return shape; 57 } 58 59 public int size() { 60 return size; 61 } 62 63 @Override 64 public boolean equals(Object o) { 65 if (this == o) return true; 66 if (o == null || getClass() != o.getClass()) return false; 67 TensorType that = (TensorType) o; 68 return Objects.equals(eType, that.eType) && Objects.equals(shape, that.shape); 69 } 70 71 @Override 72 public int hashCode() { 73 return Objects.hash(eType, shape); 74 } 75 76 @Override 77 public ExternalizedTypeElement externalize() { 78 List<ExternalizedTypeElement> args = new ArrayList<>(); 79 for (int i : shape) { 80 args.add(new ExternalizedTypeElement("x" + i, List.of())); 81 } 82 args.add(eType.externalize()); 83 return ExternalizedTypeElement.of(NAME, args); 84 } 85 86 @Override 87 public String toString() { 88 return externalize().toString(); 89 } 90 }