atri.palaash.lumenforge.app.LumenForgeApp
@@ -109,24 +115,22 @@
+
gpu-windows
-
- Windows
-
-
- onnx.gpu
- true
-
+ Windows
-
-
- com.microsoft.onnxruntime
- onnxruntime_gpu
- 1.22.0
-
-
+
+ onnxruntime_gpu
+
gpu-linux
@@ -135,54 +139,11 @@
unix
Linux
-
- onnx.gpu
- true
-
-
-
-
- com.microsoft.onnxruntime
- onnxruntime_gpu
- 1.22.0
-
-
-
-
-
- djl-pytorch
-
-
- djl
- true
-
-
-
-
- ai.djl
- bom
- 0.31.0
- pom
- import
-
-
-
-
-
- ai.djl
- api
-
-
- ai.djl.pytorch
- pytorch-engine
- runtime
-
-
- ai.djl.huggingface
- tokenizers
-
-
+
+ onnxruntime_gpu
+
+
diff --git a/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java b/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java
index 1ba6d45..dfc00a8 100644
--- a/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java
+++ b/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java
@@ -22,6 +22,11 @@ public class LumenForgeApp {
public static void main(String[] args) {
configureDesktopIntegration();
+ System.out.println("[LumenForge] Starting \u2014 Java " + Runtime.version()
+ + ", " + System.getProperty("os.name") + " " + System.getProperty("os.arch"));
+ System.out.println("[LumenForge] Max heap: " + (Runtime.getRuntime().maxMemory() / (1024 * 1024)) + " MB"
+ + ", CPUs: " + Runtime.getRuntime().availableProcessors());
+
ExecutorService workerPool = Executors.newVirtualThreadPerTaskExecutor();
ModelStorage storage = new ModelStorage();
ModelRegistry registry = new ModelRegistry();
diff --git a/src/main/java/atri/palaash/lumenforge/inference/DjlPyTorchService.java b/src/main/java/atri/palaash/lumenforge/inference/DjlPyTorchService.java
deleted file mode 100644
index 56aad75..0000000
--- a/src/main/java/atri/palaash/lumenforge/inference/DjlPyTorchService.java
+++ /dev/null
@@ -1,396 +0,0 @@
-package atri.palaash.lumenforge.inference;
-
-import atri.palaash.lumenforge.storage.ModelStorage;
-
-import java.awt.image.BufferedImage;
-import java.lang.reflect.Method;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.Executor;
-
-/**
- * Stable Diffusion inference using DJL (Deep Java Library) with PyTorch backend.
- *
- * This backend loads TorchScript-traced models (.pt files) and runs inference
- * using PyTorch's native tensor operations (CPU, CUDA, or MPS). Activate by
- * building with {@code mvn compile -Ddjl=true}.
- *
- *
Required model layout
- *
- * ~/.lumenforge-models/text-image/sd-pytorch/
- * clip_model.pt — traced CLIP text encoder
- * unet_model.pt — traced UNet2DConditionModel
- * vae_decoder_model.pt — traced VAE decoder
- * tokenizer.json — HuggingFace tokenizer config
- *
- *
- * Use the included {@code scripts/export_torchscript.py} helper to export
- * HuggingFace SD checkpoints to TorchScript format.
- */
-public class DjlPyTorchService implements InferenceService {
-
- private final ModelStorage storage;
- private final Executor executor;
-
- /* ── DJL availability check ─────────────────────────────────────── */
-
- private static final boolean DJL_AVAILABLE;
- static {
- boolean found;
- try {
- Class.forName("ai.djl.ndarray.NDManager");
- found = true;
- } catch (ClassNotFoundException e) {
- found = false;
- }
- DJL_AVAILABLE = found;
- }
-
- /** Returns {@code true} if the DJL PyTorch runtime is on the classpath. */
- public static boolean isAvailable() {
- return DJL_AVAILABLE;
- }
-
- public DjlPyTorchService(ModelStorage storage, Executor executor) {
- this.storage = storage;
- this.executor = executor;
- }
-
- @Override
- public CompletableFuture run(InferenceRequest request) {
- return CompletableFuture.supplyAsync(() -> {
- if (!DJL_AVAILABLE) {
- return InferenceResult.fail(
- "DJL PyTorch is not on the classpath. Rebuild with: mvn compile -Ddjl=true");
- }
-
- Path base = storage.root().resolve("text-image").resolve("sd-pytorch");
- Path clipPath = base.resolve("clip_model.pt");
- Path unetPath = base.resolve("unet_model.pt");
- Path vaePath = base.resolve("vae_decoder_model.pt");
- Path tokPath = base.resolve("tokenizer.json");
-
- for (Path p : new Path[]{clipPath, unetPath, vaePath, tokPath}) {
- if (!Files.exists(p)) {
- return InferenceResult.fail(
- "Missing TorchScript component: " + p.getFileName()
- + "\nExport models with scripts/export_torchscript.py first.");
- }
- }
-
- try {
- return runPyTorchPipeline(request, clipPath, unetPath, vaePath, tokPath);
- } catch (Exception ex) {
- return InferenceResult.fail("DJL PyTorch pipeline failed: " + ex.getMessage());
- }
- }, executor);
- }
-
- /* ── Pipeline (reflection-based to compile without DJL on classpath) ─ */
-
- @SuppressWarnings("unchecked")
- private InferenceResult runPyTorchPipeline(InferenceRequest request,
- Path clipPath, Path unetPath,
- Path vaePath, Path tokPath)
- throws Exception {
-
- // All DJL calls go through reflection so the class compiles even when
- // the djl-pytorch profile is not active.
-
- Class> clsNDManager = Class.forName("ai.djl.ndarray.NDManager");
- Class> clsNDArray = Class.forName("ai.djl.ndarray.NDArray");
- Class> clsNDList = Class.forName("ai.djl.ndarray.NDList");
- Class> clsShape = Class.forName("ai.djl.ndarray.types.Shape");
- Class> clsDataType = Class.forName("ai.djl.ndarray.types.DataType");
- Class> clsDevice = Class.forName("ai.djl.Device");
- Class> clsCriteria = Class.forName("ai.djl.repository.zoo.Criteria");
- Class> clsZooModel = Class.forName("ai.djl.repository.zoo.ZooModel");
- Class> clsPredictor = Class.forName("ai.djl.inference.Predictor");
- Class> clsTokenizer = Class.forName("ai.djl.huggingface.tokenizers.HuggingFaceTokenizer");
- Class> clsEncoding = Class.forName("ai.djl.huggingface.tokenizers.Encoding");
- Class> clsTranslator = Class.forName("ai.djl.translate.NoopTranslator");
-
- int width = Math.max(256, (request.width() / 8) * 8);
- int height = Math.max(256, (request.height() / 8) * 8);
- int latentW = width / 8;
- int latentH = height / 8;
- int steps = Math.max(1, Math.min(request.batch() > 0 ? request.batch() : 20, 50));
-
- // Choose device — prefer GPU if available
- Object device;
- if (request.preferGpu()) {
- Method gpuMethod = clsDevice.getMethod("gpu");
- device = gpuMethod.invoke(null);
- } else {
- Method cpuMethod = clsDevice.getMethod("cpu");
- device = cpuMethod.invoke(null);
- }
-
- // Create NDManager
- Method managerOf = clsNDManager.getMethod("newBaseManager", clsDevice);
- Object manager = managerOf.invoke(null, device);
-
- try {
- /* ── Tokenize ──────────────────────────────────────────────── */
- request.reportProgress("Tokenizing prompt (DJL HuggingFace)\u2026");
- Method tokLoad = clsTokenizer.getMethod("newInstance", Path.class);
- Object tokenizer = tokLoad.invoke(null, tokPath);
- Method tokEncode = clsTokenizer.getMethod("encode", String.class);
- Object encoding = tokEncode.invoke(tokenizer, request.prompt());
- Method getIds = clsEncoding.getMethod("getIds");
- long[] tokenIds = (long[]) getIds.invoke(encoding);
-
- // Pad/truncate to 77
- long[] padded = new long[77];
- System.arraycopy(tokenIds, 0, padded, 0, Math.min(tokenIds.length, 77));
- if (tokenIds.length < 77) {
- long padId = 49407L; // <|endoftext|>
- for (int i = tokenIds.length; i < 77; i++) padded[i] = padId;
- }
-
- // Create token tensor [1, 77]
- Object shapeTokens = clsShape.getConstructor(long[].class)
- .newInstance((Object) new long[]{1, 77});
- Method mgrCreate = clsNDManager.getMethod("create", long[].class,
- Class.forName("ai.djl.ndarray.types.Shape"));
- Object tokenTensor = mgrCreate.invoke(manager, padded, shapeTokens);
-
- /* ── CLIP text encoder ──────────────────────────────────── */
- request.reportProgress("Loading CLIP text encoder (PyTorch)\u2026");
- Object clipModel = loadTorchScriptModel(clsCriteria, clsNDList, clsTranslator,
- clipPath, device);
- Object clipPredictor = clsZooModel.getMethod("newPredictor").invoke(clipModel);
-
- Object clipInput = clsNDList.getConstructor(clsNDArray).newInstance(tokenTensor);
- request.reportProgress("Running CLIP text encoder\u2026");
- Object clipOutput = clsPredictor.getMethod("predict", Object.class)
- .invoke(clipPredictor, clipInput);
- Object textEmbedding = clsNDList.getMethod("singletonOrThrow").invoke(clipOutput);
-
- /* ── Negative prompt (uncond) encoder ───────────────────── */
- long[] emptyTokens = new long[77];
- emptyTokens[0] = 49406L; // <|startoftext|>
- for (int i = 1; i < 77; i++) emptyTokens[i] = 49407L;
- Object uncondTensor = mgrCreate.invoke(manager, emptyTokens, shapeTokens);
- Object uncondInput = clsNDList.getConstructor(clsNDArray).newInstance(uncondTensor);
- request.reportProgress("Running uncond text encoder\u2026");
- Object uncondOutput = clsPredictor.getMethod("predict", Object.class)
- .invoke(clipPredictor, uncondInput);
- Object uncondEmbedding = clsNDList.getMethod("singletonOrThrow").invoke(uncondOutput);
-
- // Concatenate [uncond, cond] → [2, 77, 768]
- Method concatMethod = clsNDArray.getMethod("concat", clsNDArray, int.class);
- // Stack along batch dim: uncond embedding [1,77,768] + text embedding [1,77,768]
- Object embeddings = concatMethod.invoke(uncondEmbedding, textEmbedding, 0);
-
- clsPredictor.getMethod("close").invoke(clipPredictor);
- clsZooModel.getMethod("close").invoke(clipModel);
-
- /* ── UNet denoising loop ────────────────────────────────── */
- request.reportProgress("Loading UNet (PyTorch)\u2026");
- Object unetModel = loadTorchScriptModel(clsCriteria, clsNDList, clsTranslator,
- unetPath, device);
- Object unetPredictor = clsZooModel.getMethod("newPredictor").invoke(unetModel);
-
- // Initialize random latents [1, 4, latentH, latentW]
- Object latentShape = clsShape.getConstructor(long[].class)
- .newInstance((Object) new long[]{1, 4, latentH, latentW});
- Method mgrRandomNormal = clsNDManager.getMethod("randomNormal",
- Class.forName("ai.djl.ndarray.types.Shape"));
- Object latents = mgrRandomNormal.invoke(manager, latentShape);
-
- // Simple linear schedule for beta
- double guidanceScale = request.promptWeight() > 0 ? request.promptWeight() : 7.5;
- float[] timestepSchedule = ddimTimesteps(steps);
-
- // Initial scale
- Method mulMethod = clsNDArray.getMethod("mul", Number.class);
- latents = mulMethod.invoke(latents, (float) Math.sqrt(1.0 + sigmaForTimestep(timestepSchedule[0])));
-
- long stepStart = System.currentTimeMillis();
- for (int i = 0; i < timestepSchedule.length; i++) {
- if (request.isCancelled()) {
- clsPredictor.getMethod("close").invoke(unetPredictor);
- clsZooModel.getMethod("close").invoke(unetModel);
- return InferenceResult.fail("Cancelled by user.");
- }
-
- float t = timestepSchedule[i];
-
- // Duplicate latents for CFG: [latents, latents] → [2, 4, H, W]
- Object latentDouble = concatMethod.invoke(latents, latents, 0);
-
- // Create timestep tensor
- Object tsShape = clsShape.getConstructor(long[].class)
- .newInstance((Object) new long[]{1});
- Object tsTensor = mgrCreate.invoke(manager, new long[]{(long) t}, tsShape);
-
- // UNet forward: (noisy_latents, timestep, encoder_hidden_states)
- Object unetInput = clsNDList.getConstructor(clsNDArray.arrayType())
- .newInstance((Object) new Object[]{latentDouble, tsTensor, embeddings});
- Object unetOutput = clsPredictor.getMethod("predict", Object.class)
- .invoke(unetPredictor, unetInput);
- Object noisePred = clsNDList.getMethod("singletonOrThrow").invoke(unetOutput);
-
- // Classifier-free guidance: split [2,4,H,W] → uncond, cond
- Method splitMethod = clsNDArray.getMethod("split", long.class, int.class);
- Object splitResult = splitMethod.invoke(noisePred, 2L, 0);
- Object noiseUncond = java.lang.reflect.Array.get(
- clsNDList.getMethod("toArray").invoke(splitResult), 0);
- Object noiseCond = java.lang.reflect.Array.get(
- clsNDList.getMethod("toArray").invoke(splitResult), 1);
-
- // guided = uncond + scale * (cond - uncond)
- Method subMethod = clsNDArray.getMethod("sub", clsNDArray);
- Method addMethod = clsNDArray.getMethod("add", clsNDArray);
- Object diff = subMethod.invoke(noiseCond, noiseUncond);
- Object scaled = mulMethod.invoke(diff, (float) guidanceScale);
- Object guided = addMethod.invoke(noiseUncond, scaled);
-
- // DDIM step
- float alphaT = alphaForTimestep(t);
- float alphaPrev = (i + 1 < timestepSchedule.length)
- ? alphaForTimestep(timestepSchedule[i + 1]) : 1.0f;
-
- // predicted x0 = (latents - sqrt(1-alpha)*noise) / sqrt(alpha)
- Object scaledNoise = mulMethod.invoke(guided, (float) Math.sqrt(1.0 - alphaT));
- Object x0 = subMethod.invoke(latents, scaledNoise);
- x0 = mulMethod.invoke(x0, (float) (1.0 / Math.sqrt(alphaT)));
-
- // direction pointing to x_t
- Object dirXt = mulMethod.invoke(guided, (float) Math.sqrt(1.0 - alphaPrev));
- // x_{t-1} = sqrt(alpha_prev) * x0 + sqrt(1-alpha_prev) * noise_pred
- Object prevSample = mulMethod.invoke(x0, (float) Math.sqrt(alphaPrev));
- latents = addMethod.invoke(prevSample, dirXt);
-
- long elapsed = System.currentTimeMillis() - stepStart;
- stepStart = System.currentTimeMillis();
- request.reportProgress("Denoising: " + (i + 1) + "/" + steps
- + " steps (" + String.format("%.1f", elapsed / 1000.0) + "s/step) [PyTorch]");
- }
-
- clsPredictor.getMethod("close").invoke(unetPredictor);
- clsZooModel.getMethod("close").invoke(unetModel);
-
- /* ── VAE decode ─────────────────────────────────────────── */
- request.reportProgress("Loading VAE decoder (PyTorch)\u2026");
- Object vaeModel = loadTorchScriptModel(clsCriteria, clsNDList, clsTranslator,
- vaePath, device);
- Object vaePredictor = clsZooModel.getMethod("newPredictor").invoke(vaeModel);
-
- // Scale latents
- latents = mulMethod.invoke(latents, 1.0f / 0.18215f);
-
- Object vaeInput = clsNDList.getConstructor(clsNDArray).newInstance(latents);
- request.reportProgress("Decoding with VAE\u2026");
- Object vaeOutput = clsPredictor.getMethod("predict", Object.class)
- .invoke(vaePredictor, vaeInput);
- Object decoded = clsNDList.getMethod("singletonOrThrow").invoke(vaeOutput);
-
- clsPredictor.getMethod("close").invoke(vaePredictor);
- clsZooModel.getMethod("close").invoke(vaeModel);
-
- /* ── Convert tensor to BufferedImage ───────────────────── */
- // decoded: [1, 3, H, W] float32, range roughly [-1, 1]
- Method toFloatArray = clsNDArray.getMethod("toFloatArray");
- float[] pixels = (float[]) toFloatArray.invoke(decoded);
-
- Method getShapeMethod = clsNDArray.getMethod("getShape");
- Object decodedShape = getShapeMethod.invoke(decoded);
- long[] dims = (long[]) clsShape.getMethod("getShape").invoke(decodedShape);
- int imgH = (int) dims[2];
- int imgW = (int) dims[3];
-
- BufferedImage image = new BufferedImage(imgW, imgH, BufferedImage.TYPE_INT_RGB);
- for (int y = 0; y < imgH; y++) {
- for (int x = 0; x < imgW; x++) {
- int rIdx = 0 * imgH * imgW + y * imgW + x;
- int gIdx = 1 * imgH * imgW + y * imgW + x;
- int bIdx = 2 * imgH * imgW + y * imgW + x;
- int r = clamp((int) ((pixels[rIdx] / 2 + 0.5f) * 255));
- int g = clamp((int) ((pixels[gIdx] / 2 + 0.5f) * 255));
- int b = clamp((int) ((pixels[bIdx] / 2 + 0.5f) * 255));
- image.setRGB(x, y, (r << 16) | (g << 8) | b);
- }
- }
-
- Path outputDir = storage.root().resolve("outputs");
- Files.createDirectories(outputDir);
- String filename = "djl-pytorch_" + System.currentTimeMillis() + ".png";
- Path outputPath = outputDir.resolve(filename);
- javax.imageio.ImageIO.write(image, "PNG", outputPath.toFile());
-
- return InferenceResult.ok(
- "Generated image for prompt: \"" + request.prompt() + "\"",
- "DJL PyTorch pipeline completed (" + steps + " steps)",
- outputPath.toString(), "image");
-
- } finally {
- // Close NDManager
- clsNDManager.getMethod("close").invoke(manager);
- }
- }
-
- /* ── Helpers ──────────────────────────────────────────────────────── */
-
- /** Load a TorchScript model via DJL Criteria (reflection). */
- private Object loadTorchScriptModel(Class> clsCriteria, Class> clsNDList,
- Class> clsTranslator, Path modelPath,
- Object device) throws Exception {
- Method builder = clsCriteria.getMethod("builder");
- Object b = builder.invoke(null);
-
- Method setTypes = b.getClass().getMethod("setTypes", Class.class, Class.class);
- b = setTypes.invoke(b, clsNDList, clsNDList);
-
- Method optModelPath = b.getClass().getMethod("optModelPath", Path.class);
- b = optModelPath.invoke(b, modelPath);
-
- Method optEngine = b.getClass().getMethod("optEngine", String.class);
- b = optEngine.invoke(b, "PyTorch");
-
- Object translator = clsTranslator.getDeclaredConstructor().newInstance();
- Method optTranslator = b.getClass().getMethod("optTranslator",
- Class.forName("ai.djl.translate.Translator"));
- b = optTranslator.invoke(b, translator);
-
- Method optDevice = b.getClass().getMethod("optDevice",
- Class.forName("ai.djl.Device"));
- b = optDevice.invoke(b, device);
-
- Method build = b.getClass().getMethod("build");
- Object criteria = build.invoke(b);
-
- Method loadModel = clsCriteria.getMethod("loadModel");
- return loadModel.invoke(criteria);
- }
-
- /** DDIM linear timestep schedule. */
- private float[] ddimTimesteps(int numSteps) {
- float[] ts = new float[numSteps];
- for (int i = 0; i < numSteps; i++) {
- ts[i] = 999.0f * (1.0f - (float) i / numSteps);
- }
- return ts;
- }
-
- /** Linear beta schedule → alpha_bar for a given timestep. */
- private float alphaForTimestep(float t) {
- // Simplified linear schedule: beta from 0.00085 to 0.012 over 1000 steps
- double beta = 0.00085 + (0.012 - 0.00085) * t / 999.0;
- double alpha = 1.0 - beta;
- // Approximate cumulative product
- return (float) Math.pow(alpha, t);
- }
-
- /** Sigma for a given timestep (derived from alpha). */
- private double sigmaForTimestep(float t) {
- float a = alphaForTimestep(t);
- return Math.sqrt((1.0 - a) / a);
- }
-
- private static int clamp(int v) {
- return Math.max(0, Math.min(255, v));
- }
-}
diff --git a/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java b/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java
index 58b8c75..e4a1d65 100644
--- a/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java
+++ b/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java
@@ -71,10 +71,13 @@ private OrtSession getOrCreateSession(OrtEnvironment env, Path modelPath,
String key = modelPath.toAbsolutePath().toString();
OrtSession existing = SESSION_CACHE.get(key);
if (existing != null) {
+ System.out.println("[LumenForge] Session cache hit: " + modelPath.getFileName());
return existing;
}
+ System.out.println("[LumenForge] Loading ONNX session: " + modelPath.getFileName());
OrtSession session = env.createSession(modelPath.toString(), opts);
SESSION_CACHE.put(key, session);
+ System.out.println("[LumenForge] Session loaded: " + modelPath.getFileName());
return session;
}
@@ -126,10 +129,13 @@ public static void clearCache() {
public CompletableFuture run(InferenceRequest request) {
return CompletableFuture.supplyAsync(() -> {
if (!storage.isAvailable(request.model())) {
+ System.out.println("[LumenForge] ERROR: Model not found locally: " + request.model().displayName());
return InferenceResult.fail("Model not found locally. Open Model Manager from the menu bar and download it first.");
}
Path modelPath = storage.modelPath(request.model());
+ System.out.println("[LumenForge] Starting inference: " + request.model().displayName()
+ + " (" + request.model().id() + ")");
// Temporarily intercept stderr so ONNX Runtime native warnings
// appear in the application's Log tab instead of only in the console.
@@ -146,6 +152,8 @@ public CompletableFuture run(InferenceRequest request) {
sessionOptions.setIntraOpNumThreads(Math.max(1, cpus - 1));
sessionOptions.setInterOpNumThreads(Math.max(1, Math.min(cpus / 2, 4)));
ProviderSelection providerSelection = configureExecutionProvider(sessionOptions, request.preferGpu());
+ System.out.println("[LumenForge] Using EP: " + providerSelection.provider()
+ + (providerSelection.notes().isBlank() ? "" : " (" + providerSelection.notes() + ")"));
// Invalidate session cache when the execution provider changes.
String epKey = providerSelection.provider() + "|" + request.preferGpu();
@@ -175,11 +183,6 @@ public CompletableFuture run(InferenceRequest request) {
&& request.model().relativePath().contains("transformer/")) {
return runSd3(environment, sessionOptions, request, providerSelection.provider());
}
- // DJL/PyTorch models — delegate to the DJL backend
- if (request.model().id().contains("pytorch")) {
- DjlPyTorchService djl = new DjlPyTorchService(storage, executor);
- return djl.run(request).join();
- }
// Img2Img pipelines
if ("sd_v15_img2img".equals(request.model().id())) {
return runImg2Img(environment, sessionOptions, request, providerSelection.provider(), false);
@@ -192,9 +195,11 @@ public CompletableFuture run(InferenceRequest request) {
+ request.model().displayName() + " | task=" + taskType.displayName()
+ " | EP=" + providerSelection.provider()
+ providerSelection.noteSuffix();
+ System.out.println("[LumenForge] WARN: " + details);
return InferenceResult.fail(details);
} catch (OrtException ex) {
String message = ex.getMessage() == null ? "Unknown ONNX Runtime error" : ex.getMessage();
+ System.out.println("[LumenForge] ERROR: ONNX Runtime error: " + message);
if (message.contains("NhwcConv")) {
return InferenceResult.fail(
"This ONNX model uses custom ops (e.g., NhwcConv) not available in CPUExecutionProvider. "
@@ -2600,8 +2605,22 @@ private List viterbi(String text) {
}
}
+ /**
+ * Build the EP preference list for the current platform, try each in order,
+ * and return the first one that the loaded ONNX Runtime native library supports.
+ *
+ * Priority order per platform (highest → lowest):
+ *
+ * - macOS: CoreML (GPU+ANE+CPU) → CPU
+ * - Windows: TensorRT-RTX → TensorRT → CUDA → DirectML → OpenVINO → CPU
+ * - Linux: TensorRT → CUDA → ROCm → OpenVINO → CPU
+ *
+ *
+ * Override with {@code -Dlumenforge.ep=cuda} (or any key) to force a specific EP.
+ */
private ProviderSelection configureExecutionProvider(OrtSession.SessionOptions options, boolean preferGpu) {
String os = System.getProperty("os.name", "unknown").toLowerCase();
+ String arch = System.getProperty("os.arch", "unknown").toLowerCase();
String forced = System.getProperty("lumenforge.ep", "").trim().toLowerCase();
List preference = new ArrayList<>();
@@ -2614,51 +2633,137 @@ private ProviderSelection configureExecutionProvider(OrtSession.SessionOptions o
preference.add("coreml");
preference.add("cpu");
} else if (os.contains("win")) {
- preference.add("directml");
- preference.add("cuda");
+ preference.add("tensorrt_rtx"); // RTX 30xx+ (Ampere+)
+ preference.add("tensorrt"); // any NVIDIA with TensorRT libs
+ preference.add("cuda"); // CUDA fallback
+ preference.add("directml"); // AMD / Intel / any DX12 GPU
+ preference.add("openvino"); // Intel CPUs/GPUs/NPUs
preference.add("cpu");
} else {
+ // Linux
+ preference.add("tensorrt");
preference.add("cuda");
- preference.add("rocm");
+ preference.add("rocm"); // AMD GPUs
+ preference.add("openvino");
preference.add("cpu");
}
StringBuilder notes = new StringBuilder();
+ List failReasons = new ArrayList<>();
+ System.out.println("[LumenForge] EP preference order: " + preference
+ + " (os=" + os + ", arch=" + arch + ")");
+
+ if (!preferGpu) {
+ System.out.println("[LumenForge] GPU not requested for this session — using CPUExecutionProvider");
+ return new ProviderSelection("CPUExecutionProvider", "GPU not requested");
+ }
+
for (String candidate : preference) {
if ("cpu".equals(candidate)) {
+ // All GPU EPs exhausted — log summary
+ System.out.println("[LumenForge] WARN: Falling back to CPUExecutionProvider");
+ if (!failReasons.isEmpty()) {
+ System.out.println("[LumenForge] WARN: Reason: every GPU execution provider was unavailable:");
+ for (String reason : failReasons) {
+ System.out.println("[LumenForge] WARN: - " + reason);
+ }
+ System.out.println("[LumenForge] WARN: Tip: install the matching GPU runtime "
+ + "(e.g. CUDA/cuDNN for NVIDIA, ROCm for AMD) or use "
+ + "-Dlumenforge.ep= to force a specific EP.");
+ }
return new ProviderSelection("CPUExecutionProvider", notes.toString());
}
- if (tryEnableProvider(options, candidate, notes)) {
- return new ProviderSelection(providerDisplayName(candidate), notes.toString());
+ String failReason = tryEnableProvider(options, candidate, notes);
+ if (failReason == null) {
+ String display = providerDisplayName(candidate);
+ System.out.println("[LumenForge] \u2713 Enabled " + display);
+ return new ProviderSelection(display, notes.toString());
+ }
+ failReasons.add(failReason);
+ }
+ // Preference list didn't include "cpu" explicitly — shouldn't happen, but handle it
+ System.out.println("[LumenForge] WARN: No EP available, falling back to CPUExecutionProvider");
+ if (!failReasons.isEmpty()) {
+ System.out.println("[LumenForge] WARN: Reasons:");
+ for (String reason : failReasons) {
+ System.out.println("[LumenForge] WARN: - " + reason);
}
}
return new ProviderSelection("CPUExecutionProvider", notes.toString());
}
- private boolean tryEnableProvider(OrtSession.SessionOptions options, String candidate, StringBuilder notes) {
+ /**
+ * Attempt to enable a single EP via reflection.
+ *
+ * @return {@code null} if the provider was enabled successfully,
+ * or a human-readable reason string if it could not be enabled.
+ */
+ private String tryEnableProvider(OrtSession.SessionOptions options, String candidate, StringBuilder notes) {
+ String failDetail = null;
try {
- return switch (candidate) {
- case "cuda" -> invokeNoArg(options, "addCUDA");
- case "directml" -> invokeNoArg(options, "addDirectML") || invokeIntArg(options, "addDirectML", 0);
+ boolean ok = switch (candidate) {
+
+ /* ── NVIDIA ───────────────────────────────────────────── */
+ case "tensorrt_rtx" ->
+ // NvTensorRtRtxExecutionProvider – RTX 30xx+ only
+ // Java method not yet in stock Maven artifacts; try reflection just in case
+ invokeNoArg(options, "addNvTensorRtRtx")
+ || invokeIntArg(options, "addNvTensorRtRtx", 0);
+
+ case "tensorrt" ->
+ // TensorrtExecutionProvider – available in onnxruntime_gpu
+ invokeIntArg(options, "addTensorrt", 0)
+ || invokeNoArg(options, "addTensorrt");
+
+ case "cuda" ->
+ invokeNoArg(options, "addCUDA");
+
+ /* ── Apple ────────────────────────────────────────────── */
case "coreml" -> {
- // Flags: COREML_FLAG_USE_CPU_AND_GPU (1) enables all ANE/GPU subgraphs
- boolean ok = invokeIntArg(options, "addCoreML", 1);
- if (!ok) { ok = invokeNoArg(options, "addCoreML"); }
- if (!ok) { ok = invokeIntArg(options, "addCoreML", 0); }
- yield ok;
+ // addCoreML(long flags) — note: parameter is long, not int!
+ // Flag 0x0 = ALL compute units (CPU+GPU+ANE — best for M-series)
+ boolean coreOk = invokeLongArg(options, "addCoreML", 0L);
+ if (!coreOk) { coreOk = invokeNoArg(options, "addCoreML"); }
+ yield coreOk;
}
- case "rocm" -> invokeNoArg(options, "addROCM");
+
+ /* ── Microsoft ────────────────────────────────────────── */
+ case "directml" ->
+ invokeIntArg(options, "addDirectML", 0)
+ || invokeNoArg(options, "addDirectML");
+
+ /* ── Intel ────────────────────────────────────────────── */
+ case "openvino" ->
+ // addOpenVINO(String) — device type "GPU" preferred, "CPU" fallback
+ invokeStringArg(options, "addOpenVINO", "GPU")
+ || invokeStringArg(options, "addOpenVINO", "CPU")
+ || invokeNoArg(options, "addOpenVINO");
+
+ /* ── AMD ──────────────────────────────────────────────── */
+ case "rocm" ->
+ invokeNoArg(options, "addROCM");
+
default -> false;
};
- } catch (Exception ex) {
- if (!notes.isEmpty()) {
- notes.append("; ");
+ if (!ok) {
+ failDetail = candidate + ": native library not found in classpath";
}
- notes.append(candidate).append(" unavailable");
- return false;
+ } catch (Exception ex) {
+ String msg = ex.getMessage();
+ if (msg == null && ex.getCause() != null) { msg = ex.getCause().getMessage(); }
+ failDetail = candidate + ": " + (msg != null ? msg : ex.getClass().getSimpleName());
}
+
+ if (failDetail != null) {
+ if (!notes.isEmpty()) { notes.append("; "); }
+ notes.append(candidate).append(" not available");
+ System.out.println("[LumenForge] WARN: \u2717 " + failDetail);
+ }
+ return failDetail;
}
+ /* ── Reflection helpers for EP registration ──────────────────────── */
+
private boolean invokeNoArg(OrtSession.SessionOptions options, String methodName) {
try {
options.getClass().getMethod(methodName).invoke(options);
@@ -2677,13 +2782,34 @@ private boolean invokeIntArg(OrtSession.SessionOptions options, String methodNam
}
}
+ private boolean invokeLongArg(OrtSession.SessionOptions options, String methodName, long arg) {
+ try {
+ options.getClass().getMethod(methodName, long.class).invoke(options, arg);
+ return true;
+ } catch (Exception ex) {
+ return false;
+ }
+ }
+
+ private boolean invokeStringArg(OrtSession.SessionOptions options, String methodName, String arg) {
+ try {
+ options.getClass().getMethod(methodName, String.class).invoke(options, arg);
+ return true;
+ } catch (Exception ex) {
+ return false;
+ }
+ }
+
private String providerDisplayName(String candidate) {
return switch (candidate) {
- case "cuda" -> "CUDAExecutionProvider";
- case "directml" -> "DmlExecutionProvider";
- case "coreml" -> "CoreMLExecutionProvider";
- case "rocm" -> "ROCMExecutionProvider";
- default -> "CPUExecutionProvider";
+ case "tensorrt_rtx" -> "NvTensorRtRtxExecutionProvider";
+ case "tensorrt" -> "TensorrtExecutionProvider";
+ case "cuda" -> "CUDAExecutionProvider";
+ case "coreml" -> "CoreMLExecutionProvider";
+ case "directml" -> "DmlExecutionProvider";
+ case "openvino" -> "OpenVINOExecutionProvider";
+ case "rocm" -> "ROCMExecutionProvider";
+ default -> "CPUExecutionProvider";
};
}
diff --git a/src/main/java/atri/palaash/lumenforge/inference/ServiceFactory.java b/src/main/java/atri/palaash/lumenforge/inference/ServiceFactory.java
index 00a0e0a..bc2c32d 100644
--- a/src/main/java/atri/palaash/lumenforge/inference/ServiceFactory.java
+++ b/src/main/java/atri/palaash/lumenforge/inference/ServiceFactory.java
@@ -25,31 +25,4 @@ public InferenceService create(TaskType taskType) {
}
return new GenericOnnxService(taskType, storage, executor);
}
-
- /**
- * Creates an inference service for the given task using the DJL PyTorch backend.
- * Falls back to the ONNX backend if DJL is not available.
- */
- public InferenceService createDjl(TaskType taskType) {
- if (DjlPyTorchService.isAvailable()) {
- return new DjlPyTorchService(storage, executor);
- }
- return create(taskType);
- }
-
- /**
- * Creates an inference service for the given model, automatically selecting
- * the DJL backend for PyTorch models and ONNX backend for everything else.
- */
- public InferenceService createForModel(String modelId, TaskType taskType) {
- if (modelId != null && modelId.contains("pytorch")) {
- return createDjl(taskType);
- }
- return create(taskType);
- }
-
- /** Returns {@code true} if the DJL PyTorch runtime is available. */
- public static boolean isDjlAvailable() {
- return DjlPyTorchService.isAvailable();
- }
}
diff --git a/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java b/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java
index 6c16e7a..94b1a47 100644
--- a/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java
+++ b/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java
@@ -56,6 +56,7 @@ public CompletableFuture download(ModelDescriptor descriptor, Consumer {
Path target = storage.modelPath(descriptor);
List writtenFiles = new ArrayList<>();
+ System.out.println("[LumenForge] Download started: " + descriptor.displayName());
try {
String downloadUrl = descriptor.sourceUrl();
if (!downloadUrl.toLowerCase(Locale.ROOT).contains(".onnx")
@@ -71,8 +72,12 @@ public CompletableFuture download(ModelDescriptor descriptor, Consumer {
if (!e.getValueIsAdjusting() && workflowList.getSelectedIndex() >= 0) {
managementList.clearSelection();
@@ -201,7 +201,7 @@ public MainFrame(ModelRegistry registry,
workflowList.setAlignmentX(Component.LEFT_ALIGNMENT);
workflowSection.add(workflowList);
- // Management section (pinned to bottom)
+ // Management section
JPanel managementSection = new JPanel();
managementSection.setLayout(new BoxLayout(managementSection, BoxLayout.Y_AXIS));
managementSection.setOpaque(false);
@@ -214,13 +214,21 @@ public MainFrame(ModelRegistry registry,
managementSection.add(managementHeader);
managementList.setAlignmentX(Component.LEFT_ALIGNMENT);
managementSection.add(managementList);
- managementSection.setBorder(BorderFactory.createEmptyBorder(0, 0, 12, 0));
- // Combine: workflow on top, management at bottom with glue between
+ // Diagnostics section removed
+
+ // Bottom panel: management
+ JPanel bottomSection = new JPanel();
+ bottomSection.setLayout(new BoxLayout(bottomSection, BoxLayout.Y_AXIS));
+ bottomSection.setOpaque(false);
+ bottomSection.setBorder(BorderFactory.createEmptyBorder(0, 0, 12, 0));
+ bottomSection.add(managementSection);
+
+ // Combine: workflow on top, management at bottom
JPanel sidebarContent = new JPanel(new BorderLayout());
sidebarContent.setOpaque(false);
sidebarContent.add(workflowSection, BorderLayout.NORTH);
- sidebarContent.add(managementSection, BorderLayout.SOUTH);
+ sidebarContent.add(bottomSection, BorderLayout.SOUTH);
sidebarPanel.add(sidebarContent, BorderLayout.CENTER);
@@ -322,11 +330,8 @@ private JMenuBar buildMenuBar() {
private void switchToCard(String card, JList targetList, int index) {
cardLayout.show(contentPanel, card);
- if (targetList == workflowList) {
- managementList.clearSelection();
- } else {
- workflowList.clearSelection();
- }
+ if (targetList != workflowList) workflowList.clearSelection();
+ if (targetList != managementList) managementList.clearSelection();
targetList.setSelectedIndex(index);
}
@@ -405,10 +410,8 @@ private String detectEpInfo() {
}
int cores = Runtime.getRuntime().availableProcessors();
long mem = Runtime.getRuntime().maxMemory() / (1024 * 1024);
- String djl = atri.palaash.lumenforge.inference.DjlPyTorchService.isAvailable()
- ? " \u2502 DJL \u2713" : "";
return "EP: " + ep + " \u2502 Cores: " + cores + " \u2502 Heap: " + mem
- + " MB \u2502 Java " + Runtime.version() + djl;
+ + " MB \u2502 Java " + Runtime.version();
}
/* ================================================================== */
diff --git a/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java b/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java
index 1ba0f76..c77ae84 100644
--- a/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java
+++ b/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java
@@ -203,6 +203,9 @@ private void startDownload(ModelDescriptor descriptor, int row) {
tableModel.setAvailable(row, true);
tableModel.updateProgress(descriptor.id(), 100);
statusLabel.setText("Downloaded: " + path);
+ if (onModelsUpdated != null) {
+ onModelsUpdated.run();
+ }
}));
}
@@ -258,7 +261,12 @@ private void importModelFromFile(ModelDescriptor descriptor, int row) {
tableModel.setAvailable(row, true);
tableModel.updateProgress(descriptor.id(), 100);
statusLabel.setText("Imported: " + target);
+ System.out.println("[LumenForge] Model imported: " + descriptor.displayName() + " from " + source);
+ if (onModelsUpdated != null) {
+ onModelsUpdated.run();
+ }
} catch (IOException ex) {
+ System.out.println("[LumenForge] ERROR: Import failed for " + descriptor.displayName() + ": " + ex.getMessage());
statusLabel.setText("Import failed: " + ex.getMessage());
}
}
@@ -313,6 +321,7 @@ private void startConversion(String modelId, String displayName, String mode) {
statusLabel.setText("Converting " + displayName + " to ONNX\u2026");
progressBar.setVisible(true);
progressBar.setIndeterminate(true);
+ System.out.println("[LumenForge] Starting PyTorch \u2192 ONNX conversion: " + modelId);
// 5. Run on background thread
final String conversionMode = mode;
@@ -337,6 +346,7 @@ private void startConversion(String modelId, String displayName, String mode) {
PyTorchToOnnxConverter.openPythonDownloadPage();
}
} else {
+ System.out.println("[LumenForge] ERROR: Conversion failed for " + modelId + ": " + cause.getMessage());
statusLabel.setText("Conversion failed: " + cause.getMessage());
JOptionPane.showMessageDialog(this,
"Conversion failed:\n" + cause.getMessage(),
@@ -385,6 +395,7 @@ private void registerConvertedModel(String modelId, String sanitized, Path outpu
modelRegistry.mergeDownloadableAssets(List.of(desc));
refreshTable();
statusLabel.setText("\u2713 Converted and registered: " + modelId);
+ System.out.println("[LumenForge] Conversion complete: " + modelId + " \u2192 " + relativePath);
if (onModelsUpdated != null) {
onModelsUpdated.run();
}