1 /*
  2  * Copyright (c) 2020, 2022, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @run testng TestReshape
 27  */
 28 
 29 import java.lang.foreign.MemoryLayout;
 30 import java.lang.foreign.SequenceLayout;
 31 import java.lang.foreign.ValueLayout;
 32 import java.util.ArrayList;
 33 import java.util.Iterator;
 34 import java.util.List;
 35 import java.util.stream.LongStream;
 36 
 37 import org.testng.annotations.*;
 38 import static org.testng.Assert.*;
 39 
 40 public class TestReshape {
 41 
 42     @Test(dataProvider = "shapes")
 43     public void testReshape(MemoryLayout layout, long[] expectedShape) {
 44         long flattenedSize = LongStream.of(expectedShape).reduce(1L, Math::multiplyExact);
 45         SequenceLayout seq_flattened = MemoryLayout.sequenceLayout(flattenedSize, layout);
 46         assertDimensions(seq_flattened, flattenedSize);
 47         for (long[] shape : new Shape(expectedShape)) {
 48             SequenceLayout seq_shaped = seq_flattened.reshape(shape);
 49             assertDimensions(seq_shaped, expectedShape);
 50             assertEquals(seq_shaped.flatten(), seq_flattened);
 51         }
 52     }
 53 
 54     @Test(expectedExceptions = IllegalArgumentException.class)
 55     public void testInvalidReshape() {
 56         SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
 57         seq.reshape(3, 2);
 58     }
 59 
 60     @Test(expectedExceptions = IllegalArgumentException.class)
 61     public void testBadReshapeInference() {
 62         SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
 63         seq.reshape(-1, -1);
 64     }
 65 
 66     @Test(expectedExceptions = IllegalArgumentException.class)
 67     public void testBadReshapeParameterZero() {
 68         SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
 69         seq.reshape(0, 4);
 70     }
 71 
 72     @Test(expectedExceptions = IllegalArgumentException.class)
 73     public void testBadReshapeParameterNegative() {
 74         SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
 75         seq.reshape(-2, 2);
 76     }
 77 
 78     static void assertDimensions(SequenceLayout layout, long... dims) {
 79         SequenceLayout prev = null;
 80         for (int i = 0 ; i < dims.length ; i++) {
 81             if (prev != null) {
 82                 layout = (SequenceLayout)prev.elementLayout();
 83             }
 84             assertEquals(layout.elementCount(), dims[i]);
 85             prev = layout;
 86         }
 87     }
 88 
 89     static class Shape implements Iterable<long[]> {
 90         long[] shape;
 91 
 92         Shape(long... shape) {
 93             this.shape = shape;
 94         }
 95 
 96         public Iterator<long[]> iterator() {
 97             List<long[]> shapes = new ArrayList<>();
 98             shapes.add(shape);
 99             for (int i = 0 ; i < shape.length ; i++) {
100                 long[] inferredShape = shape.clone();
101                 inferredShape[i] = -1;
102                 shapes.add(inferredShape);
103             }
104             return shapes.iterator();
105         }
106     }
107 
108     static MemoryLayout POINT = MemoryLayout.structLayout(
109             ValueLayout.JAVA_INT,
110             ValueLayout.JAVA_INT
111     );
112 
113     @DataProvider(name = "shapes")
114     Object[][] shapes() {
115         return new Object[][] {
116                 { ValueLayout.JAVA_BYTE, new long[] { 256 } },
117                 { ValueLayout.JAVA_BYTE, new long[] { 16, 16 } },
118                 { ValueLayout.JAVA_BYTE, new long[] { 4, 4, 4, 4 } },
119                 { ValueLayout.JAVA_BYTE, new long[] { 2, 8, 16 } },
120                 { ValueLayout.JAVA_BYTE, new long[] { 16, 8, 2 } },
121                 { ValueLayout.JAVA_BYTE, new long[] { 8, 16, 2 } },
122 
123                 { ValueLayout.JAVA_SHORT, new long[] { 256 } },
124                 { ValueLayout.JAVA_SHORT, new long[] { 16, 16 } },
125                 { ValueLayout.JAVA_SHORT, new long[] { 4, 4, 4, 4 } },
126                 { ValueLayout.JAVA_SHORT, new long[] { 2, 8, 16 } },
127                 { ValueLayout.JAVA_SHORT, new long[] { 16, 8, 2 } },
128                 { ValueLayout.JAVA_SHORT, new long[] { 8, 16, 2 } },
129 
130                 { ValueLayout.JAVA_CHAR, new long[] { 256 } },
131                 { ValueLayout.JAVA_CHAR, new long[] { 16, 16 } },
132                 { ValueLayout.JAVA_CHAR, new long[] { 4, 4, 4, 4 } },
133                 { ValueLayout.JAVA_CHAR, new long[] { 2, 8, 16 } },
134                 { ValueLayout.JAVA_CHAR, new long[] { 16, 8, 2 } },
135                 { ValueLayout.JAVA_CHAR, new long[] { 8, 16, 2 } },
136 
137                 { ValueLayout.JAVA_INT, new long[] { 256 } },
138                 { ValueLayout.JAVA_INT, new long[] { 16, 16 } },
139                 { ValueLayout.JAVA_INT, new long[] { 4, 4, 4, 4 } },
140                 { ValueLayout.JAVA_INT, new long[] { 2, 8, 16 } },
141                 { ValueLayout.JAVA_INT, new long[] { 16, 8, 2 } },
142                 { ValueLayout.JAVA_INT, new long[] { 8, 16, 2 } },
143 
144                 { ValueLayout.JAVA_LONG, new long[] { 256 } },
145                 { ValueLayout.JAVA_LONG, new long[] { 16, 16 } },
146                 { ValueLayout.JAVA_LONG, new long[] { 4, 4, 4, 4 } },
147                 { ValueLayout.JAVA_LONG, new long[] { 2, 8, 16 } },
148                 { ValueLayout.JAVA_LONG, new long[] { 16, 8, 2 } },
149                 { ValueLayout.JAVA_LONG, new long[] { 8, 16, 2 } },
150 
151                 { ValueLayout.JAVA_FLOAT, new long[] { 256 } },
152                 { ValueLayout.JAVA_FLOAT, new long[] { 16, 16 } },
153                 { ValueLayout.JAVA_FLOAT, new long[] { 4, 4, 4, 4 } },
154                 { ValueLayout.JAVA_FLOAT, new long[] { 2, 8, 16 } },
155                 { ValueLayout.JAVA_FLOAT, new long[] { 16, 8, 2 } },
156                 { ValueLayout.JAVA_FLOAT, new long[] { 8, 16, 2 } },
157 
158                 { ValueLayout.JAVA_DOUBLE, new long[] { 256 } },
159                 { ValueLayout.JAVA_DOUBLE, new long[] { 16, 16 } },
160                 { ValueLayout.JAVA_DOUBLE, new long[] { 4, 4, 4, 4 } },
161                 { ValueLayout.JAVA_DOUBLE, new long[] { 2, 8, 16 } },
162                 { ValueLayout.JAVA_DOUBLE, new long[] { 16, 8, 2 } },
163                 { ValueLayout.JAVA_DOUBLE, new long[] { 8, 16, 2 } },
164 
165                 { POINT, new long[] { 256 } },
166                 { POINT, new long[] { 16, 16 } },
167                 { POINT, new long[] { 4, 4, 4, 4 } },
168                 { POINT, new long[] { 2, 8, 16 } },
169                 { POINT, new long[] { 16, 8, 2 } },
170                 { POINT, new long[] { 8, 16, 2 } },
171         };
172     }
173 }