1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.triton; 27 28 import java.lang.reflect.code.TypeElement; 29 import java.util.ArrayList; 30 import java.util.List; 31 import java.util.Objects; 32 33 public final class TensorType extends TritonType { 34 static final String NAME = "tensor"; 35 36 final TypeElement eType; 37 final List<Integer> shape; 38 final int size; 39 40 public TensorType(TypeElement eType, List<Integer> shape) { 41 this.eType = eType; 42 this.shape = List.copyOf(shape); 43 int s = 1; 44 for (Integer i : shape) { 45 s *= i; 46 } 47 this.size = s; 48 } 49 50 public TypeElement eType() { 51 return eType; 52 } 53 54 public List<Integer> shape() { 55 return shape; 56 } 57 58 public int size() { 59 return size; 60 } 61 62 @Override 63 public boolean equals(Object o) { 64 if (this == o) return true; 65 if (o == null || getClass() != o.getClass()) return false; 66 TensorType that = (TensorType) o; 67 return Objects.equals(eType, that.eType) && Objects.equals(shape, that.shape); 68 } 69 70 @Override 71 public int hashCode() { 72 return Objects.hash(eType, shape); 73 } 74 75 @Override 76 public ExternalizedTypeElement externalize() { 77 List<ExternalizedTypeElement> args = new ArrayList<>(); 78 for (int i : shape) { 79 args.add(new ExternalizedTypeElement("x" + i, List.of())); 80 } 81 args.add(eType.externalize()); 82 return new ExternalizedTypeElement(NAME, args); 83 } 84 85 @Override 86 public String toString() { 87 return externalize().toString(); 88 } 89 }