1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package oracle.code.onnx.fer;
25
26 import oracle.code.onnx.OnnxProvider;
27 import oracle.code.onnx.OnnxRuntime;
28
29 import javax.imageio.ImageIO;
30 import java.awt.*;
31 import java.awt.image.BufferedImage;
32 import java.io.IOException;
33 import java.lang.foreign.Arena;
34 import java.net.URL;
35 import java.util.Objects;
36
37 import static oracle.code.onnx.fer.FERCoreMLDemo.IMAGE_SIZE;
38
39 public class FERInference {
40
41
42 private final OnnxRuntime runtime;
43
44 public FERInference() {
45 runtime = OnnxRuntime.getInstance();
46 }
47
48 public float[] analyzeImage(Arena arena, OnnxRuntime.SessionOptions sessionOptions, URL url, boolean isCondensed) throws Exception {
49 float[] imageData = transformToFloatArray(url);
50 FERModel ferModel = new FERModel(arena);
51 float[] rawScores = ferModel.classify(imageData, sessionOptions, isCondensed);
52 return rawScores;
53 }
54
55 public OnnxRuntime.SessionOptions prepareSessionOptions(Arena arena, OnnxProvider provider) {
56 var sessionOptions = runtime.createSessionOptions(arena);
57 if (Objects.nonNull(provider)) {
58 runtime.appendExecutionProvider(arena, sessionOptions, provider);
59 }
60 return sessionOptions;
61 }
62
63 private float[] transformToFloatArray(URL imgUrl) throws IOException {
64 BufferedImage src = ImageIO.read(imgUrl);
65 if (src == null) {
66 throw new IOException("Unsupported or corrupt image: " + imgUrl);
67 }
68
69 BufferedImage graySrc = new BufferedImage(src.getWidth(), src.getHeight(), BufferedImage.TYPE_BYTE_GRAY);
70 Graphics2D g0 = graySrc.createGraphics();
71 g0.drawImage(src, 0, 0, null);
72 g0.dispose();
73
74 BufferedImage gray = new BufferedImage(IMAGE_SIZE, IMAGE_SIZE, BufferedImage.TYPE_BYTE_GRAY);
75 Graphics2D g = gray.createGraphics();
76 g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
77 g.drawImage(graySrc, 0, 0, IMAGE_SIZE, IMAGE_SIZE, null);
78 g.dispose();
79
80 float[] data = new float[IMAGE_SIZE * IMAGE_SIZE];
81 gray.getData().getSamples(0, 0, IMAGE_SIZE, IMAGE_SIZE, 0, data);
82
83 return data;
84 }
85 }