1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package normmap;
 25 
 26 import javax.imageio.ImageIO;
 27 import javax.swing.JFrame;
 28 
 29 import java.awt.Graphics;
 30 import java.awt.Graphics2D;
 31 import java.awt.Color;
 32 import java.awt.image.BufferedImage;
 33 import java.awt.image.DataBufferInt;
 34 import java.io.IOException;
 35 import java.util.Random;
 36 import javax.swing.JPanel;
 37 import java.awt.Font;
 38 
 39 import java.net.URL;
 40 import java.net.URISyntaxException;
 41 import java.io.File;
 42 import java.util.Arrays;
 43 
 44 /**
 45  * Based on a demo presented at JVMLS 2025 conference by Emanuel Peter, when giving
 46  * The rest of this comment is based on Emanuel's original code.
 47  *
 48  * A talk about Auto-Vectorization in HotSpot, see:
 49  *   https://inside.java/2025/08/16/jvmls-hotspot-auto-vectorization/
 50  *
 51  * If you want to disable the auto-vectorizer, you can run:
 52  *   java -XX:-UseSuperWord NormalMapping.java
 53  *
 54  * On x86, you can also play with the UseAVX flag:
 55  *   java -XX:UseAVX=1 NormalMapping.java
 56  *
 57  * The motivation for JVMLS 2025 was to present something that vectorizes
 58  * in an "embarassingly parallel" way. It should be something that C2's
 59  * SuperWord Auto Vectorizer could already do for many JDK releases,
 60  * and also has some visual appeal. I decided to use normal mapping, see:
 61  *   https://en.wikipedia.org/wiki/Normal_mapping
 62  *
 63  * At the conference, I only had the version that loads a normal map
 64  * from an image. I now also added some "generated" cases, which are
 65  * created from 2d height functions, and then converted to normal
 66  * maps. This allows us to show more "surfaces" without having to
 67  * store the images for all those cases.
 68  *
 69  * If you are interested in understanding the components, then look at these:
 70  * - computeLight: the normal mapping "shader / kernel".
 71  * - generateNormals / computeNormals: computing normals from height functions.
 72  * - main: setup and endless-loop that triggers normals to be swapped periodically.
 73  * - MyDrawingPanel: drawing all the parts to the screen.
 74  */
 75 public class Main {
 76     public static Random RANDOM = new Random();
 77 
 78     public static void main(String[] args) {
 79         System.out.println("Welcome to the Normal Mapping Demo!");
 80         // Create an application state with 5 lights.
 81         State state = new State(5);
 82 
 83         // Set up a panel we can draw on, and put it in a window.
 84         System.out.println("Setting up Window...");
 85         MyDrawingPanel panel = new MyDrawingPanel(state);
 86         JFrame frame = new JFrame("Normal Mapping Demo (Auto-Vectorization)");
 87         frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
 88         frame.setSize(2000, 1000);
 89         frame.add(panel);
 90         frame.setVisible(true);
 91         System.out.println("Running Demo...");
 92 
 93         try {
 94             // Tight loop where we redraw the panel as fast as possible.
 95             int count = 0;
 96             while (true) {
 97                 Thread.sleep(1);
 98                 state.update();
 99                 panel.repaint();
100                 if (count++ > 500) {
101                     count = 0;
102                     state.nextNormals();
103                 }
104             }
105         } catch (InterruptedException e) {
106             System.out.println("Interrupted, terminating demo.");
107         } finally {
108             System.out.println("Shut down demo.");
109             frame.setVisible(false);
110             frame.dispose();
111         }
112     }
113 
114 /*    public static File getLocalFile(String name) {
115         // If we are in JTREG IR testing mode, we have to get the path via system property,
116         // if it is run in stand-alone that property is not available, and we can load
117         // via getResource.
118         String file = //System.getProperty("test.src",
119             "/Users/grfrost/github/babylon-grfrost-fork/hat/examples/normmap/src/main/resources/images/"+name;
120         System.out.println("file = "+file);
121         return new File(file);
122 
123     } */
124 
125     public static BufferedImage loadImage(String resourcePath) {
126         try {
127             var inputStream = Main.class.getResourceAsStream(resourcePath);
128             return ImageIO.read(inputStream);
129            // return ImageIO.read(file);
130         } catch (IOException e) {
131             throw new RuntimeException("Could not load: ", e);
132         }
133     }
134 
135     public static class Light {
136         public float x = 0.5f;
137         public float y = 0.5f;
138         private float dx;
139         private float dy;
140 
141         private float h;
142         public float r;
143         public float g;
144         public float b;
145 
146         Light() {
147             this.h = RANDOM.nextFloat();
148         }
149 
150         // Random movement of the Light
151         public void update() {
152             // Random acceleration with dampening.
153             dx *= 0.99;
154             dy *= 0.99;
155             dx += RANDOM.nextFloat() * 0.001 - 0.0005;
156             dy += RANDOM.nextFloat() * 0.001 - 0.0005;
157             x += dx;
158             y += dy;
159 
160             // Boounce off the walls.
161             if (x < 0) { dx = +Math.abs(dx); }
162             if (x > 1) { dx = -Math.abs(dx); }
163             if (y < 0) { dy = +Math.abs(dy); }
164             if (y > 1) { dy = -Math.abs(dy); }
165 
166             // Rotate the hue -> gets us nice rainbow colors.
167             h += 0.001 + RANDOM.nextFloat() * 0.0002;
168             Color c = Color.getHSBColor(h, 1f, 1f);
169             r = (1f / 256f) * c.getRed();
170             g = (1f / 256f) * c.getGreen();
171             b = (1f / 256f) * c.getBlue();
172         }
173     }
174 
175     public static class State {
176         private static final int sizeX = 1000;
177         private static final int sizeY = 1000;
178 
179         public Light[] lights;
180         private int nextNormalsId = 0;
181 
182         public BufferedImage normals;
183         public float[] coordsX;
184         public float[] coordsY;
185         public float[] normalsX;
186         public float[] normalsY;
187         public float[] normalsZ;
188 
189         public BufferedImage output;
190         public BufferedImage output_2;
191         public int[] outputRGB;
192         public int[] outputRGB_2;
193 
194         public long lastTime;
195         public float fps;
196 
197         float luminosityCorrection = 1f;
198 
199         public State(int numberOfLights) {
200             lights = new Light[numberOfLights];
201             for (int i = 0; i < lights.length; i++) {
202                 lights[i] = new Light();
203             }
204 
205             // Coordinates
206             this.coordsX = new float[sizeX * sizeY];
207             this.coordsY = new float[sizeX * sizeY];
208             for (int y = 0; y < sizeY; y++) {
209                 for (int x = 0; x < sizeX; x++) {
210                     this.coordsX[y * sizeX + x] = x * (1f / sizeX);
211                     this.coordsY[y * sizeX + x] = y * (1f / sizeY);
212                 }
213             }
214 
215             nextNormals();
216 
217             // Double buffered output images, where we render to.
218             // Without double buffering, we would get some flickering effects,
219             // because we would be concurrently updating the buffer and drawing it.
220             this.output   = new BufferedImage(sizeX, sizeY, BufferedImage.TYPE_INT_RGB);
221             this.output_2 = new BufferedImage(sizeX, sizeY, BufferedImage.TYPE_INT_RGB);
222             this.outputRGB   = ((DataBufferInt) output.getRaster().getDataBuffer()).getData();
223             this.outputRGB_2 = ((DataBufferInt) output_2.getRaster().getDataBuffer()).getData();
224 
225             // Set up the FPS tracker
226             lastTime = System.nanoTime();
227         }
228 
229         public void nextNormals() {
230             switch (nextNormalsId) {
231                 case 0 -> setNormals(loadNormals("normal_map.png"));
232                 case 1 -> setNormals(generateNormals("heart"));
233                 case 2 -> setNormals(generateNormals("hex"));
234                 case 3 -> setNormals(generateNormals("cone"));
235                 case 4 -> setNormals(generateNormals("ripple"));
236                 case 5 -> setNormals(generateNormals("hill"));
237                 case 6 -> setNormals(generateNormals("ripple2"));
238                 case 7 -> setNormals(generateNormals("cones"));
239                 case 8 -> setNormals(generateNormals("spheres"));
240                 case 9 -> setNormals(generateNormals("donut"));
241                 default -> throw new RuntimeException();
242             }
243             nextNormalsId = (nextNormalsId + 1) % 10;
244         }
245 
246         public BufferedImage loadNormals(String name) {
247             // Extract normal values from RGB image
248             // The loaded image may not have the desired INT_RGB format, so first convert it
249             BufferedImage normalsLoaded = loadImage("/images/"+name);
250             BufferedImage buf = new BufferedImage(sizeX, sizeY, BufferedImage.TYPE_INT_RGB);
251             buf.getGraphics().drawImage(normalsLoaded, 0, 0, null);
252             return buf;
253         }
254 
255         public void setNormals(BufferedImage buf) {
256             this.normals = buf;
257 
258             int[] normalsRGB = ((DataBufferInt) this.normals.getRaster().getDataBuffer()).getData();
259             this.normalsX = new float[sizeX * sizeY];
260             this.normalsY = new float[sizeX * sizeY];
261             this.normalsZ = new float[sizeX * sizeY];
262             for (int y = 0; y < sizeY; y++) {
263                 for (int x = 0; x < sizeX; x++) {
264                     this.coordsY[y * sizeX + x] = y * (1f / sizeY);
265                     int normal = normalsRGB[y * sizeX + x];
266                     // RGB values in range [0 ... 255]
267                     int nr = (normal >> 16) & 0xff;
268                     int ng = (normal >>  8) & 0xff;
269                     int nb = (normal >>  0) & 0xff;
270 
271                     // Map range [0..255] -> [-1 .. 1]
272                     float nx = ((float)nr) * (1f / 128f) - 1f;
273                     float ny = ((float)ng) * (1f / 128f) - 1f;
274                     float nz = ((float)nb) * (1f / 128f) - 1f;
275 
276                     this.normalsX[y * sizeX + x] = -nx;
277                     this.normalsY[y * sizeX + x] = ny;
278                     this.normalsZ[y * sizeX + x] = nz;
279                 }
280             }
281         }
282 
283         interface HeightFunction {
284             // x and y should be in [0..1]
285             double call(double x, double y);
286         }
287 
288         public BufferedImage generateNormals(String name) {
289             System.out.println("  generate normals for: " + name);
290             return computeNormals((double x, double y) -> {
291                 // Scale out, so we see a little more
292                 x = 10 * (x - 0.5);
293                 y = 10 * (y - 0.5);
294 
295                 // A selection of "height functions":
296                 return switch (name) {
297                     case "cone" -> 0.1 * Math.max(0, 2 - Math.sqrt(x * x + y * y));
298                     case "heart" -> {
299                         double heart = Math.abs(Math.pow(x * x + y * y - 1, 3) - x * x * Math.pow(-y, 3));
300                         double decay = Math.exp(-(x * x + y * y));
301                         yield 0.1 * heart * decay;
302                     }
303                     case "hill" ->    0.5 * Math.exp(-(x * x + y * y));
304                     case "ripple" ->  0.01 * Math.sin(x * x + y * y);
305                     case "ripple2" -> 0.3 * Math.sin(x) * Math.sin(y);
306                     case "donut" -> {
307                         double d = Math.sqrt(x * x + y * y) - 2;
308                         double i = 1 - d*d;
309                         yield (i >= 0) ? 0.1 * Math.sqrt(i) : 0;
310                     }
311                     case "hex" -> {
312                         double f = 3.0;
313                         double a = Math.cos(f * x);
314                         double b = Math.cos(f * (-0.5 * x + Math.sqrt(3) / 2.0 * y));
315                         double c = Math.cos(f * (-0.5 * x - Math.sqrt(3) / 2.0 * y));
316                         yield 0.03 * (a + b + c);
317                     }
318                     case "cones" -> {
319                         double scale = 2.0;
320                         double r = 0.8;
321                         double cx = scale * (Math.floor(x / scale) + 0.5);
322                         double cy = scale * (Math.floor(y / scale) + 0.5);
323                         double dx = x - cx;
324                         double dy = y - cy;
325                         double d = Math.sqrt(dx * dx + dy * dy);
326                         yield 0.1 * Math.max(0, 0.8 - d);
327                     }
328                     case "spheres" -> {
329                         double scale = 2.0;
330                         double r = 0.8;
331                         double cx = scale * (Math.floor(x / scale) + 0.5);
332                         double cy = scale * (Math.floor(y / scale) + 0.5);
333                         double dx = x - cx;
334                         double dy = y - cy;
335                         double d2 = dx * dx + dy * dy;
336                         if (d2 <= r * r) {
337                             yield 0.03 * Math.sqrt(r * r - d2);
338                         }
339                         yield 0.0;
340                     }
341                     default -> throw new RuntimeException("not supported: " + name);
342                 };
343             });
344         }
345 
346         public static BufferedImage computeNormals(HeightFunction fun) {
347             BufferedImage out = new BufferedImage(1000, 1000, BufferedImage.TYPE_INT_RGB);
348             int[] arr = ((DataBufferInt) out.getRaster().getDataBuffer()).getData();
349             int sx = out.getWidth();
350             int sy = out.getHeight();
351 
352             double delta = 0.00001;
353             double dxx = 1.0 / sx;
354             double dyy = 1.0 / sy;
355             for (int yy = 0; yy < sy; yy++) {
356                 int nStart = sy * yy;
357                 for (int xx = 0; xx < sx; xx++) {
358                     double x = xx * dxx;
359                     double y = yy * dyy;
360 
361                     // Compute the partial derivatives in x and y direction;
362                     double fdx = fun.call(x + delta, y) - fun.call(x - delta, y);
363                     double fdy = fun.call(x, y + delta) - fun.call(x, y - delta);
364                     // We can compute the normal from the cross product of:
365                     //
366                     //  df/dx  x  df/dy = [2*delta, 0, fdx]  x  [0, 2*delta, fdy]
367                     //                  = [0*fdy - fdx*2*delta, fdx*0 - 2*delta*fdy, 2*delta*2*delta - 0*0]
368                     double nx = -fdx * 2 * delta;
369                     double ny = -2 * delta * fdy;
370                     double nz = 2 * delta * 2 * delta;
371 
372                     // normalize
373                     float dist = (float)Math.sqrt(nx * nx + ny * ny + nz * nz);
374                     nx /= dist;
375                     ny /= dist;
376                     nz /= dist;
377 
378                     // Now transform [-1..1] -> [0..255]
379                     int r = (int)(nx * 127f + 127f) & 0xff;
380                     int g = (int)(ny * 127f + 127f) & 0xff;
381                     int b = (int)(nz * 127f + 127f) & 0xff;
382                     int c = (r << 16) + (g << 8) + b;
383                     arr[nStart + xx] = c;
384                 }
385             }
386             return out;
387         }
388 
389         public void update() {
390             long nowTime = System.nanoTime();
391             float newFPS = 1e9f / (nowTime - lastTime);
392             fps = 0.99f * fps + 0.01f * newFPS;
393             lastTime = nowTime;
394 
395             for (Light light : lights) {
396                 light.update();
397             }
398 
399             // Reset the buffer
400             int[] outputArray = ((DataBufferInt) output.getRaster().getDataBuffer()).getData();
401             Arrays.fill(outputArray, 0);
402 
403             // Add in the contribution of each light
404             for (Light l : lights) {
405                 computeLight(l);
406             }
407             computeLuminosityCorrection();
408 
409             // Swap the buffers for double buffering.
410             var outputTmp = output;
411             output = output_2;
412             output_2 = outputTmp;
413 
414             var outputRGBTmp = outputRGB;
415             outputRGB = outputRGB_2;
416             outputRGB_2 = outputRGBTmp;
417         }
418 
419         public void computeLight(Light l) {
420             for (int i = 0; i < outputRGB.length; i++) {
421                 float x = coordsX[i];
422                 float y = coordsY[i];
423                 float nx = normalsX[i];
424                 float ny = normalsY[i];
425                 float nz = normalsZ[i];
426 
427                 // Compute distance vector between the light and the pixel
428                 float dx = x - l.x;
429                 float dy = y - l.y;
430                 float dz = 0.2f; // how much the lights float above the scene
431 
432                 // Compute the distance (dot product of d with itself)
433                 float d2 = dx * dx + dy * dy + dz * dz;
434                 float d = (float)Math.sqrt(d2);
435                 float d3 = d * d2;
436 
437                 // Compute dot-product between distance and normal vector
438                 float dotProduct = nx * dx + ny * dy + nz * dz;
439 
440                 // If the dot-product is negative:
441                 //   Light on wrong side -> 0
442                 // If the dot-product is positive:
443                 //   There should be light normalize by distance (d), and divide by the
444                 //   squared distance (d2) to have physically accurately decaying light.
445                 //   Correct the luminosity so the RGB values are going to be close
446                 //   to 255, but not over.
447                 float luminosity = Math.max(0, dotProduct / d3) * luminosityCorrection;
448 
449                 // Now we compute the color values that hopefully end up in the range
450                 // [0..255]. If the hack/trick with luminosityCorrection fails, we may
451                 // occasionally go out of the range and generate an overflow in the masking.
452                 // This can lead to some funky visual artifacts around the lights, but it
453                 // is quite rare.
454                 //
455                 // Feel free to play with the targetExposure below, and see if you can
456                 // observe the artefacts.
457                 int r = (int)(luminosity * l.r) & 0xff;
458                 int g = (int)(luminosity * l.g) & 0xff;
459                 int b = (int)(luminosity * l.b) & 0xff;
460                 int c = (r << 16) + (g << 8) + b;
461                 outputRGB[i] += c;
462             }
463         }
464 
465         // This is a bit of a horrible hack, but it mostly works.
466         // Essentially, it tries to solve the "exposure" problem:
467         // It is hard to know how much light a pixel will receive at most, and
468         // we have to convert this value to a byte [0..255] at some point.
469         // If we chose the "exposure" too low, we get a very dark picture
470         // that is not very exciting to look at. If we over-expose, then we
471         // may overflow/clip the range [0..255], leading to unpleasant visual
472         // artifacts.
473         public void computeLuminosityCorrection() {
474             // Find maximum R, G, and B value.
475             float maxR = 0;
476             float maxG = 0;
477             float maxB = 0;
478             for (int i = 0; i < outputRGB.length; i++) {
479                 int c = outputRGB[i];
480                 int cr = (c >> 16) & 0xff;
481                 int cg = (c >>  8) & 0xff;
482                 int cb = (c >>  0) & 0xff;
483 
484                 maxR = Math.max(maxR, cr);
485                 maxG = Math.max(maxG, cg);
486                 maxB = Math.max(maxB, cb);
487             }
488 
489             float maxC = Math.max(Math.max(maxR, maxG), maxB);
490 
491             // Correct the maximum value to be 230, so we are safely in range 0..255
492             // Setting it instead to 255 will make the image brighter, but most likely
493             // it will give you some funky artefacts.
494             // Setting it to 100 will make the image darker.
495             float targetExposure = 230f;
496             luminosityCorrection *= targetExposure / maxC;
497         }
498     }
499 
500     public static class MyDrawingPanel extends JPanel {
501         private final State state;
502 
503         public MyDrawingPanel(State state) {
504             this.state = state;
505         }
506 
507         @Override
508         protected void paintComponent(Graphics g) {
509             super.paintComponent(g);
510             Graphics2D g2d = (Graphics2D) g;
511 
512             // Draw color output
513             g2d.drawImage(state.output_2, 0, 0, null);
514 
515             // Draw position of lights
516             for (Light l : state.lights) {
517                 g2d.setColor(new Color(l.r, l.g, l.b));
518                 g2d.fillRect((int)(1000f * l.x) - 3, (int)(1000f * l.y) - 3, 6, 6);
519             }
520 
521             g2d.setColor(new Color(0, 0, 0));
522             g2d.fillRect(0, 0, 150, 35);
523             g2d.setColor(new Color(255, 255, 255));
524             g2d.setFont(new Font("Consolas", Font.PLAIN, 30));
525             g2d.drawString("FPS: " + (int)Math.floor(state.fps), 0, 30);
526 
527             g2d.drawImage(state.normals, 1000, 0, null);
528         }
529     }
530 }