From 821a6234eec7221e2baed485ff4e3286a320279a Mon Sep 17 00:00:00 2001 From: Palaash Atri Date: Tue, 24 Feb 2026 16:12:28 +0530 Subject: [PATCH 1/5] fix model load --- .../java/atri/palaash/lumenforge/ui/ModelManagerPanel.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java b/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java index 1ba0f76..2325a74 100644 --- a/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java +++ b/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java @@ -203,6 +203,9 @@ private void startDownload(ModelDescriptor descriptor, int row) { tableModel.setAvailable(row, true); tableModel.updateProgress(descriptor.id(), 100); statusLabel.setText("Downloaded: " + path); + if (onModelsUpdated != null) { + onModelsUpdated.run(); + } })); } @@ -258,6 +261,9 @@ private void importModelFromFile(ModelDescriptor descriptor, int row) { tableModel.setAvailable(row, true); tableModel.updateProgress(descriptor.id(), 100); statusLabel.setText("Imported: " + target); + if (onModelsUpdated != null) { + onModelsUpdated.run(); + } } catch (IOException ex) { statusLabel.setText("Import failed: " + ex.getMessage()); } From 208d99e206dc3ffe6810fd1f250192a055d0628c Mon Sep 17 00:00:00 2001 From: Palaash Atri Date: Tue, 24 Feb 2026 16:31:57 +0530 Subject: [PATCH 2/5] fix gpu support --- pom.xml | 47 +++---- .../inference/GenericOnnxService.java | 130 ++++++++++++++---- 2 files changed, 127 insertions(+), 50 deletions(-) diff --git a/pom.xml b/pom.xml index 3ecd726..44bfea1 100644 --- a/pom.xml +++ b/pom.xml @@ -11,13 +11,16 @@ UTF-8 21 + + onnxruntime + 1.22.0 com.microsoft.onnxruntime - onnxruntime - 1.22.0 + ${ort.artifactId} + ${ort.version} com.formdev @@ -109,24 +112,22 @@ + gpu-windows - - Windows - - - onnx.gpu - true - + Windows - - - com.microsoft.onnxruntime - onnxruntime_gpu - 1.22.0 - - + + onnxruntime_gpu + gpu-linux @@ -135,18 +136,10 @@ unix Linux - - onnx.gpu - true - - - - com.microsoft.onnxruntime - onnxruntime_gpu - 1.22.0 - - + + onnxruntime_gpu + diff --git a/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java b/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java index 58b8c75..4bc549b 100644 --- a/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java +++ b/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java @@ -2600,8 +2600,22 @@ private List viterbi(String text) { } } + /** + * Build the EP preference list for the current platform, try each in order, + * and return the first one that the loaded ONNX Runtime native library supports. + * + *

Priority order per platform (highest → lowest): + *

    + *
  • macOS: CoreML (GPU+ANE+CPU) → CPU
  • + *
  • Windows: TensorRT-RTX → TensorRT → CUDA → DirectML → OpenVINO → CPU
  • + *
  • Linux: TensorRT → CUDA → ROCm → OpenVINO → CPU
  • + *
+ * + *

Override with {@code -Dlumenforge.ep=cuda} (or any key) to force a specific EP. + */ private ProviderSelection configureExecutionProvider(OrtSession.SessionOptions options, boolean preferGpu) { String os = System.getProperty("os.name", "unknown").toLowerCase(); + String arch = System.getProperty("os.arch", "unknown").toLowerCase(); String forced = System.getProperty("lumenforge.ep", "").trim().toLowerCase(); List preference = new ArrayList<>(); @@ -2614,51 +2628,100 @@ private ProviderSelection configureExecutionProvider(OrtSession.SessionOptions o preference.add("coreml"); preference.add("cpu"); } else if (os.contains("win")) { - preference.add("directml"); - preference.add("cuda"); + preference.add("tensorrt_rtx"); // RTX 30xx+ (Ampere+) + preference.add("tensorrt"); // any NVIDIA with TensorRT libs + preference.add("cuda"); // CUDA fallback + preference.add("directml"); // AMD / Intel / any DX12 GPU + preference.add("openvino"); // Intel CPUs/GPUs/NPUs preference.add("cpu"); } else { + // Linux + preference.add("tensorrt"); preference.add("cuda"); - preference.add("rocm"); + preference.add("rocm"); // AMD GPUs + preference.add("openvino"); preference.add("cpu"); } StringBuilder notes = new StringBuilder(); + System.out.println("[LumenForge] EP preference order: " + preference); + for (String candidate : preference) { if ("cpu".equals(candidate)) { + System.out.println("[LumenForge] Using CPUExecutionProvider"); return new ProviderSelection("CPUExecutionProvider", notes.toString()); } if (tryEnableProvider(options, candidate, notes)) { - return new ProviderSelection(providerDisplayName(candidate), notes.toString()); + String display = providerDisplayName(candidate); + System.out.println("[LumenForge] \u2713 Enabled " + display); + return new ProviderSelection(display, notes.toString()); } } + System.out.println("[LumenForge] No GPU EP available, falling back to CPU"); return new ProviderSelection("CPUExecutionProvider", notes.toString()); } private boolean tryEnableProvider(OrtSession.SessionOptions options, String candidate, StringBuilder notes) { + boolean ok = false; try { - return switch (candidate) { - case "cuda" -> invokeNoArg(options, "addCUDA"); - case "directml" -> invokeNoArg(options, "addDirectML") || invokeIntArg(options, "addDirectML", 0); + ok = switch (candidate) { + + /* ── NVIDIA ───────────────────────────────────────────── */ + case "tensorrt_rtx" -> + // NvTensorRtRtxExecutionProvider – RTX 30xx+ only + // Java method not yet in stock Maven artifacts; try reflection just in case + invokeNoArg(options, "addNvTensorRtRtx") + || invokeIntArg(options, "addNvTensorRtRtx", 0); + + case "tensorrt" -> + // TensorrtExecutionProvider – available in onnxruntime_gpu + invokeIntArg(options, "addTensorrt", 0) + || invokeNoArg(options, "addTensorrt"); + + case "cuda" -> + invokeNoArg(options, "addCUDA"); + + /* ── Apple ────────────────────────────────────────────── */ case "coreml" -> { - // Flags: COREML_FLAG_USE_CPU_AND_GPU (1) enables all ANE/GPU subgraphs - boolean ok = invokeIntArg(options, "addCoreML", 1); - if (!ok) { ok = invokeNoArg(options, "addCoreML"); } - if (!ok) { ok = invokeIntArg(options, "addCoreML", 0); } - yield ok; + // addCoreML(long flags) — note: parameter is long, not int! + // Flag 0x0 = ALL compute units (CPU+GPU+ANE — best for M-series) + boolean coreOk = invokeLongArg(options, "addCoreML", 0L); + if (!coreOk) { coreOk = invokeNoArg(options, "addCoreML"); } + yield coreOk; } - case "rocm" -> invokeNoArg(options, "addROCM"); + + /* ── Microsoft ────────────────────────────────────────── */ + case "directml" -> + invokeIntArg(options, "addDirectML", 0) + || invokeNoArg(options, "addDirectML"); + + /* ── Intel ────────────────────────────────────────────── */ + case "openvino" -> + // addOpenVINO(String) — device type "GPU" preferred, "CPU" fallback + invokeStringArg(options, "addOpenVINO", "GPU") + || invokeStringArg(options, "addOpenVINO", "CPU") + || invokeNoArg(options, "addOpenVINO"); + + /* ── AMD ──────────────────────────────────────────────── */ + case "rocm" -> + invokeNoArg(options, "addROCM"); + default -> false; }; } catch (Exception ex) { - if (!notes.isEmpty()) { - notes.append("; "); - } - notes.append(candidate).append(" unavailable"); - return false; + ok = false; } + + if (!ok) { + if (!notes.isEmpty()) { notes.append("; "); } + notes.append(candidate).append(" not available"); + System.out.println("[LumenForge] \u2717 " + candidate + " not available"); + } + return ok; } + /* ── Reflection helpers for EP registration ──────────────────────── */ + private boolean invokeNoArg(OrtSession.SessionOptions options, String methodName) { try { options.getClass().getMethod(methodName).invoke(options); @@ -2677,13 +2740,34 @@ private boolean invokeIntArg(OrtSession.SessionOptions options, String methodNam } } + private boolean invokeLongArg(OrtSession.SessionOptions options, String methodName, long arg) { + try { + options.getClass().getMethod(methodName, long.class).invoke(options, arg); + return true; + } catch (Exception ex) { + return false; + } + } + + private boolean invokeStringArg(OrtSession.SessionOptions options, String methodName, String arg) { + try { + options.getClass().getMethod(methodName, String.class).invoke(options, arg); + return true; + } catch (Exception ex) { + return false; + } + } + private String providerDisplayName(String candidate) { return switch (candidate) { - case "cuda" -> "CUDAExecutionProvider"; - case "directml" -> "DmlExecutionProvider"; - case "coreml" -> "CoreMLExecutionProvider"; - case "rocm" -> "ROCMExecutionProvider"; - default -> "CPUExecutionProvider"; + case "tensorrt_rtx" -> "NvTensorRtRtxExecutionProvider"; + case "tensorrt" -> "TensorrtExecutionProvider"; + case "cuda" -> "CUDAExecutionProvider"; + case "coreml" -> "CoreMLExecutionProvider"; + case "directml" -> "DmlExecutionProvider"; + case "openvino" -> "OpenVINOExecutionProvider"; + case "rocm" -> "ROCMExecutionProvider"; + default -> "CPUExecutionProvider"; }; } From d9d3b65bfafc75a4b45b4d52928bebf07ae75746 Mon Sep 17 00:00:00 2001 From: Palaash Atri Date: Tue, 24 Feb 2026 16:34:30 +0530 Subject: [PATCH 3/5] add github action to publish related jars --- .github/workflows/build.yml | 91 +++++++++++++++++++++++++++++++++++++ pom.xml | 3 ++ 2 files changed, 94 insertions(+) create mode 100644 .github/workflows/build.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..1950e51 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,91 @@ +name: Build LumenForge + +on: + push: + branches: [ main ] + tags: [ 'v*' ] + pull_request: + branches: [ main ] + workflow_dispatch: + +permissions: + contents: write # needed for release uploads + +jobs: + build: + name: ${{ matrix.label }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + # ── Universal (CPU + CoreML on macOS) ────────────────────── + # Works on macOS (CoreML GPU/ANE), Windows (CPU), Linux (CPU). + # AMD DirectML / Intel OpenVINO users: install native libs + # separately and this JAR will auto-detect them at runtime. + - label: "Universal (CPU / macOS CoreML)" + ort_artifact: onnxruntime + classifier: "-universal" + asset_name: lumenforge-universal.jar + + # ── NVIDIA GPU (CUDA + TensorRT) ─────────────────────────── + # For Windows and Linux with NVIDIA GPUs. + # Requires CUDA toolkit + cuDNN on the host system. + - label: "NVIDIA GPU (CUDA + TensorRT)" + ort_artifact: onnxruntime_gpu + classifier: "-nvidia" + asset_name: lumenforge-nvidia.jar + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up JDK 21 + uses: actions/setup-java@v4 + with: + distribution: temurin + java-version: '21' + cache: maven + + - name: Build fat JAR + run: | + mvn clean package -DskipTests \ + -Dort.artifactId=${{ matrix.ort_artifact }} \ + -Djar.classifier=${{ matrix.classifier }} + + - name: Rename artifact + run: | + JAR=$(find target -maxdepth 1 -name "lumenforge-*${{ matrix.classifier }}.jar" | head -1) + cp "$JAR" "target/${{ matrix.asset_name }}" + echo "Built: ${{ matrix.asset_name }} ($(du -h target/${{ matrix.asset_name }} | cut -f1))" + + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: ${{ matrix.asset_name }} + path: target/${{ matrix.asset_name }} + retention-days: 14 + + # ── Release (only on version tags) ────────────────────────────────── + release: + name: Create Release + needs: build + if: startsWith(github.ref, 'refs/tags/v') + runs-on: ubuntu-latest + steps: + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: artifacts + merge-multiple: true + + - name: List artifacts + run: ls -lh artifacts/ + + - name: Create GitHub Release + uses: softprops/action-gh-release@v2 + with: + generate_release_notes: true + files: | + artifacts/lumenforge-universal.jar + artifacts/lumenforge-nvidia.jar diff --git a/pom.xml b/pom.xml index 44bfea1..e4839e6 100644 --- a/pom.xml +++ b/pom.xml @@ -14,6 +14,8 @@ onnxruntime 1.22.0 + + @@ -72,6 +74,7 @@ shade false + ${project.artifactId}-${project.version}${jar.classifier} atri.palaash.lumenforge.app.LumenForgeApp From 4febf09e20694c074d3f45bd9afbc2c0eaff8394 Mon Sep 17 00:00:00 2001 From: Palaash Atri Date: Tue, 24 Feb 2026 17:03:02 +0530 Subject: [PATCH 4/5] fix gpu acceleration + better logging --- README.md | 113 ++++++++++++----- .../palaash/lumenforge/app/LumenForgeApp.java | 6 + .../inference/GenericOnnxService.java | 72 +++++++++-- .../lumenforge/storage/ModelDownloader.java | 9 ++ .../atri/palaash/lumenforge/ui/AppLogger.java | 88 +++++++++++++ .../atri/palaash/lumenforge/ui/LogsPanel.java | 118 ++++++++++++++++++ .../atri/palaash/lumenforge/ui/MainFrame.java | 66 ++++++++-- .../lumenforge/ui/ModelManagerPanel.java | 5 + 8 files changed, 424 insertions(+), 53 deletions(-) create mode 100644 src/main/java/atri/palaash/lumenforge/ui/AppLogger.java create mode 100644 src/main/java/atri/palaash/lumenforge/ui/LogsPanel.java diff --git a/README.md b/README.md index 95061e1..72e0b02 100644 --- a/README.md +++ b/README.md @@ -1,66 +1,117 @@ # LumenForge -Desktop Java Swing application for Java-native ONNX Runtime workflows with automatic GPU/CPU provider fallback. +[![Build LumenForge](https://github.com/palaashatri/lumenforge/actions/workflows/build.yml/badge.svg)](https://github.com/palaashatri/lumenforge/actions/workflows/build.yml) -- Text → Image: Stable Diffusion v1.5 UNet + Real-ESRGAN with fully automatic downloads +Desktop Java Swing application for ONNX Runtime inference with intelligent GPU acceleration across NVIDIA, Apple, Intel, and AMD hardware. image - ## Features -- Native look-and-feel handling: - - macOS: system-native look-and-feel - - Windows/Linux: FlatLaf with automatic system dark/light detection -- High-performance async execution using virtual threads -- Model downloader module with progress reporting -- Auto-download when a selected model is missing in task tabs -- Model Manager downloads only on explicit **Download / Redownload** button click -- Model Manager includes additional ONNX pipeline assets (UNet/Text Encoder/VAE/Safety Checker) for one-click download -- Manual ONNX model import per selected row in Model Manager -- Prompt library: saved presets + tags + search +### Image Generation +- **Text → Image**: Stable Diffusion v1.5, SD Turbo, SDXL Turbo, SDXL Base 1.0, SD 3.5 (MMDiT) +- **Image → Image**: Img2Img with adjustable strength + inpainting +- **Image Upscaling**: Real-ESRGAN 4× super-resolution with before/after preview +- **Prompt library**: saved presets, tags, and search - Negative prompts + prompt weighting -- Seed control + reproducibility presets -- Batch generation settings (N images per prompt) +- Seed control + reproducibility +- Batch generation (N images per prompt) - Aspect ratio presets + custom size fields -- Upscale toggle (Real-ESRGAN) with before/after preview (UI wiring) - Style presets (cinematic, sketch, product, etc.) - History gallery with metadata (seed, model, settings) - Export pipeline logs for debugging + +### SD 3.5 Support +- MMDiT transformer architecture with Flow Matching Euler scheduler +- Triple text encoding: CLIP-L (768d) + CLIP-G (1280d) + T5-XXL (4096d) +- Built-in T5 tokenizer (SentencePiece/Unigram with Viterbi segmentation) +- 16-channel latent space + +### Model Management +- **Automatic HuggingFace discovery**: finds ONNX and PyTorch Stable Diffusion + ESRGAN models +- **One-click download** with resume, retry, and stall detection +- **PyTorch → ONNX auto-conversion**: downloading a PyTorch model triggers automatic conversion via managed Python venv +- Manual ONNX model import +- Gated model support with HuggingFace token authentication - Local model storage in `~/.lumenforge-models` -- Execution provider fallback by OS: - - macOS: CoreML → CPU - - Windows: DirectML → CUDA → CPU - - Linux: CUDA → ROCm → CPU -## Build +### GPU Acceleration +Intelligent execution provider selection — LumenForge probes available EPs at runtime and picks the best one: + +| Platform | Priority (highest → lowest) | +|---|---| +| **macOS** | CoreML (GPU + ANE + CPU) → CPU | +| **Windows** | TensorRT-RTX → TensorRT → CUDA → DirectML → OpenVINO → CPU | +| **Linux** | TensorRT → CUDA → ROCm → OpenVINO → CPU | + +Override with `-Dlumenforge.ep=cuda` (or any EP key) to force a specific provider. + +### UI +- Native look-and-feel: system-native on macOS, FlatLaf with dark/light detection on Windows/Linux +- High-performance async execution using virtual threads +- Per-step progress with timing and ETA +- Session and tokenizer caching for fast repeated inference + +## Downloads + +Pre-built fat JARs are available from [GitHub Releases](https://github.com/palaashatri/lumenforge/releases): + +| JAR | GPU Support | Use When | +|---|---|---| +| `lumenforge-universal.jar` | macOS CoreML (M-series GPU/ANE), CPU everywhere | macOS, or Windows/Linux without NVIDIA GPU | +| `lumenforge-nvidia.jar` | CUDA + TensorRT (Windows/Linux) | Windows/Linux with NVIDIA GPU + CUDA installed | + +> **Note**: DirectML (AMD/Intel on Windows), OpenVINO (Intel), and ROCm (AMD on Linux) are auto-detected at runtime if the native libraries are installed on the system. The universal JAR handles this automatically. ```bash -mvn -DskipTests compile +# Run any variant +java -jar lumenforge-universal.jar +java -jar lumenforge-nvidia.jar ``` -Enable GPU runtime artifact (Windows/Linux): +## Build from Source + +Requires **Java 21+** and **Maven 3.8+**. ```bash -mvn -Donnx.gpu=true -DskipTests compile +# Universal build (CPU + CoreML) +mvn clean package -DskipTests + +# NVIDIA GPU build (CUDA + TensorRT) +mvn clean package -DskipTests -Dort.artifactId=onnxruntime_gpu + +# Force CPU-only +mvn clean package -DskipTests -Dort.artifactId=onnxruntime ``` -## Run +### Run from source ```bash -mvn exec:java +mvn clean compile exec:java ``` -## Test +### Run tests ```bash mvn clean test ``` +## CI / CD + +GitHub Actions builds both JAR variants on every push to `main` and PR. Pushing a version tag (e.g. `v1.0.0`) creates a GitHub Release with both JARs attached. + +See [.github/workflows/build.yml](.github/workflows/build.yml) for details. + +## Requirements + +- **Java 21** or later +- **macOS 10.15+** for CoreML acceleration (M-series recommended) +- **CUDA 12 + cuDNN** for NVIDIA GPU acceleration (RTX 30xx+ recommended) +- **Python 3.8+** (optional) for PyTorch → ONNX model conversion + ## Notes -- Runtime is Java-only ONNX Runtime execution with GPU fallback (no Python bridge). -- Override provider order via JVM property: `-Dlumenforge.ep=cpu|coreml|directml|cuda|rocm`. -- Task tabs show preview images when an output artifact is generated, with **Open Output** to launch the file. -- SD Turbo UNet remains experimental and may fail on CPU-only environments. +- Runtime is pure Java — ONNX Runtime execution with GPU fallback. No Python bridge needed at inference time. +- Override EP order via JVM property: `-Dlumenforge.ep=cpu|coreml|cuda|tensorrt|directml|openvino|rocm` +- Task tabs show preview images when output is generated, with **Open Output** to launch the file. - If a model requires external tensor files (e.g. `weights.pb`), import the complete ONNX bundle into the model directory. diff --git a/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java b/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java index 1ba6d45..6342325 100644 --- a/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java +++ b/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java @@ -6,6 +6,7 @@ import atri.palaash.lumenforge.model.TaskType; import atri.palaash.lumenforge.storage.ModelDownloader; import atri.palaash.lumenforge.storage.ModelStorage; +import atri.palaash.lumenforge.ui.AppLogger; import atri.palaash.lumenforge.ui.MainFrame; import atri.palaash.lumenforge.ui.NativeLookAndFeel; @@ -22,6 +23,11 @@ public class LumenForgeApp { public static void main(String[] args) { configureDesktopIntegration(); + AppLogger.app("LumenForge starting \u2014 Java " + Runtime.version() + + ", " + System.getProperty("os.name") + " " + System.getProperty("os.arch")); + AppLogger.app("Max heap: " + (Runtime.getRuntime().maxMemory() / (1024 * 1024)) + " MB" + + ", CPUs: " + Runtime.getRuntime().availableProcessors()); + ExecutorService workerPool = Executors.newVirtualThreadPerTaskExecutor(); ModelStorage storage = new ModelStorage(); ModelRegistry registry = new ModelRegistry(); diff --git a/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java b/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java index 4bc549b..a2e654e 100644 --- a/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java +++ b/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java @@ -10,6 +10,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import atri.palaash.lumenforge.model.TaskType; import atri.palaash.lumenforge.storage.ModelStorage; +import atri.palaash.lumenforge.ui.AppLogger; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; @@ -71,10 +72,13 @@ private OrtSession getOrCreateSession(OrtEnvironment env, Path modelPath, String key = modelPath.toAbsolutePath().toString(); OrtSession existing = SESSION_CACHE.get(key); if (existing != null) { + AppLogger.model("Session cache hit: " + modelPath.getFileName()); return existing; } + AppLogger.model("Loading ONNX session: " + modelPath.getFileName()); OrtSession session = env.createSession(modelPath.toString(), opts); SESSION_CACHE.put(key, session); + AppLogger.model("Session loaded: " + modelPath.getFileName()); return session; } @@ -126,10 +130,13 @@ public static void clearCache() { public CompletableFuture run(InferenceRequest request) { return CompletableFuture.supplyAsync(() -> { if (!storage.isAvailable(request.model())) { + AppLogger.modelError("Model not found locally: " + request.model().displayName()); return InferenceResult.fail("Model not found locally. Open Model Manager from the menu bar and download it first."); } Path modelPath = storage.modelPath(request.model()); + AppLogger.model("Starting inference: " + request.model().displayName() + + " (" + request.model().id() + ")"); // Temporarily intercept stderr so ONNX Runtime native warnings // appear in the application's Log tab instead of only in the console. @@ -146,6 +153,8 @@ public CompletableFuture run(InferenceRequest request) { sessionOptions.setIntraOpNumThreads(Math.max(1, cpus - 1)); sessionOptions.setInterOpNumThreads(Math.max(1, Math.min(cpus / 2, 4))); ProviderSelection providerSelection = configureExecutionProvider(sessionOptions, request.preferGpu()); + AppLogger.model("Using EP: " + providerSelection.provider() + + (providerSelection.notes().isBlank() ? "" : " (" + providerSelection.notes() + ")")); // Invalidate session cache when the execution provider changes. String epKey = providerSelection.provider() + "|" + request.preferGpu(); @@ -192,9 +201,11 @@ public CompletableFuture run(InferenceRequest request) { + request.model().displayName() + " | task=" + taskType.displayName() + " | EP=" + providerSelection.provider() + providerSelection.noteSuffix(); + AppLogger.modelWarn(details); return InferenceResult.fail(details); } catch (OrtException ex) { String message = ex.getMessage() == null ? "Unknown ONNX Runtime error" : ex.getMessage(); + AppLogger.modelError("ONNX Runtime error: " + message); if (message.contains("NhwcConv")) { return InferenceResult.fail( "This ONNX model uses custom ops (e.g., NhwcConv) not available in CPUExecutionProvider. " @@ -2644,27 +2655,59 @@ private ProviderSelection configureExecutionProvider(OrtSession.SessionOptions o } StringBuilder notes = new StringBuilder(); - System.out.println("[LumenForge] EP preference order: " + preference); + List failReasons = new ArrayList<>(); + AppLogger.app("EP preference order: " + preference + + " (os=" + os + ", arch=" + arch + ")"); + + if (!preferGpu) { + AppLogger.app("GPU not requested for this session — using CPUExecutionProvider"); + return new ProviderSelection("CPUExecutionProvider", "GPU not requested"); + } for (String candidate : preference) { if ("cpu".equals(candidate)) { - System.out.println("[LumenForge] Using CPUExecutionProvider"); + // All GPU EPs exhausted — log summary + AppLogger.appWarn("Falling back to CPUExecutionProvider"); + if (!failReasons.isEmpty()) { + AppLogger.appWarn("Reason: every GPU execution provider was unavailable:"); + for (String reason : failReasons) { + AppLogger.appWarn(" - " + reason); + } + AppLogger.appWarn("Tip: install the matching GPU runtime " + + "(e.g. CUDA/cuDNN for NVIDIA, ROCm for AMD) or use " + + "-Dlumenforge.ep= to force a specific EP."); + } return new ProviderSelection("CPUExecutionProvider", notes.toString()); } - if (tryEnableProvider(options, candidate, notes)) { + String failReason = tryEnableProvider(options, candidate, notes); + if (failReason == null) { String display = providerDisplayName(candidate); - System.out.println("[LumenForge] \u2713 Enabled " + display); + AppLogger.app("\u2713 Enabled " + display); return new ProviderSelection(display, notes.toString()); } + failReasons.add(failReason); + } + // Preference list didn't include "cpu" explicitly — shouldn't happen, but handle it + AppLogger.appWarn("No EP available, falling back to CPUExecutionProvider"); + if (!failReasons.isEmpty()) { + AppLogger.appWarn("Reasons:"); + for (String reason : failReasons) { + AppLogger.appWarn(" - " + reason); + } } - System.out.println("[LumenForge] No GPU EP available, falling back to CPU"); return new ProviderSelection("CPUExecutionProvider", notes.toString()); } - private boolean tryEnableProvider(OrtSession.SessionOptions options, String candidate, StringBuilder notes) { - boolean ok = false; + /** + * Attempt to enable a single EP via reflection. + * + * @return {@code null} if the provider was enabled successfully, + * or a human-readable reason string if it could not be enabled. + */ + private String tryEnableProvider(OrtSession.SessionOptions options, String candidate, StringBuilder notes) { + String failDetail = null; try { - ok = switch (candidate) { + boolean ok = switch (candidate) { /* ── NVIDIA ───────────────────────────────────────────── */ case "tensorrt_rtx" -> @@ -2708,16 +2751,21 @@ private boolean tryEnableProvider(OrtSession.SessionOptions options, String cand default -> false; }; + if (!ok) { + failDetail = candidate + ": native library not found in classpath"; + } } catch (Exception ex) { - ok = false; + String msg = ex.getMessage(); + if (msg == null && ex.getCause() != null) { msg = ex.getCause().getMessage(); } + failDetail = candidate + ": " + (msg != null ? msg : ex.getClass().getSimpleName()); } - if (!ok) { + if (failDetail != null) { if (!notes.isEmpty()) { notes.append("; "); } notes.append(candidate).append(" not available"); - System.out.println("[LumenForge] \u2717 " + candidate + " not available"); + AppLogger.appWarn("\u2717 " + failDetail); } - return ok; + return failDetail; } /* ── Reflection helpers for EP registration ──────────────────────── */ diff --git a/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java b/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java index 6c16e7a..1c42e1b 100644 --- a/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java +++ b/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java @@ -2,6 +2,7 @@ import atri.palaash.lumenforge.model.ModelDescriptor; import atri.palaash.lumenforge.model.TaskType; +import atri.palaash.lumenforge.ui.AppLogger; import java.io.IOException; import java.io.InputStream; @@ -56,6 +57,7 @@ public CompletableFuture download(ModelDescriptor descriptor, Consumer { Path target = storage.modelPath(descriptor); List writtenFiles = new ArrayList<>(); + AppLogger.model("Download started: " + descriptor.displayName()); try { String downloadUrl = descriptor.sourceUrl(); if (!downloadUrl.toLowerCase(Locale.ROOT).contains(".onnx") @@ -71,8 +73,12 @@ public CompletableFuture download(ModelDescriptor descriptor, ConsumerTwo independent channels: + *

    + *
  • {@link Channel#APPLICATION} — startup, EP selection, configuration, errors
  • + *
  • {@link Channel#MODEL} — model download, import, conversion, inference
  • + *
+ */ +public final class AppLogger { + + public enum Channel { APPLICATION, MODEL } + public enum Level { INFO, WARN, ERROR } + + public record LogEntry(LocalDateTime timestamp, Channel channel, Level level, String message) { } + + private static final DateTimeFormatter TS_FORMAT = DateTimeFormatter.ofPattern("HH:mm:ss.SSS"); + private static final List> listeners = new CopyOnWriteArrayList<>(); + + private AppLogger() { } + + /* ── Listener management ─────────────────────────────────────── */ + + public static void addListener(Consumer listener) { + listeners.add(listener); + } + + public static void removeListener(Consumer listener) { + listeners.remove(listener); + } + + /* ── Convenience logging ─────────────────────────────────────── */ + + public static void app(String message) { + log(Channel.APPLICATION, Level.INFO, message); + } + + public static void appWarn(String message) { + log(Channel.APPLICATION, Level.WARN, message); + } + + public static void appError(String message) { + log(Channel.APPLICATION, Level.ERROR, message); + } + + public static void model(String message) { + log(Channel.MODEL, Level.INFO, message); + } + + public static void modelWarn(String message) { + log(Channel.MODEL, Level.WARN, message); + } + + public static void modelError(String message) { + log(Channel.MODEL, Level.ERROR, message); + } + + /* ── Core ────────────────────────────────────────────────────── */ + + public static void log(Channel channel, Level level, String message) { + LogEntry entry = new LogEntry(LocalDateTime.now(), channel, level, message); + for (Consumer listener : listeners) { + try { + listener.accept(entry); + } catch (Exception ignored) { } + } + } + + /** Format a log entry for display in the Logs panel. */ + public static String format(LogEntry entry) { + String prefix = switch (entry.level()) { + case INFO -> ""; + case WARN -> "\u26a0 "; + case ERROR -> "\u274c "; + }; + return TS_FORMAT.format(entry.timestamp()) + " " + prefix + entry.message(); + } +} diff --git a/src/main/java/atri/palaash/lumenforge/ui/LogsPanel.java b/src/main/java/atri/palaash/lumenforge/ui/LogsPanel.java new file mode 100644 index 0000000..ea92369 --- /dev/null +++ b/src/main/java/atri/palaash/lumenforge/ui/LogsPanel.java @@ -0,0 +1,118 @@ +package atri.palaash.lumenforge.ui; + +import atri.palaash.lumenforge.ui.AppLogger.Channel; +import atri.palaash.lumenforge.ui.AppLogger.LogEntry; + +import javax.swing.BorderFactory; +import javax.swing.Box; +import javax.swing.BoxLayout; +import javax.swing.JButton; +import javax.swing.JPanel; +import javax.swing.JScrollPane; +import javax.swing.JTextArea; +import javax.swing.JToggleButton; +import javax.swing.SwingUtilities; +import javax.swing.Timer; +import java.awt.BorderLayout; +import java.awt.Dimension; +import java.awt.Font; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +/** + * Logs panel with two toggle-filtered views: Application and Model. + *

+ * Entries are appended in real-time via {@link AppLogger} and auto-scroll + * to the bottom. A Clear button wipes the visible log. + */ +public class LogsPanel extends JPanel { + + private final JTextArea logArea; + private final List allEntries = new ArrayList<>(); + private boolean showApp = true; + private boolean showModel = true; + + /* Coalesce rapid-fire updates (e.g. per-step progress) to ~60 fps. */ + private boolean dirty = false; + private final Timer coalesceTimer; + + public LogsPanel() { + super(new BorderLayout(0, 0)); + + /* ── Log text area (init first — referenced by toolbar buttons) ── */ + logArea = new JTextArea(); + logArea.setEditable(false); + logArea.setFont(new Font(Font.MONOSPACED, Font.PLAIN, 12)); + logArea.setLineWrap(true); + logArea.setWrapStyleWord(true); + JScrollPane scroll = new JScrollPane(logArea); + scroll.setBorder(BorderFactory.createEmptyBorder()); + + /* ── Toolbar ─────────────────────────────────────────────── */ + JPanel toolbar = new JPanel(); + toolbar.setLayout(new BoxLayout(toolbar, BoxLayout.X_AXIS)); + toolbar.setBorder(BorderFactory.createEmptyBorder(6, 12, 6, 12)); + + JToggleButton appBtn = new JToggleButton("Application", showApp); + JToggleButton modelBtn = new JToggleButton("Model", showModel); + appBtn.setFocusable(false); + modelBtn.setFocusable(false); + appBtn.addActionListener(e -> { showApp = appBtn.isSelected(); rebuild(); }); + modelBtn.addActionListener(e -> { showModel = modelBtn.isSelected(); rebuild(); }); + + JButton clearBtn = new JButton("Clear"); + clearBtn.setFocusable(false); + clearBtn.addActionListener(e -> { + synchronized (allEntries) { allEntries.clear(); } + logArea.setText(""); + }); + + toolbar.add(appBtn); + toolbar.add(Box.createRigidArea(new Dimension(4, 0))); + toolbar.add(modelBtn); + toolbar.add(Box.createHorizontalGlue()); + toolbar.add(clearBtn); + + add(toolbar, BorderLayout.NORTH); + add(scroll, BorderLayout.CENTER); + + /* ── Wire up listener ────────────────────────────────────── */ + Consumer listener = entry -> { + synchronized (allEntries) { allEntries.add(entry); } + scheduleDirty(); + }; + AppLogger.addListener(listener); + + /* Coalesce timer fires every 16 ms (~60 fps) if dirty */ + coalesceTimer = new Timer(16, e -> { + if (dirty) { + dirty = false; + rebuild(); + } + }); + coalesceTimer.setRepeats(true); + coalesceTimer.start(); + } + + private void scheduleDirty() { + dirty = true; + } + + /** Rebuild the text area contents from the filtered entry list. */ + private void rebuild() { + SwingUtilities.invokeLater(() -> { + StringBuilder sb = new StringBuilder(); + List snapshot; + synchronized (allEntries) { snapshot = new ArrayList<>(allEntries); } + for (LogEntry entry : snapshot) { + if (entry.channel() == Channel.APPLICATION && !showApp) continue; + if (entry.channel() == Channel.MODEL && !showModel) continue; + String tag = entry.channel() == Channel.APPLICATION ? "[APP] " : "[MDL] "; + sb.append(tag).append(AppLogger.format(entry)).append('\n'); + } + logArea.setText(sb.toString()); + logArea.setCaretPosition(logArea.getDocument().getLength()); + }); + } +} diff --git a/src/main/java/atri/palaash/lumenforge/ui/MainFrame.java b/src/main/java/atri/palaash/lumenforge/ui/MainFrame.java index da90fb5..e40012c 100644 --- a/src/main/java/atri/palaash/lumenforge/ui/MainFrame.java +++ b/src/main/java/atri/palaash/lumenforge/ui/MainFrame.java @@ -55,6 +55,7 @@ public class MainFrame extends JFrame { private static final String CARD_IMG2IMG = "Img2Img"; private static final String CARD_UPSCALE = "Upscale"; private static final String CARD_MODELS = "Models"; + private static final String CARD_LOGS = "Logs"; private final CardLayout cardLayout = new CardLayout(); private final JPanel contentPanel = new JPanel(cardLayout); @@ -63,11 +64,13 @@ public class MainFrame extends JFrame { private final ImageUpscalePanel imageUpscalePanel; private final Img2ImgPanel img2ImgPanel; private final ModelManagerPanel modelManagerPanel; + private final LogsPanel logsPanel; private final JLabel statusBarLabel; /* Sidebar lists */ private JList workflowList; private JList managementList; + private JList diagnosticsList; /* GPU state (shared across panels via supplier) */ private boolean gpuEnabled = true; @@ -137,11 +140,14 @@ public MainFrame(ModelRegistry registry, .collect(Collectors.toList())); }); + logsPanel = new LogsPanel(); + /* ── Content cards ───────────────────────────────────────── */ contentPanel.add(textToImagePanel, CARD_GENERATE); contentPanel.add(img2ImgPanel, CARD_IMG2IMG); contentPanel.add(imageUpscalePanel, CARD_UPSCALE); contentPanel.add(modelManagerPanel, CARD_MODELS); + contentPanel.add(logsPanel, CARD_LOGS); /* ── Sidebar ─────────────────────────────────────────────── */ // Create both lists first, then wire listeners @@ -159,19 +165,35 @@ public MainFrame(ModelRegistry registry, managementList.setCellRenderer(new SidebarRenderer()); managementList.setOpaque(false); - // Wire listeners after both lists exist + String[] diagnosticsItems = {CARD_LOGS}; + diagnosticsList = new JList<>(diagnosticsItems); + diagnosticsList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); + diagnosticsList.setFixedCellHeight(32); + diagnosticsList.setCellRenderer(new SidebarRenderer()); + diagnosticsList.setOpaque(false); + + // Wire listeners after all lists exist workflowList.addListSelectionListener(e -> { if (!e.getValueIsAdjusting() && workflowList.getSelectedIndex() >= 0) { managementList.clearSelection(); + diagnosticsList.clearSelection(); cardLayout.show(contentPanel, workflowList.getSelectedValue()); } }); managementList.addListSelectionListener(e -> { if (!e.getValueIsAdjusting() && managementList.getSelectedIndex() >= 0) { workflowList.clearSelection(); + diagnosticsList.clearSelection(); cardLayout.show(contentPanel, managementList.getSelectedValue()); } }); + diagnosticsList.addListSelectionListener(e -> { + if (!e.getValueIsAdjusting() && diagnosticsList.getSelectedIndex() >= 0) { + workflowList.clearSelection(); + managementList.clearSelection(); + cardLayout.show(contentPanel, diagnosticsList.getSelectedValue()); + } + }); // Set initial selection after listeners are wired workflowList.setSelectedIndex(0); @@ -201,7 +223,7 @@ public MainFrame(ModelRegistry registry, workflowList.setAlignmentX(Component.LEFT_ALIGNMENT); workflowSection.add(workflowList); - // Management section (pinned to bottom) + // Management section JPanel managementSection = new JPanel(); managementSection.setLayout(new BoxLayout(managementSection, BoxLayout.Y_AXIS)); managementSection.setOpaque(false); @@ -214,13 +236,34 @@ public MainFrame(ModelRegistry registry, managementSection.add(managementHeader); managementList.setAlignmentX(Component.LEFT_ALIGNMENT); managementSection.add(managementList); - managementSection.setBorder(BorderFactory.createEmptyBorder(0, 0, 12, 0)); - // Combine: workflow on top, management at bottom with glue between + // Diagnostics section + JPanel diagnosticsSection = new JPanel(); + diagnosticsSection.setLayout(new BoxLayout(diagnosticsSection, BoxLayout.Y_AXIS)); + diagnosticsSection.setOpaque(false); + + JLabel diagnosticsHeader = new JLabel("DIAGNOSTICS"); + diagnosticsHeader.setFont(diagnosticsHeader.getFont().deriveFont(Font.BOLD, 10f)); + diagnosticsHeader.setForeground(UIManager.getColor("Label.disabledForeground")); + diagnosticsHeader.setBorder(BorderFactory.createEmptyBorder(8, 20, 4, 20)); + diagnosticsHeader.setAlignmentX(Component.LEFT_ALIGNMENT); + diagnosticsSection.add(diagnosticsHeader); + diagnosticsList.setAlignmentX(Component.LEFT_ALIGNMENT); + diagnosticsSection.add(diagnosticsList); + + // Bottom panel: management + diagnostics + JPanel bottomSection = new JPanel(); + bottomSection.setLayout(new BoxLayout(bottomSection, BoxLayout.Y_AXIS)); + bottomSection.setOpaque(false); + bottomSection.setBorder(BorderFactory.createEmptyBorder(0, 0, 12, 0)); + bottomSection.add(managementSection); + bottomSection.add(diagnosticsSection); + + // Combine: workflow on top, management+diagnostics at bottom JPanel sidebarContent = new JPanel(new BorderLayout()); sidebarContent.setOpaque(false); sidebarContent.add(workflowSection, BorderLayout.NORTH); - sidebarContent.add(managementSection, BorderLayout.SOUTH); + sidebarContent.add(bottomSection, BorderLayout.SOUTH); sidebarPanel.add(sidebarContent, BorderLayout.CENTER); @@ -290,6 +333,11 @@ private JMenuBar buildMenuBar() { showModels.addActionListener(e -> switchToCard(CARD_MODELS, managementList, 0)); viewMenu.add(showModels); + JMenuItem showLogs = new JMenuItem("Logs"); + showLogs.setAccelerator(KeyStroke.getKeyStroke(KeyEvent.VK_L, menuMask)); + showLogs.addActionListener(e -> switchToCard(CARD_LOGS, diagnosticsList, 0)); + viewMenu.add(showLogs); + bar.add(viewMenu); /* Inference */ @@ -322,11 +370,9 @@ private JMenuBar buildMenuBar() { private void switchToCard(String card, JList targetList, int index) { cardLayout.show(contentPanel, card); - if (targetList == workflowList) { - managementList.clearSelection(); - } else { - workflowList.clearSelection(); - } + if (targetList != workflowList) workflowList.clearSelection(); + if (targetList != managementList) managementList.clearSelection(); + if (targetList != diagnosticsList) diagnosticsList.clearSelection(); targetList.setSelectedIndex(index); } diff --git a/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java b/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java index 2325a74..dfa92fb 100644 --- a/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java +++ b/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java @@ -261,10 +261,12 @@ private void importModelFromFile(ModelDescriptor descriptor, int row) { tableModel.setAvailable(row, true); tableModel.updateProgress(descriptor.id(), 100); statusLabel.setText("Imported: " + target); + AppLogger.model("Model imported: " + descriptor.displayName() + " from " + source); if (onModelsUpdated != null) { onModelsUpdated.run(); } } catch (IOException ex) { + AppLogger.modelError("Import failed for " + descriptor.displayName() + ": " + ex.getMessage()); statusLabel.setText("Import failed: " + ex.getMessage()); } } @@ -319,6 +321,7 @@ private void startConversion(String modelId, String displayName, String mode) { statusLabel.setText("Converting " + displayName + " to ONNX\u2026"); progressBar.setVisible(true); progressBar.setIndeterminate(true); + AppLogger.model("Starting PyTorch \u2192 ONNX conversion: " + modelId); // 5. Run on background thread final String conversionMode = mode; @@ -343,6 +346,7 @@ private void startConversion(String modelId, String displayName, String mode) { PyTorchToOnnxConverter.openPythonDownloadPage(); } } else { + AppLogger.modelError("Conversion failed for " + modelId + ": " + cause.getMessage()); statusLabel.setText("Conversion failed: " + cause.getMessage()); JOptionPane.showMessageDialog(this, "Conversion failed:\n" + cause.getMessage(), @@ -391,6 +395,7 @@ private void registerConvertedModel(String modelId, String sanitized, Path outpu modelRegistry.mergeDownloadableAssets(List.of(desc)); refreshTable(); statusLabel.setText("\u2713 Converted and registered: " + modelId); + AppLogger.model("Conversion complete: " + modelId + " \u2192 " + relativePath); if (onModelsUpdated != null) { onModelsUpdated.run(); } From f1d2a131e790a4a85e5b681fec596ae3c1e87113 Mon Sep 17 00:00:00 2001 From: Palaash Atri Date: Tue, 24 Feb 2026 17:54:01 +0530 Subject: [PATCH 5/5] rm unncessary stuff. --- pom.xml | 37 +- .../palaash/lumenforge/app/LumenForgeApp.java | 5 +- .../inference/DjlPyTorchService.java | 396 ------------------ .../inference/GenericOnnxService.java | 44 +- .../lumenforge/inference/ServiceFactory.java | 27 -- .../lumenforge/storage/ModelDownloader.java | 9 +- .../atri/palaash/lumenforge/ui/AppLogger.java | 88 ---- .../atri/palaash/lumenforge/ui/LogsPanel.java | 118 ------ .../atri/palaash/lumenforge/ui/MainFrame.java | 55 +-- .../lumenforge/ui/ModelManagerPanel.java | 10 +- 10 files changed, 37 insertions(+), 752 deletions(-) delete mode 100644 src/main/java/atri/palaash/lumenforge/inference/DjlPyTorchService.java delete mode 100644 src/main/java/atri/palaash/lumenforge/ui/AppLogger.java delete mode 100644 src/main/java/atri/palaash/lumenforge/ui/LogsPanel.java diff --git a/pom.xml b/pom.xml index e4839e6..b834b1e 100644 --- a/pom.xml +++ b/pom.xml @@ -144,41 +144,6 @@ onnxruntime_gpu - - - djl-pytorch - - - djl - true - - - - - - ai.djl - bom - 0.31.0 - pom - import - - - - - - ai.djl - api - - - ai.djl.pytorch - pytorch-engine - runtime - - - ai.djl.huggingface - tokenizers - - - + diff --git a/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java b/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java index 6342325..dfc00a8 100644 --- a/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java +++ b/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java @@ -6,7 +6,6 @@ import atri.palaash.lumenforge.model.TaskType; import atri.palaash.lumenforge.storage.ModelDownloader; import atri.palaash.lumenforge.storage.ModelStorage; -import atri.palaash.lumenforge.ui.AppLogger; import atri.palaash.lumenforge.ui.MainFrame; import atri.palaash.lumenforge.ui.NativeLookAndFeel; @@ -23,9 +22,9 @@ public class LumenForgeApp { public static void main(String[] args) { configureDesktopIntegration(); - AppLogger.app("LumenForge starting \u2014 Java " + Runtime.version() + System.out.println("[LumenForge] Starting \u2014 Java " + Runtime.version() + ", " + System.getProperty("os.name") + " " + System.getProperty("os.arch")); - AppLogger.app("Max heap: " + (Runtime.getRuntime().maxMemory() / (1024 * 1024)) + " MB" + System.out.println("[LumenForge] Max heap: " + (Runtime.getRuntime().maxMemory() / (1024 * 1024)) + " MB" + ", CPUs: " + Runtime.getRuntime().availableProcessors()); ExecutorService workerPool = Executors.newVirtualThreadPerTaskExecutor(); diff --git a/src/main/java/atri/palaash/lumenforge/inference/DjlPyTorchService.java b/src/main/java/atri/palaash/lumenforge/inference/DjlPyTorchService.java deleted file mode 100644 index 56aad75..0000000 --- a/src/main/java/atri/palaash/lumenforge/inference/DjlPyTorchService.java +++ /dev/null @@ -1,396 +0,0 @@ -package atri.palaash.lumenforge.inference; - -import atri.palaash.lumenforge.storage.ModelStorage; - -import java.awt.image.BufferedImage; -import java.lang.reflect.Method; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Executor; - -/** - * Stable Diffusion inference using DJL (Deep Java Library) with PyTorch backend. - * - *

This backend loads TorchScript-traced models (.pt files) and runs inference - * using PyTorch's native tensor operations (CPU, CUDA, or MPS). Activate by - * building with {@code mvn compile -Ddjl=true}. - * - *

Required model layout

- *
- *   ~/.lumenforge-models/text-image/sd-pytorch/
- *     clip_model.pt           — traced CLIP text encoder
- *     unet_model.pt           — traced UNet2DConditionModel
- *     vae_decoder_model.pt    — traced VAE decoder
- *     tokenizer.json          — HuggingFace tokenizer config
- * 
- * - *

Use the included {@code scripts/export_torchscript.py} helper to export - * HuggingFace SD checkpoints to TorchScript format. - */ -public class DjlPyTorchService implements InferenceService { - - private final ModelStorage storage; - private final Executor executor; - - /* ── DJL availability check ─────────────────────────────────────── */ - - private static final boolean DJL_AVAILABLE; - static { - boolean found; - try { - Class.forName("ai.djl.ndarray.NDManager"); - found = true; - } catch (ClassNotFoundException e) { - found = false; - } - DJL_AVAILABLE = found; - } - - /** Returns {@code true} if the DJL PyTorch runtime is on the classpath. */ - public static boolean isAvailable() { - return DJL_AVAILABLE; - } - - public DjlPyTorchService(ModelStorage storage, Executor executor) { - this.storage = storage; - this.executor = executor; - } - - @Override - public CompletableFuture run(InferenceRequest request) { - return CompletableFuture.supplyAsync(() -> { - if (!DJL_AVAILABLE) { - return InferenceResult.fail( - "DJL PyTorch is not on the classpath. Rebuild with: mvn compile -Ddjl=true"); - } - - Path base = storage.root().resolve("text-image").resolve("sd-pytorch"); - Path clipPath = base.resolve("clip_model.pt"); - Path unetPath = base.resolve("unet_model.pt"); - Path vaePath = base.resolve("vae_decoder_model.pt"); - Path tokPath = base.resolve("tokenizer.json"); - - for (Path p : new Path[]{clipPath, unetPath, vaePath, tokPath}) { - if (!Files.exists(p)) { - return InferenceResult.fail( - "Missing TorchScript component: " + p.getFileName() - + "\nExport models with scripts/export_torchscript.py first."); - } - } - - try { - return runPyTorchPipeline(request, clipPath, unetPath, vaePath, tokPath); - } catch (Exception ex) { - return InferenceResult.fail("DJL PyTorch pipeline failed: " + ex.getMessage()); - } - }, executor); - } - - /* ── Pipeline (reflection-based to compile without DJL on classpath) ─ */ - - @SuppressWarnings("unchecked") - private InferenceResult runPyTorchPipeline(InferenceRequest request, - Path clipPath, Path unetPath, - Path vaePath, Path tokPath) - throws Exception { - - // All DJL calls go through reflection so the class compiles even when - // the djl-pytorch profile is not active. - - Class clsNDManager = Class.forName("ai.djl.ndarray.NDManager"); - Class clsNDArray = Class.forName("ai.djl.ndarray.NDArray"); - Class clsNDList = Class.forName("ai.djl.ndarray.NDList"); - Class clsShape = Class.forName("ai.djl.ndarray.types.Shape"); - Class clsDataType = Class.forName("ai.djl.ndarray.types.DataType"); - Class clsDevice = Class.forName("ai.djl.Device"); - Class clsCriteria = Class.forName("ai.djl.repository.zoo.Criteria"); - Class clsZooModel = Class.forName("ai.djl.repository.zoo.ZooModel"); - Class clsPredictor = Class.forName("ai.djl.inference.Predictor"); - Class clsTokenizer = Class.forName("ai.djl.huggingface.tokenizers.HuggingFaceTokenizer"); - Class clsEncoding = Class.forName("ai.djl.huggingface.tokenizers.Encoding"); - Class clsTranslator = Class.forName("ai.djl.translate.NoopTranslator"); - - int width = Math.max(256, (request.width() / 8) * 8); - int height = Math.max(256, (request.height() / 8) * 8); - int latentW = width / 8; - int latentH = height / 8; - int steps = Math.max(1, Math.min(request.batch() > 0 ? request.batch() : 20, 50)); - - // Choose device — prefer GPU if available - Object device; - if (request.preferGpu()) { - Method gpuMethod = clsDevice.getMethod("gpu"); - device = gpuMethod.invoke(null); - } else { - Method cpuMethod = clsDevice.getMethod("cpu"); - device = cpuMethod.invoke(null); - } - - // Create NDManager - Method managerOf = clsNDManager.getMethod("newBaseManager", clsDevice); - Object manager = managerOf.invoke(null, device); - - try { - /* ── Tokenize ──────────────────────────────────────────────── */ - request.reportProgress("Tokenizing prompt (DJL HuggingFace)\u2026"); - Method tokLoad = clsTokenizer.getMethod("newInstance", Path.class); - Object tokenizer = tokLoad.invoke(null, tokPath); - Method tokEncode = clsTokenizer.getMethod("encode", String.class); - Object encoding = tokEncode.invoke(tokenizer, request.prompt()); - Method getIds = clsEncoding.getMethod("getIds"); - long[] tokenIds = (long[]) getIds.invoke(encoding); - - // Pad/truncate to 77 - long[] padded = new long[77]; - System.arraycopy(tokenIds, 0, padded, 0, Math.min(tokenIds.length, 77)); - if (tokenIds.length < 77) { - long padId = 49407L; // <|endoftext|> - for (int i = tokenIds.length; i < 77; i++) padded[i] = padId; - } - - // Create token tensor [1, 77] - Object shapeTokens = clsShape.getConstructor(long[].class) - .newInstance((Object) new long[]{1, 77}); - Method mgrCreate = clsNDManager.getMethod("create", long[].class, - Class.forName("ai.djl.ndarray.types.Shape")); - Object tokenTensor = mgrCreate.invoke(manager, padded, shapeTokens); - - /* ── CLIP text encoder ──────────────────────────────────── */ - request.reportProgress("Loading CLIP text encoder (PyTorch)\u2026"); - Object clipModel = loadTorchScriptModel(clsCriteria, clsNDList, clsTranslator, - clipPath, device); - Object clipPredictor = clsZooModel.getMethod("newPredictor").invoke(clipModel); - - Object clipInput = clsNDList.getConstructor(clsNDArray).newInstance(tokenTensor); - request.reportProgress("Running CLIP text encoder\u2026"); - Object clipOutput = clsPredictor.getMethod("predict", Object.class) - .invoke(clipPredictor, clipInput); - Object textEmbedding = clsNDList.getMethod("singletonOrThrow").invoke(clipOutput); - - /* ── Negative prompt (uncond) encoder ───────────────────── */ - long[] emptyTokens = new long[77]; - emptyTokens[0] = 49406L; // <|startoftext|> - for (int i = 1; i < 77; i++) emptyTokens[i] = 49407L; - Object uncondTensor = mgrCreate.invoke(manager, emptyTokens, shapeTokens); - Object uncondInput = clsNDList.getConstructor(clsNDArray).newInstance(uncondTensor); - request.reportProgress("Running uncond text encoder\u2026"); - Object uncondOutput = clsPredictor.getMethod("predict", Object.class) - .invoke(clipPredictor, uncondInput); - Object uncondEmbedding = clsNDList.getMethod("singletonOrThrow").invoke(uncondOutput); - - // Concatenate [uncond, cond] → [2, 77, 768] - Method concatMethod = clsNDArray.getMethod("concat", clsNDArray, int.class); - // Stack along batch dim: uncond embedding [1,77,768] + text embedding [1,77,768] - Object embeddings = concatMethod.invoke(uncondEmbedding, textEmbedding, 0); - - clsPredictor.getMethod("close").invoke(clipPredictor); - clsZooModel.getMethod("close").invoke(clipModel); - - /* ── UNet denoising loop ────────────────────────────────── */ - request.reportProgress("Loading UNet (PyTorch)\u2026"); - Object unetModel = loadTorchScriptModel(clsCriteria, clsNDList, clsTranslator, - unetPath, device); - Object unetPredictor = clsZooModel.getMethod("newPredictor").invoke(unetModel); - - // Initialize random latents [1, 4, latentH, latentW] - Object latentShape = clsShape.getConstructor(long[].class) - .newInstance((Object) new long[]{1, 4, latentH, latentW}); - Method mgrRandomNormal = clsNDManager.getMethod("randomNormal", - Class.forName("ai.djl.ndarray.types.Shape")); - Object latents = mgrRandomNormal.invoke(manager, latentShape); - - // Simple linear schedule for beta - double guidanceScale = request.promptWeight() > 0 ? request.promptWeight() : 7.5; - float[] timestepSchedule = ddimTimesteps(steps); - - // Initial scale - Method mulMethod = clsNDArray.getMethod("mul", Number.class); - latents = mulMethod.invoke(latents, (float) Math.sqrt(1.0 + sigmaForTimestep(timestepSchedule[0]))); - - long stepStart = System.currentTimeMillis(); - for (int i = 0; i < timestepSchedule.length; i++) { - if (request.isCancelled()) { - clsPredictor.getMethod("close").invoke(unetPredictor); - clsZooModel.getMethod("close").invoke(unetModel); - return InferenceResult.fail("Cancelled by user."); - } - - float t = timestepSchedule[i]; - - // Duplicate latents for CFG: [latents, latents] → [2, 4, H, W] - Object latentDouble = concatMethod.invoke(latents, latents, 0); - - // Create timestep tensor - Object tsShape = clsShape.getConstructor(long[].class) - .newInstance((Object) new long[]{1}); - Object tsTensor = mgrCreate.invoke(manager, new long[]{(long) t}, tsShape); - - // UNet forward: (noisy_latents, timestep, encoder_hidden_states) - Object unetInput = clsNDList.getConstructor(clsNDArray.arrayType()) - .newInstance((Object) new Object[]{latentDouble, tsTensor, embeddings}); - Object unetOutput = clsPredictor.getMethod("predict", Object.class) - .invoke(unetPredictor, unetInput); - Object noisePred = clsNDList.getMethod("singletonOrThrow").invoke(unetOutput); - - // Classifier-free guidance: split [2,4,H,W] → uncond, cond - Method splitMethod = clsNDArray.getMethod("split", long.class, int.class); - Object splitResult = splitMethod.invoke(noisePred, 2L, 0); - Object noiseUncond = java.lang.reflect.Array.get( - clsNDList.getMethod("toArray").invoke(splitResult), 0); - Object noiseCond = java.lang.reflect.Array.get( - clsNDList.getMethod("toArray").invoke(splitResult), 1); - - // guided = uncond + scale * (cond - uncond) - Method subMethod = clsNDArray.getMethod("sub", clsNDArray); - Method addMethod = clsNDArray.getMethod("add", clsNDArray); - Object diff = subMethod.invoke(noiseCond, noiseUncond); - Object scaled = mulMethod.invoke(diff, (float) guidanceScale); - Object guided = addMethod.invoke(noiseUncond, scaled); - - // DDIM step - float alphaT = alphaForTimestep(t); - float alphaPrev = (i + 1 < timestepSchedule.length) - ? alphaForTimestep(timestepSchedule[i + 1]) : 1.0f; - - // predicted x0 = (latents - sqrt(1-alpha)*noise) / sqrt(alpha) - Object scaledNoise = mulMethod.invoke(guided, (float) Math.sqrt(1.0 - alphaT)); - Object x0 = subMethod.invoke(latents, scaledNoise); - x0 = mulMethod.invoke(x0, (float) (1.0 / Math.sqrt(alphaT))); - - // direction pointing to x_t - Object dirXt = mulMethod.invoke(guided, (float) Math.sqrt(1.0 - alphaPrev)); - // x_{t-1} = sqrt(alpha_prev) * x0 + sqrt(1-alpha_prev) * noise_pred - Object prevSample = mulMethod.invoke(x0, (float) Math.sqrt(alphaPrev)); - latents = addMethod.invoke(prevSample, dirXt); - - long elapsed = System.currentTimeMillis() - stepStart; - stepStart = System.currentTimeMillis(); - request.reportProgress("Denoising: " + (i + 1) + "/" + steps - + " steps (" + String.format("%.1f", elapsed / 1000.0) + "s/step) [PyTorch]"); - } - - clsPredictor.getMethod("close").invoke(unetPredictor); - clsZooModel.getMethod("close").invoke(unetModel); - - /* ── VAE decode ─────────────────────────────────────────── */ - request.reportProgress("Loading VAE decoder (PyTorch)\u2026"); - Object vaeModel = loadTorchScriptModel(clsCriteria, clsNDList, clsTranslator, - vaePath, device); - Object vaePredictor = clsZooModel.getMethod("newPredictor").invoke(vaeModel); - - // Scale latents - latents = mulMethod.invoke(latents, 1.0f / 0.18215f); - - Object vaeInput = clsNDList.getConstructor(clsNDArray).newInstance(latents); - request.reportProgress("Decoding with VAE\u2026"); - Object vaeOutput = clsPredictor.getMethod("predict", Object.class) - .invoke(vaePredictor, vaeInput); - Object decoded = clsNDList.getMethod("singletonOrThrow").invoke(vaeOutput); - - clsPredictor.getMethod("close").invoke(vaePredictor); - clsZooModel.getMethod("close").invoke(vaeModel); - - /* ── Convert tensor to BufferedImage ───────────────────── */ - // decoded: [1, 3, H, W] float32, range roughly [-1, 1] - Method toFloatArray = clsNDArray.getMethod("toFloatArray"); - float[] pixels = (float[]) toFloatArray.invoke(decoded); - - Method getShapeMethod = clsNDArray.getMethod("getShape"); - Object decodedShape = getShapeMethod.invoke(decoded); - long[] dims = (long[]) clsShape.getMethod("getShape").invoke(decodedShape); - int imgH = (int) dims[2]; - int imgW = (int) dims[3]; - - BufferedImage image = new BufferedImage(imgW, imgH, BufferedImage.TYPE_INT_RGB); - for (int y = 0; y < imgH; y++) { - for (int x = 0; x < imgW; x++) { - int rIdx = 0 * imgH * imgW + y * imgW + x; - int gIdx = 1 * imgH * imgW + y * imgW + x; - int bIdx = 2 * imgH * imgW + y * imgW + x; - int r = clamp((int) ((pixels[rIdx] / 2 + 0.5f) * 255)); - int g = clamp((int) ((pixels[gIdx] / 2 + 0.5f) * 255)); - int b = clamp((int) ((pixels[bIdx] / 2 + 0.5f) * 255)); - image.setRGB(x, y, (r << 16) | (g << 8) | b); - } - } - - Path outputDir = storage.root().resolve("outputs"); - Files.createDirectories(outputDir); - String filename = "djl-pytorch_" + System.currentTimeMillis() + ".png"; - Path outputPath = outputDir.resolve(filename); - javax.imageio.ImageIO.write(image, "PNG", outputPath.toFile()); - - return InferenceResult.ok( - "Generated image for prompt: \"" + request.prompt() + "\"", - "DJL PyTorch pipeline completed (" + steps + " steps)", - outputPath.toString(), "image"); - - } finally { - // Close NDManager - clsNDManager.getMethod("close").invoke(manager); - } - } - - /* ── Helpers ──────────────────────────────────────────────────────── */ - - /** Load a TorchScript model via DJL Criteria (reflection). */ - private Object loadTorchScriptModel(Class clsCriteria, Class clsNDList, - Class clsTranslator, Path modelPath, - Object device) throws Exception { - Method builder = clsCriteria.getMethod("builder"); - Object b = builder.invoke(null); - - Method setTypes = b.getClass().getMethod("setTypes", Class.class, Class.class); - b = setTypes.invoke(b, clsNDList, clsNDList); - - Method optModelPath = b.getClass().getMethod("optModelPath", Path.class); - b = optModelPath.invoke(b, modelPath); - - Method optEngine = b.getClass().getMethod("optEngine", String.class); - b = optEngine.invoke(b, "PyTorch"); - - Object translator = clsTranslator.getDeclaredConstructor().newInstance(); - Method optTranslator = b.getClass().getMethod("optTranslator", - Class.forName("ai.djl.translate.Translator")); - b = optTranslator.invoke(b, translator); - - Method optDevice = b.getClass().getMethod("optDevice", - Class.forName("ai.djl.Device")); - b = optDevice.invoke(b, device); - - Method build = b.getClass().getMethod("build"); - Object criteria = build.invoke(b); - - Method loadModel = clsCriteria.getMethod("loadModel"); - return loadModel.invoke(criteria); - } - - /** DDIM linear timestep schedule. */ - private float[] ddimTimesteps(int numSteps) { - float[] ts = new float[numSteps]; - for (int i = 0; i < numSteps; i++) { - ts[i] = 999.0f * (1.0f - (float) i / numSteps); - } - return ts; - } - - /** Linear beta schedule → alpha_bar for a given timestep. */ - private float alphaForTimestep(float t) { - // Simplified linear schedule: beta from 0.00085 to 0.012 over 1000 steps - double beta = 0.00085 + (0.012 - 0.00085) * t / 999.0; - double alpha = 1.0 - beta; - // Approximate cumulative product - return (float) Math.pow(alpha, t); - } - - /** Sigma for a given timestep (derived from alpha). */ - private double sigmaForTimestep(float t) { - float a = alphaForTimestep(t); - return Math.sqrt((1.0 - a) / a); - } - - private static int clamp(int v) { - return Math.max(0, Math.min(255, v)); - } -} diff --git a/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java b/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java index a2e654e..e4a1d65 100644 --- a/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java +++ b/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java @@ -10,7 +10,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import atri.palaash.lumenforge.model.TaskType; import atri.palaash.lumenforge.storage.ModelStorage; -import atri.palaash.lumenforge.ui.AppLogger; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; @@ -72,13 +71,13 @@ private OrtSession getOrCreateSession(OrtEnvironment env, Path modelPath, String key = modelPath.toAbsolutePath().toString(); OrtSession existing = SESSION_CACHE.get(key); if (existing != null) { - AppLogger.model("Session cache hit: " + modelPath.getFileName()); + System.out.println("[LumenForge] Session cache hit: " + modelPath.getFileName()); return existing; } - AppLogger.model("Loading ONNX session: " + modelPath.getFileName()); + System.out.println("[LumenForge] Loading ONNX session: " + modelPath.getFileName()); OrtSession session = env.createSession(modelPath.toString(), opts); SESSION_CACHE.put(key, session); - AppLogger.model("Session loaded: " + modelPath.getFileName()); + System.out.println("[LumenForge] Session loaded: " + modelPath.getFileName()); return session; } @@ -130,12 +129,12 @@ public static void clearCache() { public CompletableFuture run(InferenceRequest request) { return CompletableFuture.supplyAsync(() -> { if (!storage.isAvailable(request.model())) { - AppLogger.modelError("Model not found locally: " + request.model().displayName()); + System.out.println("[LumenForge] ERROR: Model not found locally: " + request.model().displayName()); return InferenceResult.fail("Model not found locally. Open Model Manager from the menu bar and download it first."); } Path modelPath = storage.modelPath(request.model()); - AppLogger.model("Starting inference: " + request.model().displayName() + System.out.println("[LumenForge] Starting inference: " + request.model().displayName() + " (" + request.model().id() + ")"); // Temporarily intercept stderr so ONNX Runtime native warnings @@ -153,7 +152,7 @@ public CompletableFuture run(InferenceRequest request) { sessionOptions.setIntraOpNumThreads(Math.max(1, cpus - 1)); sessionOptions.setInterOpNumThreads(Math.max(1, Math.min(cpus / 2, 4))); ProviderSelection providerSelection = configureExecutionProvider(sessionOptions, request.preferGpu()); - AppLogger.model("Using EP: " + providerSelection.provider() + System.out.println("[LumenForge] Using EP: " + providerSelection.provider() + (providerSelection.notes().isBlank() ? "" : " (" + providerSelection.notes() + ")")); // Invalidate session cache when the execution provider changes. @@ -184,11 +183,6 @@ public CompletableFuture run(InferenceRequest request) { && request.model().relativePath().contains("transformer/")) { return runSd3(environment, sessionOptions, request, providerSelection.provider()); } - // DJL/PyTorch models — delegate to the DJL backend - if (request.model().id().contains("pytorch")) { - DjlPyTorchService djl = new DjlPyTorchService(storage, executor); - return djl.run(request).join(); - } // Img2Img pipelines if ("sd_v15_img2img".equals(request.model().id())) { return runImg2Img(environment, sessionOptions, request, providerSelection.provider(), false); @@ -201,11 +195,11 @@ public CompletableFuture run(InferenceRequest request) { + request.model().displayName() + " | task=" + taskType.displayName() + " | EP=" + providerSelection.provider() + providerSelection.noteSuffix(); - AppLogger.modelWarn(details); + System.out.println("[LumenForge] WARN: " + details); return InferenceResult.fail(details); } catch (OrtException ex) { String message = ex.getMessage() == null ? "Unknown ONNX Runtime error" : ex.getMessage(); - AppLogger.modelError("ONNX Runtime error: " + message); + System.out.println("[LumenForge] ERROR: ONNX Runtime error: " + message); if (message.contains("NhwcConv")) { return InferenceResult.fail( "This ONNX model uses custom ops (e.g., NhwcConv) not available in CPUExecutionProvider. " @@ -2656,24 +2650,24 @@ private ProviderSelection configureExecutionProvider(OrtSession.SessionOptions o StringBuilder notes = new StringBuilder(); List failReasons = new ArrayList<>(); - AppLogger.app("EP preference order: " + preference + System.out.println("[LumenForge] EP preference order: " + preference + " (os=" + os + ", arch=" + arch + ")"); if (!preferGpu) { - AppLogger.app("GPU not requested for this session — using CPUExecutionProvider"); + System.out.println("[LumenForge] GPU not requested for this session — using CPUExecutionProvider"); return new ProviderSelection("CPUExecutionProvider", "GPU not requested"); } for (String candidate : preference) { if ("cpu".equals(candidate)) { // All GPU EPs exhausted — log summary - AppLogger.appWarn("Falling back to CPUExecutionProvider"); + System.out.println("[LumenForge] WARN: Falling back to CPUExecutionProvider"); if (!failReasons.isEmpty()) { - AppLogger.appWarn("Reason: every GPU execution provider was unavailable:"); + System.out.println("[LumenForge] WARN: Reason: every GPU execution provider was unavailable:"); for (String reason : failReasons) { - AppLogger.appWarn(" - " + reason); + System.out.println("[LumenForge] WARN: - " + reason); } - AppLogger.appWarn("Tip: install the matching GPU runtime " + System.out.println("[LumenForge] WARN: Tip: install the matching GPU runtime " + "(e.g. CUDA/cuDNN for NVIDIA, ROCm for AMD) or use " + "-Dlumenforge.ep= to force a specific EP."); } @@ -2682,17 +2676,17 @@ private ProviderSelection configureExecutionProvider(OrtSession.SessionOptions o String failReason = tryEnableProvider(options, candidate, notes); if (failReason == null) { String display = providerDisplayName(candidate); - AppLogger.app("\u2713 Enabled " + display); + System.out.println("[LumenForge] \u2713 Enabled " + display); return new ProviderSelection(display, notes.toString()); } failReasons.add(failReason); } // Preference list didn't include "cpu" explicitly — shouldn't happen, but handle it - AppLogger.appWarn("No EP available, falling back to CPUExecutionProvider"); + System.out.println("[LumenForge] WARN: No EP available, falling back to CPUExecutionProvider"); if (!failReasons.isEmpty()) { - AppLogger.appWarn("Reasons:"); + System.out.println("[LumenForge] WARN: Reasons:"); for (String reason : failReasons) { - AppLogger.appWarn(" - " + reason); + System.out.println("[LumenForge] WARN: - " + reason); } } return new ProviderSelection("CPUExecutionProvider", notes.toString()); @@ -2763,7 +2757,7 @@ private String tryEnableProvider(OrtSession.SessionOptions options, String candi if (failDetail != null) { if (!notes.isEmpty()) { notes.append("; "); } notes.append(candidate).append(" not available"); - AppLogger.appWarn("\u2717 " + failDetail); + System.out.println("[LumenForge] WARN: \u2717 " + failDetail); } return failDetail; } diff --git a/src/main/java/atri/palaash/lumenforge/inference/ServiceFactory.java b/src/main/java/atri/palaash/lumenforge/inference/ServiceFactory.java index 00a0e0a..bc2c32d 100644 --- a/src/main/java/atri/palaash/lumenforge/inference/ServiceFactory.java +++ b/src/main/java/atri/palaash/lumenforge/inference/ServiceFactory.java @@ -25,31 +25,4 @@ public InferenceService create(TaskType taskType) { } return new GenericOnnxService(taskType, storage, executor); } - - /** - * Creates an inference service for the given task using the DJL PyTorch backend. - * Falls back to the ONNX backend if DJL is not available. - */ - public InferenceService createDjl(TaskType taskType) { - if (DjlPyTorchService.isAvailable()) { - return new DjlPyTorchService(storage, executor); - } - return create(taskType); - } - - /** - * Creates an inference service for the given model, automatically selecting - * the DJL backend for PyTorch models and ONNX backend for everything else. - */ - public InferenceService createForModel(String modelId, TaskType taskType) { - if (modelId != null && modelId.contains("pytorch")) { - return createDjl(taskType); - } - return create(taskType); - } - - /** Returns {@code true} if the DJL PyTorch runtime is available. */ - public static boolean isDjlAvailable() { - return DjlPyTorchService.isAvailable(); - } } diff --git a/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java b/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java index 1c42e1b..94b1a47 100644 --- a/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java +++ b/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java @@ -2,7 +2,6 @@ import atri.palaash.lumenforge.model.ModelDescriptor; import atri.palaash.lumenforge.model.TaskType; -import atri.palaash.lumenforge.ui.AppLogger; import java.io.IOException; import java.io.InputStream; @@ -57,7 +56,7 @@ public CompletableFuture download(ModelDescriptor descriptor, Consumer { Path target = storage.modelPath(descriptor); List writtenFiles = new ArrayList<>(); - AppLogger.model("Download started: " + descriptor.displayName()); + System.out.println("[LumenForge] Download started: " + descriptor.displayName()); try { String downloadUrl = descriptor.sourceUrl(); if (!downloadUrl.toLowerCase(Locale.ROOT).contains(".onnx") @@ -73,11 +72,11 @@ public CompletableFuture download(ModelDescriptor descriptor, ConsumerTwo independent channels: - *

    - *
  • {@link Channel#APPLICATION} — startup, EP selection, configuration, errors
  • - *
  • {@link Channel#MODEL} — model download, import, conversion, inference
  • - *
- */ -public final class AppLogger { - - public enum Channel { APPLICATION, MODEL } - public enum Level { INFO, WARN, ERROR } - - public record LogEntry(LocalDateTime timestamp, Channel channel, Level level, String message) { } - - private static final DateTimeFormatter TS_FORMAT = DateTimeFormatter.ofPattern("HH:mm:ss.SSS"); - private static final List> listeners = new CopyOnWriteArrayList<>(); - - private AppLogger() { } - - /* ── Listener management ─────────────────────────────────────── */ - - public static void addListener(Consumer listener) { - listeners.add(listener); - } - - public static void removeListener(Consumer listener) { - listeners.remove(listener); - } - - /* ── Convenience logging ─────────────────────────────────────── */ - - public static void app(String message) { - log(Channel.APPLICATION, Level.INFO, message); - } - - public static void appWarn(String message) { - log(Channel.APPLICATION, Level.WARN, message); - } - - public static void appError(String message) { - log(Channel.APPLICATION, Level.ERROR, message); - } - - public static void model(String message) { - log(Channel.MODEL, Level.INFO, message); - } - - public static void modelWarn(String message) { - log(Channel.MODEL, Level.WARN, message); - } - - public static void modelError(String message) { - log(Channel.MODEL, Level.ERROR, message); - } - - /* ── Core ────────────────────────────────────────────────────── */ - - public static void log(Channel channel, Level level, String message) { - LogEntry entry = new LogEntry(LocalDateTime.now(), channel, level, message); - for (Consumer listener : listeners) { - try { - listener.accept(entry); - } catch (Exception ignored) { } - } - } - - /** Format a log entry for display in the Logs panel. */ - public static String format(LogEntry entry) { - String prefix = switch (entry.level()) { - case INFO -> ""; - case WARN -> "\u26a0 "; - case ERROR -> "\u274c "; - }; - return TS_FORMAT.format(entry.timestamp()) + " " + prefix + entry.message(); - } -} diff --git a/src/main/java/atri/palaash/lumenforge/ui/LogsPanel.java b/src/main/java/atri/palaash/lumenforge/ui/LogsPanel.java deleted file mode 100644 index ea92369..0000000 --- a/src/main/java/atri/palaash/lumenforge/ui/LogsPanel.java +++ /dev/null @@ -1,118 +0,0 @@ -package atri.palaash.lumenforge.ui; - -import atri.palaash.lumenforge.ui.AppLogger.Channel; -import atri.palaash.lumenforge.ui.AppLogger.LogEntry; - -import javax.swing.BorderFactory; -import javax.swing.Box; -import javax.swing.BoxLayout; -import javax.swing.JButton; -import javax.swing.JPanel; -import javax.swing.JScrollPane; -import javax.swing.JTextArea; -import javax.swing.JToggleButton; -import javax.swing.SwingUtilities; -import javax.swing.Timer; -import java.awt.BorderLayout; -import java.awt.Dimension; -import java.awt.Font; -import java.util.ArrayList; -import java.util.List; -import java.util.function.Consumer; - -/** - * Logs panel with two toggle-filtered views: Application and Model. - *

- * Entries are appended in real-time via {@link AppLogger} and auto-scroll - * to the bottom. A Clear button wipes the visible log. - */ -public class LogsPanel extends JPanel { - - private final JTextArea logArea; - private final List allEntries = new ArrayList<>(); - private boolean showApp = true; - private boolean showModel = true; - - /* Coalesce rapid-fire updates (e.g. per-step progress) to ~60 fps. */ - private boolean dirty = false; - private final Timer coalesceTimer; - - public LogsPanel() { - super(new BorderLayout(0, 0)); - - /* ── Log text area (init first — referenced by toolbar buttons) ── */ - logArea = new JTextArea(); - logArea.setEditable(false); - logArea.setFont(new Font(Font.MONOSPACED, Font.PLAIN, 12)); - logArea.setLineWrap(true); - logArea.setWrapStyleWord(true); - JScrollPane scroll = new JScrollPane(logArea); - scroll.setBorder(BorderFactory.createEmptyBorder()); - - /* ── Toolbar ─────────────────────────────────────────────── */ - JPanel toolbar = new JPanel(); - toolbar.setLayout(new BoxLayout(toolbar, BoxLayout.X_AXIS)); - toolbar.setBorder(BorderFactory.createEmptyBorder(6, 12, 6, 12)); - - JToggleButton appBtn = new JToggleButton("Application", showApp); - JToggleButton modelBtn = new JToggleButton("Model", showModel); - appBtn.setFocusable(false); - modelBtn.setFocusable(false); - appBtn.addActionListener(e -> { showApp = appBtn.isSelected(); rebuild(); }); - modelBtn.addActionListener(e -> { showModel = modelBtn.isSelected(); rebuild(); }); - - JButton clearBtn = new JButton("Clear"); - clearBtn.setFocusable(false); - clearBtn.addActionListener(e -> { - synchronized (allEntries) { allEntries.clear(); } - logArea.setText(""); - }); - - toolbar.add(appBtn); - toolbar.add(Box.createRigidArea(new Dimension(4, 0))); - toolbar.add(modelBtn); - toolbar.add(Box.createHorizontalGlue()); - toolbar.add(clearBtn); - - add(toolbar, BorderLayout.NORTH); - add(scroll, BorderLayout.CENTER); - - /* ── Wire up listener ────────────────────────────────────── */ - Consumer listener = entry -> { - synchronized (allEntries) { allEntries.add(entry); } - scheduleDirty(); - }; - AppLogger.addListener(listener); - - /* Coalesce timer fires every 16 ms (~60 fps) if dirty */ - coalesceTimer = new Timer(16, e -> { - if (dirty) { - dirty = false; - rebuild(); - } - }); - coalesceTimer.setRepeats(true); - coalesceTimer.start(); - } - - private void scheduleDirty() { - dirty = true; - } - - /** Rebuild the text area contents from the filtered entry list. */ - private void rebuild() { - SwingUtilities.invokeLater(() -> { - StringBuilder sb = new StringBuilder(); - List snapshot; - synchronized (allEntries) { snapshot = new ArrayList<>(allEntries); } - for (LogEntry entry : snapshot) { - if (entry.channel() == Channel.APPLICATION && !showApp) continue; - if (entry.channel() == Channel.MODEL && !showModel) continue; - String tag = entry.channel() == Channel.APPLICATION ? "[APP] " : "[MDL] "; - sb.append(tag).append(AppLogger.format(entry)).append('\n'); - } - logArea.setText(sb.toString()); - logArea.setCaretPosition(logArea.getDocument().getLength()); - }); - } -} diff --git a/src/main/java/atri/palaash/lumenforge/ui/MainFrame.java b/src/main/java/atri/palaash/lumenforge/ui/MainFrame.java index e40012c..fb050c8 100644 --- a/src/main/java/atri/palaash/lumenforge/ui/MainFrame.java +++ b/src/main/java/atri/palaash/lumenforge/ui/MainFrame.java @@ -55,7 +55,6 @@ public class MainFrame extends JFrame { private static final String CARD_IMG2IMG = "Img2Img"; private static final String CARD_UPSCALE = "Upscale"; private static final String CARD_MODELS = "Models"; - private static final String CARD_LOGS = "Logs"; private final CardLayout cardLayout = new CardLayout(); private final JPanel contentPanel = new JPanel(cardLayout); @@ -64,13 +63,11 @@ public class MainFrame extends JFrame { private final ImageUpscalePanel imageUpscalePanel; private final Img2ImgPanel img2ImgPanel; private final ModelManagerPanel modelManagerPanel; - private final LogsPanel logsPanel; private final JLabel statusBarLabel; /* Sidebar lists */ private JList workflowList; private JList managementList; - private JList diagnosticsList; /* GPU state (shared across panels via supplier) */ private boolean gpuEnabled = true; @@ -140,14 +137,11 @@ public MainFrame(ModelRegistry registry, .collect(Collectors.toList())); }); - logsPanel = new LogsPanel(); - - /* ── Content cards ───────────────────────────────────────── */ + /* ── Content cards ─────────────────────────────────────── */ contentPanel.add(textToImagePanel, CARD_GENERATE); contentPanel.add(img2ImgPanel, CARD_IMG2IMG); contentPanel.add(imageUpscalePanel, CARD_UPSCALE); contentPanel.add(modelManagerPanel, CARD_MODELS); - contentPanel.add(logsPanel, CARD_LOGS); /* ── Sidebar ─────────────────────────────────────────────── */ // Create both lists first, then wire listeners @@ -165,35 +159,19 @@ public MainFrame(ModelRegistry registry, managementList.setCellRenderer(new SidebarRenderer()); managementList.setOpaque(false); - String[] diagnosticsItems = {CARD_LOGS}; - diagnosticsList = new JList<>(diagnosticsItems); - diagnosticsList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); - diagnosticsList.setFixedCellHeight(32); - diagnosticsList.setCellRenderer(new SidebarRenderer()); - diagnosticsList.setOpaque(false); - // Wire listeners after all lists exist workflowList.addListSelectionListener(e -> { if (!e.getValueIsAdjusting() && workflowList.getSelectedIndex() >= 0) { managementList.clearSelection(); - diagnosticsList.clearSelection(); cardLayout.show(contentPanel, workflowList.getSelectedValue()); } }); managementList.addListSelectionListener(e -> { if (!e.getValueIsAdjusting() && managementList.getSelectedIndex() >= 0) { workflowList.clearSelection(); - diagnosticsList.clearSelection(); cardLayout.show(contentPanel, managementList.getSelectedValue()); } }); - diagnosticsList.addListSelectionListener(e -> { - if (!e.getValueIsAdjusting() && diagnosticsList.getSelectedIndex() >= 0) { - workflowList.clearSelection(); - managementList.clearSelection(); - cardLayout.show(contentPanel, diagnosticsList.getSelectedValue()); - } - }); // Set initial selection after listeners are wired workflowList.setSelectedIndex(0); @@ -237,29 +215,16 @@ public MainFrame(ModelRegistry registry, managementList.setAlignmentX(Component.LEFT_ALIGNMENT); managementSection.add(managementList); - // Diagnostics section - JPanel diagnosticsSection = new JPanel(); - diagnosticsSection.setLayout(new BoxLayout(diagnosticsSection, BoxLayout.Y_AXIS)); - diagnosticsSection.setOpaque(false); - - JLabel diagnosticsHeader = new JLabel("DIAGNOSTICS"); - diagnosticsHeader.setFont(diagnosticsHeader.getFont().deriveFont(Font.BOLD, 10f)); - diagnosticsHeader.setForeground(UIManager.getColor("Label.disabledForeground")); - diagnosticsHeader.setBorder(BorderFactory.createEmptyBorder(8, 20, 4, 20)); - diagnosticsHeader.setAlignmentX(Component.LEFT_ALIGNMENT); - diagnosticsSection.add(diagnosticsHeader); - diagnosticsList.setAlignmentX(Component.LEFT_ALIGNMENT); - diagnosticsSection.add(diagnosticsList); - - // Bottom panel: management + diagnostics + // Diagnostics section removed + + // Bottom panel: management JPanel bottomSection = new JPanel(); bottomSection.setLayout(new BoxLayout(bottomSection, BoxLayout.Y_AXIS)); bottomSection.setOpaque(false); bottomSection.setBorder(BorderFactory.createEmptyBorder(0, 0, 12, 0)); bottomSection.add(managementSection); - bottomSection.add(diagnosticsSection); - // Combine: workflow on top, management+diagnostics at bottom + // Combine: workflow on top, management at bottom JPanel sidebarContent = new JPanel(new BorderLayout()); sidebarContent.setOpaque(false); sidebarContent.add(workflowSection, BorderLayout.NORTH); @@ -333,11 +298,6 @@ private JMenuBar buildMenuBar() { showModels.addActionListener(e -> switchToCard(CARD_MODELS, managementList, 0)); viewMenu.add(showModels); - JMenuItem showLogs = new JMenuItem("Logs"); - showLogs.setAccelerator(KeyStroke.getKeyStroke(KeyEvent.VK_L, menuMask)); - showLogs.addActionListener(e -> switchToCard(CARD_LOGS, diagnosticsList, 0)); - viewMenu.add(showLogs); - bar.add(viewMenu); /* Inference */ @@ -372,7 +332,6 @@ private void switchToCard(String card, JList targetList, int index) { cardLayout.show(contentPanel, card); if (targetList != workflowList) workflowList.clearSelection(); if (targetList != managementList) managementList.clearSelection(); - if (targetList != diagnosticsList) diagnosticsList.clearSelection(); targetList.setSelectedIndex(index); } @@ -451,10 +410,8 @@ private String detectEpInfo() { } int cores = Runtime.getRuntime().availableProcessors(); long mem = Runtime.getRuntime().maxMemory() / (1024 * 1024); - String djl = atri.palaash.lumenforge.inference.DjlPyTorchService.isAvailable() - ? " \u2502 DJL \u2713" : ""; return "EP: " + ep + " \u2502 Cores: " + cores + " \u2502 Heap: " + mem - + " MB \u2502 Java " + Runtime.version() + djl; + + " MB \u2502 Java " + Runtime.version(); } /* ================================================================== */ diff --git a/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java b/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java index dfa92fb..c77ae84 100644 --- a/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java +++ b/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java @@ -261,12 +261,12 @@ private void importModelFromFile(ModelDescriptor descriptor, int row) { tableModel.setAvailable(row, true); tableModel.updateProgress(descriptor.id(), 100); statusLabel.setText("Imported: " + target); - AppLogger.model("Model imported: " + descriptor.displayName() + " from " + source); + System.out.println("[LumenForge] Model imported: " + descriptor.displayName() + " from " + source); if (onModelsUpdated != null) { onModelsUpdated.run(); } } catch (IOException ex) { - AppLogger.modelError("Import failed for " + descriptor.displayName() + ": " + ex.getMessage()); + System.out.println("[LumenForge] ERROR: Import failed for " + descriptor.displayName() + ": " + ex.getMessage()); statusLabel.setText("Import failed: " + ex.getMessage()); } } @@ -321,7 +321,7 @@ private void startConversion(String modelId, String displayName, String mode) { statusLabel.setText("Converting " + displayName + " to ONNX\u2026"); progressBar.setVisible(true); progressBar.setIndeterminate(true); - AppLogger.model("Starting PyTorch \u2192 ONNX conversion: " + modelId); + System.out.println("[LumenForge] Starting PyTorch \u2192 ONNX conversion: " + modelId); // 5. Run on background thread final String conversionMode = mode; @@ -346,7 +346,7 @@ private void startConversion(String modelId, String displayName, String mode) { PyTorchToOnnxConverter.openPythonDownloadPage(); } } else { - AppLogger.modelError("Conversion failed for " + modelId + ": " + cause.getMessage()); + System.out.println("[LumenForge] ERROR: Conversion failed for " + modelId + ": " + cause.getMessage()); statusLabel.setText("Conversion failed: " + cause.getMessage()); JOptionPane.showMessageDialog(this, "Conversion failed:\n" + cause.getMessage(), @@ -395,7 +395,7 @@ private void registerConvertedModel(String modelId, String sanitized, Path outpu modelRegistry.mergeDownloadableAssets(List.of(desc)); refreshTable(); statusLabel.setText("\u2713 Converted and registered: " + modelId); - AppLogger.model("Conversion complete: " + modelId + " \u2192 " + relativePath); + System.out.println("[LumenForge] Conversion complete: " + modelId + " \u2192 " + relativePath); if (onModelsUpdated != null) { onModelsUpdated.run(); }