1 /*
  2  * Copyright (c) 2020, 2022, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @enablePreview
 27  * @run testng TestReshape
 28  */
 29 
 30 import java.lang.foreign.MemoryLayout;
 31 import java.lang.foreign.SequenceLayout;
 32 import java.lang.foreign.ValueLayout;
 33 import java.util.ArrayList;
 34 import java.util.Iterator;
 35 import java.util.List;
 36 import java.util.stream.LongStream;
 37 
 38 import org.testng.annotations.*;
 39 import static org.testng.Assert.*;
 40 
 41 public class TestReshape {
 42 
 43     @Test(dataProvider = "shapes")
 44     public void testReshape(MemoryLayout layout, long[] expectedShape) {
 45         long flattenedSize = LongStream.of(expectedShape).reduce(1L, Math::multiplyExact);
 46         SequenceLayout seq_flattened = MemoryLayout.sequenceLayout(flattenedSize, layout);
 47         assertDimensions(seq_flattened, flattenedSize);
 48         for (long[] shape : new Shape(expectedShape)) {
 49             SequenceLayout seq_shaped = seq_flattened.reshape(shape);
 50             assertDimensions(seq_shaped, expectedShape);
 51             assertEquals(seq_shaped.flatten(), seq_flattened);
 52         }
 53     }
 54 
 55     @Test(expectedExceptions = IllegalArgumentException.class)
 56     public void testInvalidReshape() {
 57         SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
 58         seq.reshape(3, 2);
 59     }
 60 
 61     @Test(expectedExceptions = IllegalArgumentException.class)
 62     public void testBadReshapeInference() {
 63         SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
 64         seq.reshape(-1, -1);
 65     }
 66 
 67     @Test(expectedExceptions = IllegalArgumentException.class)
 68     public void testBadReshapeParameterZero() {
 69         SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
 70         seq.reshape(0, 4);
 71     }
 72 
 73     @Test(expectedExceptions = IllegalArgumentException.class)
 74     public void testBadReshapeParameterNegative() {
 75         SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
 76         seq.reshape(-2, 2);
 77     }
 78 
 79     static void assertDimensions(SequenceLayout layout, long... dims) {
 80         SequenceLayout prev = null;
 81         for (int i = 0 ; i < dims.length ; i++) {
 82             if (prev != null) {
 83                 layout = (SequenceLayout)prev.elementLayout();
 84             }
 85             assertEquals(layout.elementCount(), dims[i]);
 86             prev = layout;
 87         }
 88     }
 89 
 90     static class Shape implements Iterable<long[]> {
 91         long[] shape;
 92 
 93         Shape(long... shape) {
 94             this.shape = shape;
 95         }
 96 
 97         public Iterator<long[]> iterator() {
 98             List<long[]> shapes = new ArrayList<>();
 99             shapes.add(shape);
100             for (int i = 0 ; i < shape.length ; i++) {
101                 long[] inferredShape = shape.clone();
102                 inferredShape[i] = -1;
103                 shapes.add(inferredShape);
104             }
105             return shapes.iterator();
106         }
107     }
108 
109     static MemoryLayout POINT = MemoryLayout.structLayout(
110             ValueLayout.JAVA_INT,
111             ValueLayout.JAVA_INT
112     );
113 
114     @DataProvider(name = "shapes")
115     Object[][] shapes() {
116         return new Object[][] {
117                 { ValueLayout.JAVA_BYTE, new long[] { 256 } },
118                 { ValueLayout.JAVA_BYTE, new long[] { 16, 16 } },
119                 { ValueLayout.JAVA_BYTE, new long[] { 4, 4, 4, 4 } },
120                 { ValueLayout.JAVA_BYTE, new long[] { 2, 8, 16 } },
121                 { ValueLayout.JAVA_BYTE, new long[] { 16, 8, 2 } },
122                 { ValueLayout.JAVA_BYTE, new long[] { 8, 16, 2 } },
123 
124                 { ValueLayout.JAVA_SHORT, new long[] { 256 } },
125                 { ValueLayout.JAVA_SHORT, new long[] { 16, 16 } },
126                 { ValueLayout.JAVA_SHORT, new long[] { 4, 4, 4, 4 } },
127                 { ValueLayout.JAVA_SHORT, new long[] { 2, 8, 16 } },
128                 { ValueLayout.JAVA_SHORT, new long[] { 16, 8, 2 } },
129                 { ValueLayout.JAVA_SHORT, new long[] { 8, 16, 2 } },
130 
131                 { ValueLayout.JAVA_CHAR, new long[] { 256 } },
132                 { ValueLayout.JAVA_CHAR, new long[] { 16, 16 } },
133                 { ValueLayout.JAVA_CHAR, new long[] { 4, 4, 4, 4 } },
134                 { ValueLayout.JAVA_CHAR, new long[] { 2, 8, 16 } },
135                 { ValueLayout.JAVA_CHAR, new long[] { 16, 8, 2 } },
136                 { ValueLayout.JAVA_CHAR, new long[] { 8, 16, 2 } },
137 
138                 { ValueLayout.JAVA_INT, new long[] { 256 } },
139                 { ValueLayout.JAVA_INT, new long[] { 16, 16 } },
140                 { ValueLayout.JAVA_INT, new long[] { 4, 4, 4, 4 } },
141                 { ValueLayout.JAVA_INT, new long[] { 2, 8, 16 } },
142                 { ValueLayout.JAVA_INT, new long[] { 16, 8, 2 } },
143                 { ValueLayout.JAVA_INT, new long[] { 8, 16, 2 } },
144 
145                 { ValueLayout.JAVA_LONG, new long[] { 256 } },
146                 { ValueLayout.JAVA_LONG, new long[] { 16, 16 } },
147                 { ValueLayout.JAVA_LONG, new long[] { 4, 4, 4, 4 } },
148                 { ValueLayout.JAVA_LONG, new long[] { 2, 8, 16 } },
149                 { ValueLayout.JAVA_LONG, new long[] { 16, 8, 2 } },
150                 { ValueLayout.JAVA_LONG, new long[] { 8, 16, 2 } },
151 
152                 { ValueLayout.JAVA_FLOAT, new long[] { 256 } },
153                 { ValueLayout.JAVA_FLOAT, new long[] { 16, 16 } },
154                 { ValueLayout.JAVA_FLOAT, new long[] { 4, 4, 4, 4 } },
155                 { ValueLayout.JAVA_FLOAT, new long[] { 2, 8, 16 } },
156                 { ValueLayout.JAVA_FLOAT, new long[] { 16, 8, 2 } },
157                 { ValueLayout.JAVA_FLOAT, new long[] { 8, 16, 2 } },
158 
159                 { ValueLayout.JAVA_DOUBLE, new long[] { 256 } },
160                 { ValueLayout.JAVA_DOUBLE, new long[] { 16, 16 } },
161                 { ValueLayout.JAVA_DOUBLE, new long[] { 4, 4, 4, 4 } },
162                 { ValueLayout.JAVA_DOUBLE, new long[] { 2, 8, 16 } },
163                 { ValueLayout.JAVA_DOUBLE, new long[] { 16, 8, 2 } },
164                 { ValueLayout.JAVA_DOUBLE, new long[] { 8, 16, 2 } },
165 
166                 { POINT, new long[] { 256 } },
167                 { POINT, new long[] { 16, 16 } },
168                 { POINT, new long[] { 4, 4, 4, 4 } },
169                 { POINT, new long[] { 2, 8, 16 } },
170                 { POINT, new long[] { 16, 8, 2 } },
171                 { POINT, new long[] { 8, 16, 2 } },
172         };
173     }
174 }