1 /* 2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package oracle.code.onnx.fer; 25 26 import oracle.code.onnx.OnnxRuntime; 27 import oracle.code.onnx.provider.OnnxProvider; 28 29 import javax.imageio.ImageIO; 30 import java.awt.*; 31 import java.awt.image.BufferedImage; 32 import java.io.IOException; 33 import java.lang.foreign.Arena; 34 import java.net.URL; 35 import java.util.Objects; 36 37 import static oracle.code.onnx.fer.FERCoreMLDemo.IMAGE_SIZE; 38 39 public class FERInference { 40 41 42 private final OnnxRuntime runtime; 43 44 public FERInference() { 45 runtime = OnnxRuntime.getInstance(); 46 } 47 48 public float[] analyzeImage(Arena arena, OnnxProvider provider, URL url, boolean useCondensedModel) throws Exception { 49 float[] imageData = transformToFloatArray(url); 50 var sessionOptions = runtime.createSessionOptions(arena); 51 if (Objects.nonNull(provider)) 52 provider.configure(sessionOptions); 53 FERModel ferModel = new FERModel(arena); 54 float[] rawScores = ferModel.classify(imageData, sessionOptions, useCondensedModel); 55 return rawScores; 56 } 57 58 private float[] transformToFloatArray(URL imgUrl) throws IOException { 59 BufferedImage src = ImageIO.read(imgUrl); 60 if (src == null) { 61 throw new IOException("Unsupported or corrupt image: " + imgUrl); 62 } 63 64 BufferedImage graySrc = new BufferedImage(src.getWidth(), src.getHeight(), BufferedImage.TYPE_BYTE_GRAY); 65 Graphics2D g0 = graySrc.createGraphics(); 66 g0.drawImage(src, 0, 0, null); 67 g0.dispose(); 68 69 BufferedImage gray = new BufferedImage(IMAGE_SIZE, IMAGE_SIZE, BufferedImage.TYPE_BYTE_GRAY); 70 Graphics2D g = gray.createGraphics(); 71 g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); 72 g.drawImage(graySrc, 0, 0, IMAGE_SIZE, IMAGE_SIZE, null); 73 g.dispose(); 74 75 float[] data = new float[IMAGE_SIZE * IMAGE_SIZE]; 76 gray.getData().getSamples(0, 0, IMAGE_SIZE, IMAGE_SIZE, 0, data); 77 78 return data; 79 } 80 }