1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package shade;
26
27 import hat.Accelerator;
28 import hat.buffer.F32Array;
29 import hat.buffer.Uniforms;
30 import hat.types.vec2;
31 import hat.types.vec4;
32 import hat.util.ui.SevenSegmentDisplay;
33
34 import javax.imageio.ImageIO;
35 import javax.swing.Box;
36 import javax.swing.JButton;
37 import javax.swing.JCheckBox;
38 import javax.swing.JComboBox;
39 import javax.swing.JComponent;
40 import javax.swing.JFrame;
41 import javax.swing.JLabel;
42 import javax.swing.JMenuBar;
43 import javax.swing.JTextField;
44 import java.awt.Graphics;
45 import java.awt.Graphics2D;
46 import java.awt.Image;
47 import java.awt.Point;
48 import java.awt.Toolkit;
49 import java.awt.Transparency;
50 import java.awt.color.ColorSpace;
51 import java.awt.event.ComponentAdapter;
52 import java.awt.event.ComponentEvent;
53 import java.awt.image.BufferedImage;
54 import java.awt.image.ColorModel;
55 import java.awt.image.ComponentColorModel;
56 import java.awt.image.DataBuffer;
57 import java.awt.image.DataBufferFloat;
58 import java.awt.image.PixelInterleavedSampleModel;
59 import java.awt.image.Raster;
60 import java.awt.image.SampleModel;
61 import java.awt.image.VolatileImage;
62 import java.awt.image.WritableRaster;
63 import java.io.File;
64 import java.io.InputStream;
65 import java.lang.reflect.InvocationTargetException;
66 import java.lang.reflect.Method;
67 import java.nio.file.Files;
68 import java.nio.file.Path;
69 import java.util.function.IntConsumer;
70 import java.util.stream.IntStream;
71
72
73 public class ShaderViewer {
74 public interface Shader{
75 void update(Uniforms uniforms, F32Array f32Array);
76 }
77 private final static ColorSpace colorSpace = ColorSpace.getInstance(ColorSpace.CS_sRGB);
78 // Create the Color Model. 32 bits per component, no alpha, non-premultiplied
79 private final static ColorModel colorModel = new ComponentColorModel(colorSpace, false, false,
80 Transparency.OPAQUE, DataBuffer.TYPE_FLOAT);
81 public JComponent view;
82 private VolatileImage volatileImage;
83 private BufferedImage buffer;
84 private float[] f32x3Arr;
85 private volatile boolean resized = false;
86
87
88 public record Texture(F32Array f32Array, int width, int height){}
89
90 private final Config config;
91 public record Config(
92 Accelerator acc,
93 Class<?> shaderClass,
94 Method mainImageMethod,
95 Uniforms uniforms,
96 F32Array f32Array,
97 int width,
98 int height,
99 Texture[] textures,
100 JMenuBar menuBar,
101 SevenSegmentDisplay fps,
102 SevenSegmentDisplay shaderTimeMs,
103 JComboBox<String> runWith
104 ){
105 public static Config of(Accelerator acc,Class<?> shaderClass, int width, int height, InputStream... textureInputStreams){
106 try {
107 var mainImageMethod =
108 textureInputStreams.length==0
109 ? shaderClass.getDeclaredMethod("mainImage", Uniforms.class, vec4.class, vec2.class)
110 : shaderClass.getDeclaredMethod("mainImage", Uniforms.class, vec4.class, vec2.class, F32Array.class, int.class, int.class);
111
112 var uniforms = Uniforms.create(acc);
113 var f32Array = F32Array.create(acc, width * height * 3);
114 var menuBar = new JMenuBar();
115 ((JButton) menuBar.add(new JButton("Exit"))).addActionListener(_ -> System.exit(0));
116 menuBar.add(new JLabel("FPS:"));
117 var fps = (SevenSegmentDisplay) menuBar.add(new SevenSegmentDisplay(4, 20, menuBar.getForeground(), menuBar.getBackground()));
118 menuBar.add(new JLabel("Shader Time ms:"));
119 var shaderTimeMs = (SevenSegmentDisplay) menuBar.add(new SevenSegmentDisplay(6, 20, menuBar.getForeground(), menuBar.getBackground()));
120 var runWith = (JComboBox<String>) menuBar.add(new JComboBox(new String[]{"HAT", "Java MT", "Seq"}));
121 Texture[] textures = new Texture[textureInputStreams.length];
122 for (int i = 0; i < textureInputStreams.length; i++) {
123
124
125 //Files.readAllBytes(paths[i]);
126 Image textureImage = ImageIO.read(textureInputStreams[i]);
127 int tw = textureImage.getWidth(null);
128 int th = textureImage.getHeight(null);
129 SampleModel sampleModel = new PixelInterleavedSampleModel(DataBuffer.TYPE_FLOAT,
130 tw, th, 3, tw * 3, new int[]{0, 1, 2});
131 // Create the DataBuffer (an actual heap allocated float array)
132 DataBufferFloat dataBufferFloat = new DataBufferFloat(tw * th * 3);
133 // Create the Raster
134 WritableRaster raster = Raster.createWritableRaster(sampleModel, dataBufferFloat, null);
135 BufferedImage buffer = new BufferedImage(colorModel, raster, false, null);
136 var graphics = buffer.createGraphics();
137 graphics.drawImage(textureImage,0,0, null);
138 var floats = dataBufferFloat.getData();
139 textures[i] = new Texture(F32Array.create(acc, tw* th*3), tw, th);
140 textures[i].f32Array.copyFrom(floats);
141 }
142
143 return new Config(acc,shaderClass,mainImageMethod,uniforms,f32Array,width,height,textures,menuBar,fps, shaderTimeMs,runWith);
144 }catch (Throwable t){
145 throw new RuntimeException(t);
146
147 }
148 }
149
150 }
151 ShaderViewer(Config config){
152 this.config = config;
153 this.view = new JComponent() {};
154 this.view.setSize(config.width, config.height);
155 this.view.addComponentListener(new ComponentAdapter() {
156 @Override
157 public void componentResized(ComponentEvent e) {
158 resized = true;
159 }
160 });
161 }
162
163 public void startLoop(Shader shader) {
164
165 // Create the Sample Model (Pixel Interleaved) bands for RGB, scanline stride is width * 3
166 SampleModel sampleModel = new PixelInterleavedSampleModel(DataBuffer.TYPE_FLOAT,
167 config.width, config.height, 3, config.width * 3, new int[]{0, 1, 2});
168 // Create the DataBuffer (an actual heap allocated float array)
169 DataBufferFloat dataBufferFloat = new DataBufferFloat(config.width * config.height * 3);
170 // Create the Raster
171 WritableRaster raster = Raster.createWritableRaster(sampleModel, dataBufferFloat, null);
172 buffer = new BufferedImage(colorModel, raster, false, null);
173 f32x3Arr = dataBufferFloat.getData();
174 volatileImage = view.createVolatileImage(config.width, config.height);
175 new Thread(() -> {
176 config.uniforms.iFrame(0);
177 long startTimeNs = System.nanoTime();
178 while (true) {
179 do {
180 // Check if the volatileImage content was lost (resize or invalid)
181 // if ( resized ){
182 // throw new RuntimeException("Dont resize");
183 // }
184
185 long startNs = System.nanoTime();
186 var mouse = view.getMousePosition() instanceof Point point ? point : new Point(0, 0);
187 config.uniforms.iTime((System.nanoTime() - startTimeNs) / 1000000000f);
188 config.uniforms.iMouse().x(mouse.x);
189 config.uniforms.iMouse().y(mouse.y);
190 config.uniforms.iResolution().x(config.width);
191 config.uniforms.iResolution().y(config.height);
192 IntConsumer intConsumer = idx -> {
193 vec2 fragCoord = vec2.vec2((float) (idx % config.width), (float) (config.height - (idx / config.width)));
194 try {
195 vec4 fragColor = (vec4) config.mainImageMethod.invoke(null, config.uniforms, vec4.vec4(1f), fragCoord);
196 config.f32Array.array(idx * 3, fragColor.x());
197 config.f32Array.array(idx * 3 + 1, fragColor.y());
198 config.f32Array.array(idx * 3 + 2, fragColor.z());
199 } catch (IllegalAccessException | InvocationTargetException e) {
200 throw new RuntimeException(e);
201 }
202 };
203 switch (config.runWith.getSelectedIndex()) {
204 case 2 -> IntStream.range(0, config.width * config.height).forEach(intConsumer);
205 case 1 -> IntStream.range(0, config.width * config.height).parallel().forEach(intConsumer);
206 case 0 -> shader.update(config.uniforms, config.f32Array);
207 }
208 config.f32Array.copyTo(f32x3Arr);
209 config.uniforms.iFrame(config.uniforms.iFrame() + 1);
210 long shaderMs = (System.nanoTime() - startNs) / 1000000;
211 config.fps.set((int) (1000 / shaderMs));
212 config.shaderTimeMs.set((int) shaderMs);
213 Graphics2D volatileGraphics2D = volatileImage.createGraphics();
214 volatileGraphics2D.drawImage(buffer, 0, 0, null);
215 volatileGraphics2D.dispose();
216 Graphics g = view.getGraphics();
217 g.drawImage(volatileImage, 0, 0, null);
218 g.dispose();
219 Toolkit.getDefaultToolkit().sync();// Ensure smooth rendering on Linux/macOS
220 } while (volatileImage == null || volatileImage.contentsLost());
221 }
222 }).start();
223 }
224
225
226
227 public static ShaderViewer of(Config config){
228 JFrame frame = new JFrame(config.shaderClass.getSimpleName());
229 frame.setJMenuBar(config.menuBar);
230 var shaderViewer = new ShaderViewer(config);
231 frame.setSize(shaderViewer.view.getWidth(),shaderViewer.view.getHeight());
232 frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
233 frame.add(shaderViewer.view);
234 frame.setVisible(true);
235 return shaderViewer;
236 }
237
238 public static ShaderViewer of(Accelerator acc, Class<?> shaderClass, int width, int height){
239 return of(Config.of(acc,shaderClass,width,height));
240 }
241 }