1 /*
  2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package org.openjdk.tests.java.util.stream;
 25 
 26 import jdk.incubator.foreign.MemoryLayout;
 27 import jdk.incubator.foreign.MemorySegment;
 28 
 29 import java.lang.invoke.VarHandle;
 30 import java.util.Collection;
 31 import java.util.List;
 32 import java.util.SpliteratorTestHelper;
 33 import java.util.function.Consumer;
 34 import java.util.function.Function;
 35 import java.util.stream.Collectors;
 36 
 37 import jdk.incubator.foreign.ValueLayout;
 38 import org.testng.annotations.DataProvider;
 39 
 40 public class SegmentTestDataProvider {
 41 
 42     static boolean compareSegmentsByte(Collection<MemorySegment> segments1, Collection<MemorySegment> segments2, boolean isOrdered) {
 43         Function<MemorySegment, Byte> mapper = s -> s.get(ValueLayout.JAVA_BYTE, 0);
 44         List<Byte> list1 = segments1.stream()
 45                 .map(mapper)
 46                 .collect(Collectors.toList());
 47         List<Byte> list2 = segments2.stream()
 48                 .map(mapper)
 49                 .collect(Collectors.toList());
 50         return list1.equals(list2);
 51     }
 52 
 53     static boolean compareSegmentsChar(Collection<MemorySegment> segments1, Collection<MemorySegment> segments2, boolean isOrdered) {
 54         Function<MemorySegment, Character> mapper = s -> s.get(ValueLayout.JAVA_CHAR, 0);
 55         List<Character> list1 = segments1.stream()
 56                 .map(mapper)
 57                 .collect(Collectors.toList());
 58         List<Character> list2 = segments2.stream()
 59                 .map(mapper)
 60                 .collect(Collectors.toList());
 61         return list1.equals(list2);
 62     }
 63 
 64     static boolean compareSegmentsShort(Collection<MemorySegment> segments1, Collection<MemorySegment> segments2, boolean isOrdered) {
 65         Function<MemorySegment, Short> mapper = s -> s.get(ValueLayout.JAVA_SHORT, 0);
 66         List<Short> list1 = segments1.stream()
 67                 .map(mapper)
 68                 .collect(Collectors.toList());
 69         List<Short> list2 = segments2.stream()
 70                 .map(mapper)
 71                 .collect(Collectors.toList());
 72         return list1.equals(list2);
 73     }
 74 
 75     static boolean compareSegmentsInt(Collection<MemorySegment> segments1, Collection<MemorySegment> segments2, boolean isOrdered) {
 76         Function<MemorySegment, Integer> mapper = s -> s.get(ValueLayout.JAVA_INT, 0);
 77         List<Integer> list1 = segments1.stream()
 78                 .map(mapper)
 79                 .collect(Collectors.toList());
 80         List<Integer> list2 = segments2.stream()
 81                 .map(mapper)
 82                 .collect(Collectors.toList());
 83         return list1.equals(list2);
 84     }
 85 
 86     static boolean compareSegmentsLong(Collection<MemorySegment> segments1, Collection<MemorySegment> segments2, boolean isOrdered) {
 87         Function<MemorySegment, Long> mapper = s-> s.get(ValueLayout.JAVA_LONG, 0);
 88         List<Long> list1 = segments1.stream()
 89                 .map(mapper)
 90                 .collect(Collectors.toList());
 91         List<Long> list2 = segments2.stream()
 92                 .map(mapper)
 93                 .collect(Collectors.toList());
 94         return list1.equals(list2);
 95     }
 96 
 97     static boolean compareSegmentsFloat(Collection<MemorySegment> segments1, Collection<MemorySegment> segments2, boolean isOrdered) {
 98         Function<MemorySegment, Float> mapper = s -> s.get(ValueLayout.JAVA_FLOAT, 0);
 99         List<Float> list1 = segments1.stream()
100                 .map(mapper)
101                 .collect(Collectors.toList());
102         List<Float> list2 = segments2.stream()
103                 .map(mapper)
104                 .collect(Collectors.toList());
105         return list1.equals(list2);
106     }
107 
108     static Consumer<MemorySegment> segmentCopier(Consumer<MemorySegment> input) {
109         return segment -> {
110             MemorySegment dest = MemorySegment.ofArray(new byte[(int)segment.byteSize()]);
111             dest.copyFrom(segment);
112             input.accept(dest);
113         };
114     }
115 
116     static boolean compareSegmentsDouble(Collection<MemorySegment> segments1, Collection<MemorySegment> segments2, boolean isOrdered) {
117         Function<MemorySegment, Double> mapper = s -> s.get(ValueLayout.JAVA_DOUBLE, 0);
118         List<Double> list1 = segments1.stream()
119                 .map(mapper)
120                 .collect(Collectors.toList());
121         List<Double> list2 = segments2.stream()
122                 .map(mapper)
123                 .collect(Collectors.toList());
124         return list1.equals(list2);
125     }
126 
127     static void initSegment(MemorySegment segment) {
128         for (int i = 0 ; i < segment.byteSize() ; i++) {
129             segment.set(ValueLayout.JAVA_BYTE, 0, (byte)i);
130         }
131     }
132 
133     static Object[][] spliteratorTestData = {
134             { "bytes", MemoryLayout.sequenceLayout(1024, ValueLayout.JAVA_BYTE), (SpliteratorTestHelper.ContentAsserter<MemorySegment>)SegmentTestDataProvider::compareSegmentsByte },
135             { "chars", MemoryLayout.sequenceLayout(1024, ValueLayout.JAVA_CHAR), (SpliteratorTestHelper.ContentAsserter<MemorySegment>)SegmentTestDataProvider::compareSegmentsChar },
136             { "shorts", MemoryLayout.sequenceLayout(1024, ValueLayout.JAVA_SHORT), (SpliteratorTestHelper.ContentAsserter<MemorySegment>)SegmentTestDataProvider::compareSegmentsShort },
137             { "ints", MemoryLayout.sequenceLayout(1024, ValueLayout.JAVA_INT), (SpliteratorTestHelper.ContentAsserter<MemorySegment>)SegmentTestDataProvider::compareSegmentsInt },
138             { "longs", MemoryLayout.sequenceLayout(1024, ValueLayout.JAVA_LONG), (SpliteratorTestHelper.ContentAsserter<MemorySegment>)SegmentTestDataProvider::compareSegmentsLong },
139             { "floats", MemoryLayout.sequenceLayout(1024, ValueLayout.JAVA_FLOAT), (SpliteratorTestHelper.ContentAsserter<MemorySegment>)SegmentTestDataProvider::compareSegmentsFloat },
140             { "doubles", MemoryLayout.sequenceLayout(1024, ValueLayout.JAVA_DOUBLE), (SpliteratorTestHelper.ContentAsserter<MemorySegment>)SegmentTestDataProvider::compareSegmentsDouble },
141     };
142 
143     // returns an array of (String name, Supplier<Spliterator<MemorySegment>>, ContentAsserter<MemorySegment>)
144     @DataProvider(name = "SegmentSpliterator")
145     public static Object[][] spliteratorProvider() {
146         return spliteratorTestData;
147     }
148 }