1 /*
  2  * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /**
 25  * @test
 26  * @bug 8074981
 27  * @summary Add C2 x86 Superword support for scalar product reduction optimizations : int test
 28  * @requires os.arch=="x86" | os.arch=="i386" | os.arch=="amd64" | os.arch=="x86_64" | os.arch=="aarch64" | os.arch=="riscv64"
 29  *
 30  * @run main/othervm -XX:+IgnoreUnrecognizedVMOptions
 31  *      -XX:LoopUnrollLimit=250 -XX:CompileThresholdScaling=0.1
 32  *      -XX:CompileCommand=exclude,compiler.loopopts.superword.ReductionPerf::main
 33  *      -XX:+SuperWordReductions
 34  *      compiler.loopopts.superword.ReductionPerf
 35  * @run main/othervm -XX:+IgnoreUnrecognizedVMOptions
 36  *      -XX:LoopUnrollLimit=250 -XX:CompileThresholdScaling=0.1
 37  *      -XX:CompileCommand=exclude,compiler.loopopts.superword.ReductionPerf::main
 38  *      -XX:-SuperWordReductions
 39  *      compiler.loopopts.superword.ReductionPerf
 40  */
 41 
 42 package compiler.loopopts.superword;
 43 
 44 public class ReductionPerf {
 45     public static void main(String[] args) throws Exception {
 46         int[] a1 = new int[8 * 1024];
 47         int[] a2 = new int[8 * 1024];
 48         int[] a3 = new int[8 * 1024];
 49         long[] b1 = new long[8 * 1024];
 50         long[] b2 = new long[8 * 1024];
 51         long[] b3 = new long[8 * 1024];
 52         float[] c1 = new float[8 * 1024];
 53         float[] c2 = new float[8 * 1024];
 54         float[] c3 = new float[8 * 1024];
 55         double[] d1 = new double[8 * 1024];
 56         double[] d2 = new double[8 * 1024];
 57         double[] d3 = new double[8 * 1024];
 58 
 59         ReductionInit(a1, a2, a3, b1, b2, b3, c1, c2, c3, d1, d2, d3);
 60 
 61         int sumIv = sumInt(a1, a2, a3);
 62         long sumLv = sumLong(b1, b2, b3);
 63         float sumFv = sumFloat(c1, c2, c3);
 64         double sumDv = sumDouble(d1, d2, d3);
 65         int mulIv = prodInt(a1, a2, a3);
 66         long mulLv = prodLong(b1, b2, b3);
 67         float mulFv = prodFloat(c1, c2, c3);
 68         double mulDv = prodDouble(d1, d2, d3);
 69 
 70         int sumI = 0;
 71         long sumL = 0;
 72         float sumF = 0.f;
 73         double sumD = 0.;
 74         int mulI = 0;
 75         long mulL = 0;
 76         float mulF = 0.f;
 77         double mulD = 0.;
 78 
 79         System.out.println("Warmup ...");
 80         long start = System.currentTimeMillis();
 81 
 82         for (int j = 0; j < 2000; j++) {
 83             sumI = sumInt(a1, a2, a3);
 84             sumL = sumLong(b1, b2, b3);
 85             sumF = sumFloat(c1, c2, c3);
 86             sumD = sumDouble(d1, d2, d3);
 87             mulI = prodInt(a1, a2, a3);
 88             mulL = prodLong(b1, b2, b3);
 89             mulF = prodFloat(c1, c2, c3);
 90             mulD = prodDouble(d1, d2, d3);
 91         }
 92 
 93         long stop = System.currentTimeMillis();
 94         System.out.println(" Warmup is done in " + (stop - start) + " msec");
 95 
 96         if (sumIv != sumI) {
 97             System.out.println("sum int:    " + sumIv + " != " + sumI);
 98         }
 99         if (sumLv != sumL) {
100             System.out.println("sum long:   " + sumLv + " != " + sumL);
101         }
102         if (sumFv != sumF) {
103             System.out.println("sum float:  " + sumFv + " != " + sumF);
104         }
105         if (sumDv != sumD) {
106             System.out.println("sum double: " + sumDv + " != " + sumD);
107         }
108         if (mulIv != mulI) {
109             System.out.println("prod int:    " + mulIv + " != " + mulI);
110         }
111         if (mulLv != mulL) {
112             System.out.println("prod long:   " + mulLv + " != " + mulL);
113         }
114         if (mulFv != mulF) {
115             System.out.println("prod float:  " + mulFv + " != " + mulF);
116         }
117         if (mulDv != mulD) {
118             System.out.println("prod double: " + mulDv + " != " + mulD);
119         }
120 
121         start = System.currentTimeMillis();
122         for (int j = 0; j < 5000; j++) {
123             sumI = sumInt(a1, a2, a3);
124         }
125         stop = System.currentTimeMillis();
126         System.out.println("sum int:    " + (stop - start));
127 
128         start = System.currentTimeMillis();
129         for (int j = 0; j < 5000; j++) {
130             sumL = sumLong(b1, b2, b3);
131         }
132         stop = System.currentTimeMillis();
133         System.out.println("sum long:   " + (stop - start));
134 
135         start = System.currentTimeMillis();
136         for (int j = 0; j < 5000; j++) {
137             sumF = sumFloat(c1, c2, c3);
138         }
139         stop = System.currentTimeMillis();
140         System.out.println("sum float:  " + (stop - start));
141 
142         start = System.currentTimeMillis();
143         for (int j = 0; j < 5000; j++) {
144             sumD = sumDouble(d1, d2, d3);
145         }
146         stop = System.currentTimeMillis();
147         System.out.println("sum double: " + (stop - start));
148 
149         start = System.currentTimeMillis();
150         for (int j = 0; j < 5000; j++) {
151             mulI = prodInt(a1, a2, a3);
152         }
153         stop = System.currentTimeMillis();
154         System.out.println("prod int:    " + (stop - start));
155 
156         start = System.currentTimeMillis();
157         for (int j = 0; j < 5000; j++) {
158             mulL = prodLong(b1, b2, b3);
159         }
160         stop = System.currentTimeMillis();
161         System.out.println("prod long:   " + (stop - start));
162 
163         start = System.currentTimeMillis();
164         for (int j = 0; j < 5000; j++) {
165             mulF = prodFloat(c1, c2, c3);
166         }
167         stop = System.currentTimeMillis();
168         System.out.println("prod float:  " + (stop - start));
169 
170         start = System.currentTimeMillis();
171         for (int j = 0; j < 5000; j++) {
172             mulD = prodDouble(d1, d2, d3);
173         }
174         stop = System.currentTimeMillis();
175         System.out.println("prod double: " + (stop - start));
176 
177     }
178 
179     public static void ReductionInit(int[] a1, int[] a2, int[] a3,
180                                      long[] b1, long[] b2, long[] b3,
181                                      float[] c1, float[] c2, float[] c3,
182                                      double[] d1, double[] d2, double[] d3) {
183         for(int i = 0; i < a1.length; i++) {
184             a1[i] =          (i + 0);
185             a2[i] =          (i + 1);
186             a3[i] =          (i + 2);
187             b1[i] =   (long) (i + 0);
188             b2[i] =   (long) (i + 1);
189             b3[i] =   (long) (i + 2);
190             c1[i] =  (float) (i + 0);
191             c2[i] =  (float) (i + 1);
192             c3[i] =  (float) (i + 2);
193             d1[i] = (double) (i + 0);
194             d2[i] = (double) (i + 1);
195             d3[i] = (double) (i + 2);
196         }
197     }
198 
199     public static int sumInt(int[] a1, int[] a2, int[] a3) {
200         int total = 0;
201         for (int i = 0; i < a1.length; i++) {
202             total += (a1[i] * a2[i]) + (a1[i] * a3[i]) + (a2[i] * a3[i]);
203         }
204         return total;
205     }
206 
207     public static long sumLong(long[] b1, long[] b2, long[] b3) {
208         long total = 0;
209         for (int i = 0; i < b1.length; i++) {
210             total += (b1[i] * b2[i]) + (b1[i] * b3[i]) + (b2[i] * b3[i]);
211         }
212         return total;
213     }
214 
215     public static float sumFloat(float[] c1, float[] c2, float[] c3) {
216         float total = 0;
217         for (int i = 0; i < c1.length; i++) {
218             total += (c1[i] * c2[i]) + (c1[i] * c3[i]) + (c2[i] * c3[i]);
219         }
220         return total;
221     }
222 
223     public static double sumDouble(double[] d1, double[] d2, double[] d3) {
224         double total = 0;
225         for (int i = 0; i < d1.length; i++) {
226             total += (d1[i] * d2[i]) + (d1[i] * d3[i]) + (d2[i] * d3[i]);
227         }
228         return total;
229     }
230 
231     public static int prodInt(int[] a1, int[] a2, int[] a3) {
232         int total = 1;
233         for (int i = 0; i < a1.length; i++) {
234             total *= (a1[i] * a2[i]) + (a1[i] * a3[i]) + (a2[i] * a3[i]);
235         }
236         return total;
237     }
238 
239     public static long prodLong(long[] b1, long[] b2, long[] b3) {
240         long total = 1;
241         for (int i = 0; i < b1.length; i++) {
242             total *= (b1[i] * b2[i]) + (b1[i] * b3[i]) + (b2[i] * b3[i]);
243         }
244         return total;
245     }
246 
247     public static float prodFloat(float[] c1, float[] c2, float[] c3) {
248         float total = 1;
249         for (int i = 0; i < c1.length; i++) {
250             total *= (c1[i] * c2[i]) + (c1[i] * c3[i]) + (c2[i] * c3[i]);
251         }
252         return total;
253     }
254 
255     public static double prodDouble(double[] d1, double[] d2, double[] d3) {
256         double total = 1;
257         for (int i = 0; i < d1.length; i++) {
258             total *= (d1[i] * d2[i]) + (d1[i] * d3[i]) + (d2[i] * d3[i]);
259         }
260         return total;
261     }
262 }