diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..1950e51 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,91 @@ +name: Build LumenForge + +on: + push: + branches: [ main ] + tags: [ 'v*' ] + pull_request: + branches: [ main ] + workflow_dispatch: + +permissions: + contents: write # needed for release uploads + +jobs: + build: + name: ${{ matrix.label }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + # ── Universal (CPU + CoreML on macOS) ────────────────────── + # Works on macOS (CoreML GPU/ANE), Windows (CPU), Linux (CPU). + # AMD DirectML / Intel OpenVINO users: install native libs + # separately and this JAR will auto-detect them at runtime. + - label: "Universal (CPU / macOS CoreML)" + ort_artifact: onnxruntime + classifier: "-universal" + asset_name: lumenforge-universal.jar + + # ── NVIDIA GPU (CUDA + TensorRT) ─────────────────────────── + # For Windows and Linux with NVIDIA GPUs. + # Requires CUDA toolkit + cuDNN on the host system. + - label: "NVIDIA GPU (CUDA + TensorRT)" + ort_artifact: onnxruntime_gpu + classifier: "-nvidia" + asset_name: lumenforge-nvidia.jar + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up JDK 21 + uses: actions/setup-java@v4 + with: + distribution: temurin + java-version: '21' + cache: maven + + - name: Build fat JAR + run: | + mvn clean package -DskipTests \ + -Dort.artifactId=${{ matrix.ort_artifact }} \ + -Djar.classifier=${{ matrix.classifier }} + + - name: Rename artifact + run: | + JAR=$(find target -maxdepth 1 -name "lumenforge-*${{ matrix.classifier }}.jar" | head -1) + cp "$JAR" "target/${{ matrix.asset_name }}" + echo "Built: ${{ matrix.asset_name }} ($(du -h target/${{ matrix.asset_name }} | cut -f1))" + + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: ${{ matrix.asset_name }} + path: target/${{ matrix.asset_name }} + retention-days: 14 + + # ── Release (only on version tags) ────────────────────────────────── + release: + name: Create Release + needs: build + if: startsWith(github.ref, 'refs/tags/v') + runs-on: ubuntu-latest + steps: + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: artifacts + merge-multiple: true + + - name: List artifacts + run: ls -lh artifacts/ + + - name: Create GitHub Release + uses: softprops/action-gh-release@v2 + with: + generate_release_notes: true + files: | + artifacts/lumenforge-universal.jar + artifacts/lumenforge-nvidia.jar diff --git a/README.md b/README.md index 95061e1..72e0b02 100644 --- a/README.md +++ b/README.md @@ -1,66 +1,117 @@ # LumenForge -Desktop Java Swing application for Java-native ONNX Runtime workflows with automatic GPU/CPU provider fallback. +[![Build LumenForge](https://github.com/palaashatri/lumenforge/actions/workflows/build.yml/badge.svg)](https://github.com/palaashatri/lumenforge/actions/workflows/build.yml) -- Text → Image: Stable Diffusion v1.5 UNet + Real-ESRGAN with fully automatic downloads +Desktop Java Swing application for ONNX Runtime inference with intelligent GPU acceleration across NVIDIA, Apple, Intel, and AMD hardware. image - ## Features -- Native look-and-feel handling: - - macOS: system-native look-and-feel - - Windows/Linux: FlatLaf with automatic system dark/light detection -- High-performance async execution using virtual threads -- Model downloader module with progress reporting -- Auto-download when a selected model is missing in task tabs -- Model Manager downloads only on explicit **Download / Redownload** button click -- Model Manager includes additional ONNX pipeline assets (UNet/Text Encoder/VAE/Safety Checker) for one-click download -- Manual ONNX model import per selected row in Model Manager -- Prompt library: saved presets + tags + search +### Image Generation +- **Text → Image**: Stable Diffusion v1.5, SD Turbo, SDXL Turbo, SDXL Base 1.0, SD 3.5 (MMDiT) +- **Image → Image**: Img2Img with adjustable strength + inpainting +- **Image Upscaling**: Real-ESRGAN 4× super-resolution with before/after preview +- **Prompt library**: saved presets, tags, and search - Negative prompts + prompt weighting -- Seed control + reproducibility presets -- Batch generation settings (N images per prompt) +- Seed control + reproducibility +- Batch generation (N images per prompt) - Aspect ratio presets + custom size fields -- Upscale toggle (Real-ESRGAN) with before/after preview (UI wiring) - Style presets (cinematic, sketch, product, etc.) - History gallery with metadata (seed, model, settings) - Export pipeline logs for debugging + +### SD 3.5 Support +- MMDiT transformer architecture with Flow Matching Euler scheduler +- Triple text encoding: CLIP-L (768d) + CLIP-G (1280d) + T5-XXL (4096d) +- Built-in T5 tokenizer (SentencePiece/Unigram with Viterbi segmentation) +- 16-channel latent space + +### Model Management +- **Automatic HuggingFace discovery**: finds ONNX and PyTorch Stable Diffusion + ESRGAN models +- **One-click download** with resume, retry, and stall detection +- **PyTorch → ONNX auto-conversion**: downloading a PyTorch model triggers automatic conversion via managed Python venv +- Manual ONNX model import +- Gated model support with HuggingFace token authentication - Local model storage in `~/.lumenforge-models` -- Execution provider fallback by OS: - - macOS: CoreML → CPU - - Windows: DirectML → CUDA → CPU - - Linux: CUDA → ROCm → CPU -## Build +### GPU Acceleration +Intelligent execution provider selection — LumenForge probes available EPs at runtime and picks the best one: + +| Platform | Priority (highest → lowest) | +|---|---| +| **macOS** | CoreML (GPU + ANE + CPU) → CPU | +| **Windows** | TensorRT-RTX → TensorRT → CUDA → DirectML → OpenVINO → CPU | +| **Linux** | TensorRT → CUDA → ROCm → OpenVINO → CPU | + +Override with `-Dlumenforge.ep=cuda` (or any EP key) to force a specific provider. + +### UI +- Native look-and-feel: system-native on macOS, FlatLaf with dark/light detection on Windows/Linux +- High-performance async execution using virtual threads +- Per-step progress with timing and ETA +- Session and tokenizer caching for fast repeated inference + +## Downloads + +Pre-built fat JARs are available from [GitHub Releases](https://github.com/palaashatri/lumenforge/releases): + +| JAR | GPU Support | Use When | +|---|---|---| +| `lumenforge-universal.jar` | macOS CoreML (M-series GPU/ANE), CPU everywhere | macOS, or Windows/Linux without NVIDIA GPU | +| `lumenforge-nvidia.jar` | CUDA + TensorRT (Windows/Linux) | Windows/Linux with NVIDIA GPU + CUDA installed | + +> **Note**: DirectML (AMD/Intel on Windows), OpenVINO (Intel), and ROCm (AMD on Linux) are auto-detected at runtime if the native libraries are installed on the system. The universal JAR handles this automatically. ```bash -mvn -DskipTests compile +# Run any variant +java -jar lumenforge-universal.jar +java -jar lumenforge-nvidia.jar ``` -Enable GPU runtime artifact (Windows/Linux): +## Build from Source + +Requires **Java 21+** and **Maven 3.8+**. ```bash -mvn -Donnx.gpu=true -DskipTests compile +# Universal build (CPU + CoreML) +mvn clean package -DskipTests + +# NVIDIA GPU build (CUDA + TensorRT) +mvn clean package -DskipTests -Dort.artifactId=onnxruntime_gpu + +# Force CPU-only +mvn clean package -DskipTests -Dort.artifactId=onnxruntime ``` -## Run +### Run from source ```bash -mvn exec:java +mvn clean compile exec:java ``` -## Test +### Run tests ```bash mvn clean test ``` +## CI / CD + +GitHub Actions builds both JAR variants on every push to `main` and PR. Pushing a version tag (e.g. `v1.0.0`) creates a GitHub Release with both JARs attached. + +See [.github/workflows/build.yml](.github/workflows/build.yml) for details. + +## Requirements + +- **Java 21** or later +- **macOS 10.15+** for CoreML acceleration (M-series recommended) +- **CUDA 12 + cuDNN** for NVIDIA GPU acceleration (RTX 30xx+ recommended) +- **Python 3.8+** (optional) for PyTorch → ONNX model conversion + ## Notes -- Runtime is Java-only ONNX Runtime execution with GPU fallback (no Python bridge). -- Override provider order via JVM property: `-Dlumenforge.ep=cpu|coreml|directml|cuda|rocm`. -- Task tabs show preview images when an output artifact is generated, with **Open Output** to launch the file. -- SD Turbo UNet remains experimental and may fail on CPU-only environments. +- Runtime is pure Java — ONNX Runtime execution with GPU fallback. No Python bridge needed at inference time. +- Override EP order via JVM property: `-Dlumenforge.ep=cpu|coreml|cuda|tensorrt|directml|openvino|rocm` +- Task tabs show preview images when output is generated, with **Open Output** to launch the file. - If a model requires external tensor files (e.g. `weights.pb`), import the complete ONNX bundle into the model directory. diff --git a/pom.xml b/pom.xml index 3ecd726..b834b1e 100644 --- a/pom.xml +++ b/pom.xml @@ -11,13 +11,18 @@ UTF-8 21 + + onnxruntime + 1.22.0 + + com.microsoft.onnxruntime - onnxruntime - 1.22.0 + ${ort.artifactId} + ${ort.version} com.formdev @@ -69,6 +74,7 @@ shade false + ${project.artifactId}-${project.version}${jar.classifier} atri.palaash.lumenforge.app.LumenForgeApp @@ -109,24 +115,22 @@ + gpu-windows - - Windows - - - onnx.gpu - true - + Windows - - - com.microsoft.onnxruntime - onnxruntime_gpu - 1.22.0 - - + + onnxruntime_gpu + gpu-linux @@ -135,54 +139,11 @@ unix Linux - - onnx.gpu - true - - - - - com.microsoft.onnxruntime - onnxruntime_gpu - 1.22.0 - - - - - - djl-pytorch - - - djl - true - - - - - ai.djl - bom - 0.31.0 - pom - import - - - - - - ai.djl - api - - - ai.djl.pytorch - pytorch-engine - runtime - - - ai.djl.huggingface - tokenizers - - + + onnxruntime_gpu + + diff --git a/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java b/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java index 1ba6d45..dfc00a8 100644 --- a/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java +++ b/src/main/java/atri/palaash/lumenforge/app/LumenForgeApp.java @@ -22,6 +22,11 @@ public class LumenForgeApp { public static void main(String[] args) { configureDesktopIntegration(); + System.out.println("[LumenForge] Starting \u2014 Java " + Runtime.version() + + ", " + System.getProperty("os.name") + " " + System.getProperty("os.arch")); + System.out.println("[LumenForge] Max heap: " + (Runtime.getRuntime().maxMemory() / (1024 * 1024)) + " MB" + + ", CPUs: " + Runtime.getRuntime().availableProcessors()); + ExecutorService workerPool = Executors.newVirtualThreadPerTaskExecutor(); ModelStorage storage = new ModelStorage(); ModelRegistry registry = new ModelRegistry(); diff --git a/src/main/java/atri/palaash/lumenforge/inference/DjlPyTorchService.java b/src/main/java/atri/palaash/lumenforge/inference/DjlPyTorchService.java deleted file mode 100644 index 56aad75..0000000 --- a/src/main/java/atri/palaash/lumenforge/inference/DjlPyTorchService.java +++ /dev/null @@ -1,396 +0,0 @@ -package atri.palaash.lumenforge.inference; - -import atri.palaash.lumenforge.storage.ModelStorage; - -import java.awt.image.BufferedImage; -import java.lang.reflect.Method; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Executor; - -/** - * Stable Diffusion inference using DJL (Deep Java Library) with PyTorch backend. - * - *

This backend loads TorchScript-traced models (.pt files) and runs inference - * using PyTorch's native tensor operations (CPU, CUDA, or MPS). Activate by - * building with {@code mvn compile -Ddjl=true}. - * - *

Required model layout

- *
- *   ~/.lumenforge-models/text-image/sd-pytorch/
- *     clip_model.pt           — traced CLIP text encoder
- *     unet_model.pt           — traced UNet2DConditionModel
- *     vae_decoder_model.pt    — traced VAE decoder
- *     tokenizer.json          — HuggingFace tokenizer config
- * 
- * - *

Use the included {@code scripts/export_torchscript.py} helper to export - * HuggingFace SD checkpoints to TorchScript format. - */ -public class DjlPyTorchService implements InferenceService { - - private final ModelStorage storage; - private final Executor executor; - - /* ── DJL availability check ─────────────────────────────────────── */ - - private static final boolean DJL_AVAILABLE; - static { - boolean found; - try { - Class.forName("ai.djl.ndarray.NDManager"); - found = true; - } catch (ClassNotFoundException e) { - found = false; - } - DJL_AVAILABLE = found; - } - - /** Returns {@code true} if the DJL PyTorch runtime is on the classpath. */ - public static boolean isAvailable() { - return DJL_AVAILABLE; - } - - public DjlPyTorchService(ModelStorage storage, Executor executor) { - this.storage = storage; - this.executor = executor; - } - - @Override - public CompletableFuture run(InferenceRequest request) { - return CompletableFuture.supplyAsync(() -> { - if (!DJL_AVAILABLE) { - return InferenceResult.fail( - "DJL PyTorch is not on the classpath. Rebuild with: mvn compile -Ddjl=true"); - } - - Path base = storage.root().resolve("text-image").resolve("sd-pytorch"); - Path clipPath = base.resolve("clip_model.pt"); - Path unetPath = base.resolve("unet_model.pt"); - Path vaePath = base.resolve("vae_decoder_model.pt"); - Path tokPath = base.resolve("tokenizer.json"); - - for (Path p : new Path[]{clipPath, unetPath, vaePath, tokPath}) { - if (!Files.exists(p)) { - return InferenceResult.fail( - "Missing TorchScript component: " + p.getFileName() - + "\nExport models with scripts/export_torchscript.py first."); - } - } - - try { - return runPyTorchPipeline(request, clipPath, unetPath, vaePath, tokPath); - } catch (Exception ex) { - return InferenceResult.fail("DJL PyTorch pipeline failed: " + ex.getMessage()); - } - }, executor); - } - - /* ── Pipeline (reflection-based to compile without DJL on classpath) ─ */ - - @SuppressWarnings("unchecked") - private InferenceResult runPyTorchPipeline(InferenceRequest request, - Path clipPath, Path unetPath, - Path vaePath, Path tokPath) - throws Exception { - - // All DJL calls go through reflection so the class compiles even when - // the djl-pytorch profile is not active. - - Class clsNDManager = Class.forName("ai.djl.ndarray.NDManager"); - Class clsNDArray = Class.forName("ai.djl.ndarray.NDArray"); - Class clsNDList = Class.forName("ai.djl.ndarray.NDList"); - Class clsShape = Class.forName("ai.djl.ndarray.types.Shape"); - Class clsDataType = Class.forName("ai.djl.ndarray.types.DataType"); - Class clsDevice = Class.forName("ai.djl.Device"); - Class clsCriteria = Class.forName("ai.djl.repository.zoo.Criteria"); - Class clsZooModel = Class.forName("ai.djl.repository.zoo.ZooModel"); - Class clsPredictor = Class.forName("ai.djl.inference.Predictor"); - Class clsTokenizer = Class.forName("ai.djl.huggingface.tokenizers.HuggingFaceTokenizer"); - Class clsEncoding = Class.forName("ai.djl.huggingface.tokenizers.Encoding"); - Class clsTranslator = Class.forName("ai.djl.translate.NoopTranslator"); - - int width = Math.max(256, (request.width() / 8) * 8); - int height = Math.max(256, (request.height() / 8) * 8); - int latentW = width / 8; - int latentH = height / 8; - int steps = Math.max(1, Math.min(request.batch() > 0 ? request.batch() : 20, 50)); - - // Choose device — prefer GPU if available - Object device; - if (request.preferGpu()) { - Method gpuMethod = clsDevice.getMethod("gpu"); - device = gpuMethod.invoke(null); - } else { - Method cpuMethod = clsDevice.getMethod("cpu"); - device = cpuMethod.invoke(null); - } - - // Create NDManager - Method managerOf = clsNDManager.getMethod("newBaseManager", clsDevice); - Object manager = managerOf.invoke(null, device); - - try { - /* ── Tokenize ──────────────────────────────────────────────── */ - request.reportProgress("Tokenizing prompt (DJL HuggingFace)\u2026"); - Method tokLoad = clsTokenizer.getMethod("newInstance", Path.class); - Object tokenizer = tokLoad.invoke(null, tokPath); - Method tokEncode = clsTokenizer.getMethod("encode", String.class); - Object encoding = tokEncode.invoke(tokenizer, request.prompt()); - Method getIds = clsEncoding.getMethod("getIds"); - long[] tokenIds = (long[]) getIds.invoke(encoding); - - // Pad/truncate to 77 - long[] padded = new long[77]; - System.arraycopy(tokenIds, 0, padded, 0, Math.min(tokenIds.length, 77)); - if (tokenIds.length < 77) { - long padId = 49407L; // <|endoftext|> - for (int i = tokenIds.length; i < 77; i++) padded[i] = padId; - } - - // Create token tensor [1, 77] - Object shapeTokens = clsShape.getConstructor(long[].class) - .newInstance((Object) new long[]{1, 77}); - Method mgrCreate = clsNDManager.getMethod("create", long[].class, - Class.forName("ai.djl.ndarray.types.Shape")); - Object tokenTensor = mgrCreate.invoke(manager, padded, shapeTokens); - - /* ── CLIP text encoder ──────────────────────────────────── */ - request.reportProgress("Loading CLIP text encoder (PyTorch)\u2026"); - Object clipModel = loadTorchScriptModel(clsCriteria, clsNDList, clsTranslator, - clipPath, device); - Object clipPredictor = clsZooModel.getMethod("newPredictor").invoke(clipModel); - - Object clipInput = clsNDList.getConstructor(clsNDArray).newInstance(tokenTensor); - request.reportProgress("Running CLIP text encoder\u2026"); - Object clipOutput = clsPredictor.getMethod("predict", Object.class) - .invoke(clipPredictor, clipInput); - Object textEmbedding = clsNDList.getMethod("singletonOrThrow").invoke(clipOutput); - - /* ── Negative prompt (uncond) encoder ───────────────────── */ - long[] emptyTokens = new long[77]; - emptyTokens[0] = 49406L; // <|startoftext|> - for (int i = 1; i < 77; i++) emptyTokens[i] = 49407L; - Object uncondTensor = mgrCreate.invoke(manager, emptyTokens, shapeTokens); - Object uncondInput = clsNDList.getConstructor(clsNDArray).newInstance(uncondTensor); - request.reportProgress("Running uncond text encoder\u2026"); - Object uncondOutput = clsPredictor.getMethod("predict", Object.class) - .invoke(clipPredictor, uncondInput); - Object uncondEmbedding = clsNDList.getMethod("singletonOrThrow").invoke(uncondOutput); - - // Concatenate [uncond, cond] → [2, 77, 768] - Method concatMethod = clsNDArray.getMethod("concat", clsNDArray, int.class); - // Stack along batch dim: uncond embedding [1,77,768] + text embedding [1,77,768] - Object embeddings = concatMethod.invoke(uncondEmbedding, textEmbedding, 0); - - clsPredictor.getMethod("close").invoke(clipPredictor); - clsZooModel.getMethod("close").invoke(clipModel); - - /* ── UNet denoising loop ────────────────────────────────── */ - request.reportProgress("Loading UNet (PyTorch)\u2026"); - Object unetModel = loadTorchScriptModel(clsCriteria, clsNDList, clsTranslator, - unetPath, device); - Object unetPredictor = clsZooModel.getMethod("newPredictor").invoke(unetModel); - - // Initialize random latents [1, 4, latentH, latentW] - Object latentShape = clsShape.getConstructor(long[].class) - .newInstance((Object) new long[]{1, 4, latentH, latentW}); - Method mgrRandomNormal = clsNDManager.getMethod("randomNormal", - Class.forName("ai.djl.ndarray.types.Shape")); - Object latents = mgrRandomNormal.invoke(manager, latentShape); - - // Simple linear schedule for beta - double guidanceScale = request.promptWeight() > 0 ? request.promptWeight() : 7.5; - float[] timestepSchedule = ddimTimesteps(steps); - - // Initial scale - Method mulMethod = clsNDArray.getMethod("mul", Number.class); - latents = mulMethod.invoke(latents, (float) Math.sqrt(1.0 + sigmaForTimestep(timestepSchedule[0]))); - - long stepStart = System.currentTimeMillis(); - for (int i = 0; i < timestepSchedule.length; i++) { - if (request.isCancelled()) { - clsPredictor.getMethod("close").invoke(unetPredictor); - clsZooModel.getMethod("close").invoke(unetModel); - return InferenceResult.fail("Cancelled by user."); - } - - float t = timestepSchedule[i]; - - // Duplicate latents for CFG: [latents, latents] → [2, 4, H, W] - Object latentDouble = concatMethod.invoke(latents, latents, 0); - - // Create timestep tensor - Object tsShape = clsShape.getConstructor(long[].class) - .newInstance((Object) new long[]{1}); - Object tsTensor = mgrCreate.invoke(manager, new long[]{(long) t}, tsShape); - - // UNet forward: (noisy_latents, timestep, encoder_hidden_states) - Object unetInput = clsNDList.getConstructor(clsNDArray.arrayType()) - .newInstance((Object) new Object[]{latentDouble, tsTensor, embeddings}); - Object unetOutput = clsPredictor.getMethod("predict", Object.class) - .invoke(unetPredictor, unetInput); - Object noisePred = clsNDList.getMethod("singletonOrThrow").invoke(unetOutput); - - // Classifier-free guidance: split [2,4,H,W] → uncond, cond - Method splitMethod = clsNDArray.getMethod("split", long.class, int.class); - Object splitResult = splitMethod.invoke(noisePred, 2L, 0); - Object noiseUncond = java.lang.reflect.Array.get( - clsNDList.getMethod("toArray").invoke(splitResult), 0); - Object noiseCond = java.lang.reflect.Array.get( - clsNDList.getMethod("toArray").invoke(splitResult), 1); - - // guided = uncond + scale * (cond - uncond) - Method subMethod = clsNDArray.getMethod("sub", clsNDArray); - Method addMethod = clsNDArray.getMethod("add", clsNDArray); - Object diff = subMethod.invoke(noiseCond, noiseUncond); - Object scaled = mulMethod.invoke(diff, (float) guidanceScale); - Object guided = addMethod.invoke(noiseUncond, scaled); - - // DDIM step - float alphaT = alphaForTimestep(t); - float alphaPrev = (i + 1 < timestepSchedule.length) - ? alphaForTimestep(timestepSchedule[i + 1]) : 1.0f; - - // predicted x0 = (latents - sqrt(1-alpha)*noise) / sqrt(alpha) - Object scaledNoise = mulMethod.invoke(guided, (float) Math.sqrt(1.0 - alphaT)); - Object x0 = subMethod.invoke(latents, scaledNoise); - x0 = mulMethod.invoke(x0, (float) (1.0 / Math.sqrt(alphaT))); - - // direction pointing to x_t - Object dirXt = mulMethod.invoke(guided, (float) Math.sqrt(1.0 - alphaPrev)); - // x_{t-1} = sqrt(alpha_prev) * x0 + sqrt(1-alpha_prev) * noise_pred - Object prevSample = mulMethod.invoke(x0, (float) Math.sqrt(alphaPrev)); - latents = addMethod.invoke(prevSample, dirXt); - - long elapsed = System.currentTimeMillis() - stepStart; - stepStart = System.currentTimeMillis(); - request.reportProgress("Denoising: " + (i + 1) + "/" + steps - + " steps (" + String.format("%.1f", elapsed / 1000.0) + "s/step) [PyTorch]"); - } - - clsPredictor.getMethod("close").invoke(unetPredictor); - clsZooModel.getMethod("close").invoke(unetModel); - - /* ── VAE decode ─────────────────────────────────────────── */ - request.reportProgress("Loading VAE decoder (PyTorch)\u2026"); - Object vaeModel = loadTorchScriptModel(clsCriteria, clsNDList, clsTranslator, - vaePath, device); - Object vaePredictor = clsZooModel.getMethod("newPredictor").invoke(vaeModel); - - // Scale latents - latents = mulMethod.invoke(latents, 1.0f / 0.18215f); - - Object vaeInput = clsNDList.getConstructor(clsNDArray).newInstance(latents); - request.reportProgress("Decoding with VAE\u2026"); - Object vaeOutput = clsPredictor.getMethod("predict", Object.class) - .invoke(vaePredictor, vaeInput); - Object decoded = clsNDList.getMethod("singletonOrThrow").invoke(vaeOutput); - - clsPredictor.getMethod("close").invoke(vaePredictor); - clsZooModel.getMethod("close").invoke(vaeModel); - - /* ── Convert tensor to BufferedImage ───────────────────── */ - // decoded: [1, 3, H, W] float32, range roughly [-1, 1] - Method toFloatArray = clsNDArray.getMethod("toFloatArray"); - float[] pixels = (float[]) toFloatArray.invoke(decoded); - - Method getShapeMethod = clsNDArray.getMethod("getShape"); - Object decodedShape = getShapeMethod.invoke(decoded); - long[] dims = (long[]) clsShape.getMethod("getShape").invoke(decodedShape); - int imgH = (int) dims[2]; - int imgW = (int) dims[3]; - - BufferedImage image = new BufferedImage(imgW, imgH, BufferedImage.TYPE_INT_RGB); - for (int y = 0; y < imgH; y++) { - for (int x = 0; x < imgW; x++) { - int rIdx = 0 * imgH * imgW + y * imgW + x; - int gIdx = 1 * imgH * imgW + y * imgW + x; - int bIdx = 2 * imgH * imgW + y * imgW + x; - int r = clamp((int) ((pixels[rIdx] / 2 + 0.5f) * 255)); - int g = clamp((int) ((pixels[gIdx] / 2 + 0.5f) * 255)); - int b = clamp((int) ((pixels[bIdx] / 2 + 0.5f) * 255)); - image.setRGB(x, y, (r << 16) | (g << 8) | b); - } - } - - Path outputDir = storage.root().resolve("outputs"); - Files.createDirectories(outputDir); - String filename = "djl-pytorch_" + System.currentTimeMillis() + ".png"; - Path outputPath = outputDir.resolve(filename); - javax.imageio.ImageIO.write(image, "PNG", outputPath.toFile()); - - return InferenceResult.ok( - "Generated image for prompt: \"" + request.prompt() + "\"", - "DJL PyTorch pipeline completed (" + steps + " steps)", - outputPath.toString(), "image"); - - } finally { - // Close NDManager - clsNDManager.getMethod("close").invoke(manager); - } - } - - /* ── Helpers ──────────────────────────────────────────────────────── */ - - /** Load a TorchScript model via DJL Criteria (reflection). */ - private Object loadTorchScriptModel(Class clsCriteria, Class clsNDList, - Class clsTranslator, Path modelPath, - Object device) throws Exception { - Method builder = clsCriteria.getMethod("builder"); - Object b = builder.invoke(null); - - Method setTypes = b.getClass().getMethod("setTypes", Class.class, Class.class); - b = setTypes.invoke(b, clsNDList, clsNDList); - - Method optModelPath = b.getClass().getMethod("optModelPath", Path.class); - b = optModelPath.invoke(b, modelPath); - - Method optEngine = b.getClass().getMethod("optEngine", String.class); - b = optEngine.invoke(b, "PyTorch"); - - Object translator = clsTranslator.getDeclaredConstructor().newInstance(); - Method optTranslator = b.getClass().getMethod("optTranslator", - Class.forName("ai.djl.translate.Translator")); - b = optTranslator.invoke(b, translator); - - Method optDevice = b.getClass().getMethod("optDevice", - Class.forName("ai.djl.Device")); - b = optDevice.invoke(b, device); - - Method build = b.getClass().getMethod("build"); - Object criteria = build.invoke(b); - - Method loadModel = clsCriteria.getMethod("loadModel"); - return loadModel.invoke(criteria); - } - - /** DDIM linear timestep schedule. */ - private float[] ddimTimesteps(int numSteps) { - float[] ts = new float[numSteps]; - for (int i = 0; i < numSteps; i++) { - ts[i] = 999.0f * (1.0f - (float) i / numSteps); - } - return ts; - } - - /** Linear beta schedule → alpha_bar for a given timestep. */ - private float alphaForTimestep(float t) { - // Simplified linear schedule: beta from 0.00085 to 0.012 over 1000 steps - double beta = 0.00085 + (0.012 - 0.00085) * t / 999.0; - double alpha = 1.0 - beta; - // Approximate cumulative product - return (float) Math.pow(alpha, t); - } - - /** Sigma for a given timestep (derived from alpha). */ - private double sigmaForTimestep(float t) { - float a = alphaForTimestep(t); - return Math.sqrt((1.0 - a) / a); - } - - private static int clamp(int v) { - return Math.max(0, Math.min(255, v)); - } -} diff --git a/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java b/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java index 58b8c75..e4a1d65 100644 --- a/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java +++ b/src/main/java/atri/palaash/lumenforge/inference/GenericOnnxService.java @@ -71,10 +71,13 @@ private OrtSession getOrCreateSession(OrtEnvironment env, Path modelPath, String key = modelPath.toAbsolutePath().toString(); OrtSession existing = SESSION_CACHE.get(key); if (existing != null) { + System.out.println("[LumenForge] Session cache hit: " + modelPath.getFileName()); return existing; } + System.out.println("[LumenForge] Loading ONNX session: " + modelPath.getFileName()); OrtSession session = env.createSession(modelPath.toString(), opts); SESSION_CACHE.put(key, session); + System.out.println("[LumenForge] Session loaded: " + modelPath.getFileName()); return session; } @@ -126,10 +129,13 @@ public static void clearCache() { public CompletableFuture run(InferenceRequest request) { return CompletableFuture.supplyAsync(() -> { if (!storage.isAvailable(request.model())) { + System.out.println("[LumenForge] ERROR: Model not found locally: " + request.model().displayName()); return InferenceResult.fail("Model not found locally. Open Model Manager from the menu bar and download it first."); } Path modelPath = storage.modelPath(request.model()); + System.out.println("[LumenForge] Starting inference: " + request.model().displayName() + + " (" + request.model().id() + ")"); // Temporarily intercept stderr so ONNX Runtime native warnings // appear in the application's Log tab instead of only in the console. @@ -146,6 +152,8 @@ public CompletableFuture run(InferenceRequest request) { sessionOptions.setIntraOpNumThreads(Math.max(1, cpus - 1)); sessionOptions.setInterOpNumThreads(Math.max(1, Math.min(cpus / 2, 4))); ProviderSelection providerSelection = configureExecutionProvider(sessionOptions, request.preferGpu()); + System.out.println("[LumenForge] Using EP: " + providerSelection.provider() + + (providerSelection.notes().isBlank() ? "" : " (" + providerSelection.notes() + ")")); // Invalidate session cache when the execution provider changes. String epKey = providerSelection.provider() + "|" + request.preferGpu(); @@ -175,11 +183,6 @@ public CompletableFuture run(InferenceRequest request) { && request.model().relativePath().contains("transformer/")) { return runSd3(environment, sessionOptions, request, providerSelection.provider()); } - // DJL/PyTorch models — delegate to the DJL backend - if (request.model().id().contains("pytorch")) { - DjlPyTorchService djl = new DjlPyTorchService(storage, executor); - return djl.run(request).join(); - } // Img2Img pipelines if ("sd_v15_img2img".equals(request.model().id())) { return runImg2Img(environment, sessionOptions, request, providerSelection.provider(), false); @@ -192,9 +195,11 @@ public CompletableFuture run(InferenceRequest request) { + request.model().displayName() + " | task=" + taskType.displayName() + " | EP=" + providerSelection.provider() + providerSelection.noteSuffix(); + System.out.println("[LumenForge] WARN: " + details); return InferenceResult.fail(details); } catch (OrtException ex) { String message = ex.getMessage() == null ? "Unknown ONNX Runtime error" : ex.getMessage(); + System.out.println("[LumenForge] ERROR: ONNX Runtime error: " + message); if (message.contains("NhwcConv")) { return InferenceResult.fail( "This ONNX model uses custom ops (e.g., NhwcConv) not available in CPUExecutionProvider. " @@ -2600,8 +2605,22 @@ private List viterbi(String text) { } } + /** + * Build the EP preference list for the current platform, try each in order, + * and return the first one that the loaded ONNX Runtime native library supports. + * + *

Priority order per platform (highest → lowest): + *

    + *
  • macOS: CoreML (GPU+ANE+CPU) → CPU
  • + *
  • Windows: TensorRT-RTX → TensorRT → CUDA → DirectML → OpenVINO → CPU
  • + *
  • Linux: TensorRT → CUDA → ROCm → OpenVINO → CPU
  • + *
+ * + *

Override with {@code -Dlumenforge.ep=cuda} (or any key) to force a specific EP. + */ private ProviderSelection configureExecutionProvider(OrtSession.SessionOptions options, boolean preferGpu) { String os = System.getProperty("os.name", "unknown").toLowerCase(); + String arch = System.getProperty("os.arch", "unknown").toLowerCase(); String forced = System.getProperty("lumenforge.ep", "").trim().toLowerCase(); List preference = new ArrayList<>(); @@ -2614,51 +2633,137 @@ private ProviderSelection configureExecutionProvider(OrtSession.SessionOptions o preference.add("coreml"); preference.add("cpu"); } else if (os.contains("win")) { - preference.add("directml"); - preference.add("cuda"); + preference.add("tensorrt_rtx"); // RTX 30xx+ (Ampere+) + preference.add("tensorrt"); // any NVIDIA with TensorRT libs + preference.add("cuda"); // CUDA fallback + preference.add("directml"); // AMD / Intel / any DX12 GPU + preference.add("openvino"); // Intel CPUs/GPUs/NPUs preference.add("cpu"); } else { + // Linux + preference.add("tensorrt"); preference.add("cuda"); - preference.add("rocm"); + preference.add("rocm"); // AMD GPUs + preference.add("openvino"); preference.add("cpu"); } StringBuilder notes = new StringBuilder(); + List failReasons = new ArrayList<>(); + System.out.println("[LumenForge] EP preference order: " + preference + + " (os=" + os + ", arch=" + arch + ")"); + + if (!preferGpu) { + System.out.println("[LumenForge] GPU not requested for this session — using CPUExecutionProvider"); + return new ProviderSelection("CPUExecutionProvider", "GPU not requested"); + } + for (String candidate : preference) { if ("cpu".equals(candidate)) { + // All GPU EPs exhausted — log summary + System.out.println("[LumenForge] WARN: Falling back to CPUExecutionProvider"); + if (!failReasons.isEmpty()) { + System.out.println("[LumenForge] WARN: Reason: every GPU execution provider was unavailable:"); + for (String reason : failReasons) { + System.out.println("[LumenForge] WARN: - " + reason); + } + System.out.println("[LumenForge] WARN: Tip: install the matching GPU runtime " + + "(e.g. CUDA/cuDNN for NVIDIA, ROCm for AMD) or use " + + "-Dlumenforge.ep= to force a specific EP."); + } return new ProviderSelection("CPUExecutionProvider", notes.toString()); } - if (tryEnableProvider(options, candidate, notes)) { - return new ProviderSelection(providerDisplayName(candidate), notes.toString()); + String failReason = tryEnableProvider(options, candidate, notes); + if (failReason == null) { + String display = providerDisplayName(candidate); + System.out.println("[LumenForge] \u2713 Enabled " + display); + return new ProviderSelection(display, notes.toString()); + } + failReasons.add(failReason); + } + // Preference list didn't include "cpu" explicitly — shouldn't happen, but handle it + System.out.println("[LumenForge] WARN: No EP available, falling back to CPUExecutionProvider"); + if (!failReasons.isEmpty()) { + System.out.println("[LumenForge] WARN: Reasons:"); + for (String reason : failReasons) { + System.out.println("[LumenForge] WARN: - " + reason); } } return new ProviderSelection("CPUExecutionProvider", notes.toString()); } - private boolean tryEnableProvider(OrtSession.SessionOptions options, String candidate, StringBuilder notes) { + /** + * Attempt to enable a single EP via reflection. + * + * @return {@code null} if the provider was enabled successfully, + * or a human-readable reason string if it could not be enabled. + */ + private String tryEnableProvider(OrtSession.SessionOptions options, String candidate, StringBuilder notes) { + String failDetail = null; try { - return switch (candidate) { - case "cuda" -> invokeNoArg(options, "addCUDA"); - case "directml" -> invokeNoArg(options, "addDirectML") || invokeIntArg(options, "addDirectML", 0); + boolean ok = switch (candidate) { + + /* ── NVIDIA ───────────────────────────────────────────── */ + case "tensorrt_rtx" -> + // NvTensorRtRtxExecutionProvider – RTX 30xx+ only + // Java method not yet in stock Maven artifacts; try reflection just in case + invokeNoArg(options, "addNvTensorRtRtx") + || invokeIntArg(options, "addNvTensorRtRtx", 0); + + case "tensorrt" -> + // TensorrtExecutionProvider – available in onnxruntime_gpu + invokeIntArg(options, "addTensorrt", 0) + || invokeNoArg(options, "addTensorrt"); + + case "cuda" -> + invokeNoArg(options, "addCUDA"); + + /* ── Apple ────────────────────────────────────────────── */ case "coreml" -> { - // Flags: COREML_FLAG_USE_CPU_AND_GPU (1) enables all ANE/GPU subgraphs - boolean ok = invokeIntArg(options, "addCoreML", 1); - if (!ok) { ok = invokeNoArg(options, "addCoreML"); } - if (!ok) { ok = invokeIntArg(options, "addCoreML", 0); } - yield ok; + // addCoreML(long flags) — note: parameter is long, not int! + // Flag 0x0 = ALL compute units (CPU+GPU+ANE — best for M-series) + boolean coreOk = invokeLongArg(options, "addCoreML", 0L); + if (!coreOk) { coreOk = invokeNoArg(options, "addCoreML"); } + yield coreOk; } - case "rocm" -> invokeNoArg(options, "addROCM"); + + /* ── Microsoft ────────────────────────────────────────── */ + case "directml" -> + invokeIntArg(options, "addDirectML", 0) + || invokeNoArg(options, "addDirectML"); + + /* ── Intel ────────────────────────────────────────────── */ + case "openvino" -> + // addOpenVINO(String) — device type "GPU" preferred, "CPU" fallback + invokeStringArg(options, "addOpenVINO", "GPU") + || invokeStringArg(options, "addOpenVINO", "CPU") + || invokeNoArg(options, "addOpenVINO"); + + /* ── AMD ──────────────────────────────────────────────── */ + case "rocm" -> + invokeNoArg(options, "addROCM"); + default -> false; }; - } catch (Exception ex) { - if (!notes.isEmpty()) { - notes.append("; "); + if (!ok) { + failDetail = candidate + ": native library not found in classpath"; } - notes.append(candidate).append(" unavailable"); - return false; + } catch (Exception ex) { + String msg = ex.getMessage(); + if (msg == null && ex.getCause() != null) { msg = ex.getCause().getMessage(); } + failDetail = candidate + ": " + (msg != null ? msg : ex.getClass().getSimpleName()); } + + if (failDetail != null) { + if (!notes.isEmpty()) { notes.append("; "); } + notes.append(candidate).append(" not available"); + System.out.println("[LumenForge] WARN: \u2717 " + failDetail); + } + return failDetail; } + /* ── Reflection helpers for EP registration ──────────────────────── */ + private boolean invokeNoArg(OrtSession.SessionOptions options, String methodName) { try { options.getClass().getMethod(methodName).invoke(options); @@ -2677,13 +2782,34 @@ private boolean invokeIntArg(OrtSession.SessionOptions options, String methodNam } } + private boolean invokeLongArg(OrtSession.SessionOptions options, String methodName, long arg) { + try { + options.getClass().getMethod(methodName, long.class).invoke(options, arg); + return true; + } catch (Exception ex) { + return false; + } + } + + private boolean invokeStringArg(OrtSession.SessionOptions options, String methodName, String arg) { + try { + options.getClass().getMethod(methodName, String.class).invoke(options, arg); + return true; + } catch (Exception ex) { + return false; + } + } + private String providerDisplayName(String candidate) { return switch (candidate) { - case "cuda" -> "CUDAExecutionProvider"; - case "directml" -> "DmlExecutionProvider"; - case "coreml" -> "CoreMLExecutionProvider"; - case "rocm" -> "ROCMExecutionProvider"; - default -> "CPUExecutionProvider"; + case "tensorrt_rtx" -> "NvTensorRtRtxExecutionProvider"; + case "tensorrt" -> "TensorrtExecutionProvider"; + case "cuda" -> "CUDAExecutionProvider"; + case "coreml" -> "CoreMLExecutionProvider"; + case "directml" -> "DmlExecutionProvider"; + case "openvino" -> "OpenVINOExecutionProvider"; + case "rocm" -> "ROCMExecutionProvider"; + default -> "CPUExecutionProvider"; }; } diff --git a/src/main/java/atri/palaash/lumenforge/inference/ServiceFactory.java b/src/main/java/atri/palaash/lumenforge/inference/ServiceFactory.java index 00a0e0a..bc2c32d 100644 --- a/src/main/java/atri/palaash/lumenforge/inference/ServiceFactory.java +++ b/src/main/java/atri/palaash/lumenforge/inference/ServiceFactory.java @@ -25,31 +25,4 @@ public InferenceService create(TaskType taskType) { } return new GenericOnnxService(taskType, storage, executor); } - - /** - * Creates an inference service for the given task using the DJL PyTorch backend. - * Falls back to the ONNX backend if DJL is not available. - */ - public InferenceService createDjl(TaskType taskType) { - if (DjlPyTorchService.isAvailable()) { - return new DjlPyTorchService(storage, executor); - } - return create(taskType); - } - - /** - * Creates an inference service for the given model, automatically selecting - * the DJL backend for PyTorch models and ONNX backend for everything else. - */ - public InferenceService createForModel(String modelId, TaskType taskType) { - if (modelId != null && modelId.contains("pytorch")) { - return createDjl(taskType); - } - return create(taskType); - } - - /** Returns {@code true} if the DJL PyTorch runtime is available. */ - public static boolean isDjlAvailable() { - return DjlPyTorchService.isAvailable(); - } } diff --git a/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java b/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java index 6c16e7a..94b1a47 100644 --- a/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java +++ b/src/main/java/atri/palaash/lumenforge/storage/ModelDownloader.java @@ -56,6 +56,7 @@ public CompletableFuture download(ModelDescriptor descriptor, Consumer { Path target = storage.modelPath(descriptor); List writtenFiles = new ArrayList<>(); + System.out.println("[LumenForge] Download started: " + descriptor.displayName()); try { String downloadUrl = descriptor.sourceUrl(); if (!downloadUrl.toLowerCase(Locale.ROOT).contains(".onnx") @@ -71,8 +72,12 @@ public CompletableFuture download(ModelDescriptor descriptor, Consumer { if (!e.getValueIsAdjusting() && workflowList.getSelectedIndex() >= 0) { managementList.clearSelection(); @@ -201,7 +201,7 @@ public MainFrame(ModelRegistry registry, workflowList.setAlignmentX(Component.LEFT_ALIGNMENT); workflowSection.add(workflowList); - // Management section (pinned to bottom) + // Management section JPanel managementSection = new JPanel(); managementSection.setLayout(new BoxLayout(managementSection, BoxLayout.Y_AXIS)); managementSection.setOpaque(false); @@ -214,13 +214,21 @@ public MainFrame(ModelRegistry registry, managementSection.add(managementHeader); managementList.setAlignmentX(Component.LEFT_ALIGNMENT); managementSection.add(managementList); - managementSection.setBorder(BorderFactory.createEmptyBorder(0, 0, 12, 0)); - // Combine: workflow on top, management at bottom with glue between + // Diagnostics section removed + + // Bottom panel: management + JPanel bottomSection = new JPanel(); + bottomSection.setLayout(new BoxLayout(bottomSection, BoxLayout.Y_AXIS)); + bottomSection.setOpaque(false); + bottomSection.setBorder(BorderFactory.createEmptyBorder(0, 0, 12, 0)); + bottomSection.add(managementSection); + + // Combine: workflow on top, management at bottom JPanel sidebarContent = new JPanel(new BorderLayout()); sidebarContent.setOpaque(false); sidebarContent.add(workflowSection, BorderLayout.NORTH); - sidebarContent.add(managementSection, BorderLayout.SOUTH); + sidebarContent.add(bottomSection, BorderLayout.SOUTH); sidebarPanel.add(sidebarContent, BorderLayout.CENTER); @@ -322,11 +330,8 @@ private JMenuBar buildMenuBar() { private void switchToCard(String card, JList targetList, int index) { cardLayout.show(contentPanel, card); - if (targetList == workflowList) { - managementList.clearSelection(); - } else { - workflowList.clearSelection(); - } + if (targetList != workflowList) workflowList.clearSelection(); + if (targetList != managementList) managementList.clearSelection(); targetList.setSelectedIndex(index); } @@ -405,10 +410,8 @@ private String detectEpInfo() { } int cores = Runtime.getRuntime().availableProcessors(); long mem = Runtime.getRuntime().maxMemory() / (1024 * 1024); - String djl = atri.palaash.lumenforge.inference.DjlPyTorchService.isAvailable() - ? " \u2502 DJL \u2713" : ""; return "EP: " + ep + " \u2502 Cores: " + cores + " \u2502 Heap: " + mem - + " MB \u2502 Java " + Runtime.version() + djl; + + " MB \u2502 Java " + Runtime.version(); } /* ================================================================== */ diff --git a/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java b/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java index 1ba0f76..c77ae84 100644 --- a/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java +++ b/src/main/java/atri/palaash/lumenforge/ui/ModelManagerPanel.java @@ -203,6 +203,9 @@ private void startDownload(ModelDescriptor descriptor, int row) { tableModel.setAvailable(row, true); tableModel.updateProgress(descriptor.id(), 100); statusLabel.setText("Downloaded: " + path); + if (onModelsUpdated != null) { + onModelsUpdated.run(); + } })); } @@ -258,7 +261,12 @@ private void importModelFromFile(ModelDescriptor descriptor, int row) { tableModel.setAvailable(row, true); tableModel.updateProgress(descriptor.id(), 100); statusLabel.setText("Imported: " + target); + System.out.println("[LumenForge] Model imported: " + descriptor.displayName() + " from " + source); + if (onModelsUpdated != null) { + onModelsUpdated.run(); + } } catch (IOException ex) { + System.out.println("[LumenForge] ERROR: Import failed for " + descriptor.displayName() + ": " + ex.getMessage()); statusLabel.setText("Import failed: " + ex.getMessage()); } } @@ -313,6 +321,7 @@ private void startConversion(String modelId, String displayName, String mode) { statusLabel.setText("Converting " + displayName + " to ONNX\u2026"); progressBar.setVisible(true); progressBar.setIndeterminate(true); + System.out.println("[LumenForge] Starting PyTorch \u2192 ONNX conversion: " + modelId); // 5. Run on background thread final String conversionMode = mode; @@ -337,6 +346,7 @@ private void startConversion(String modelId, String displayName, String mode) { PyTorchToOnnxConverter.openPythonDownloadPage(); } } else { + System.out.println("[LumenForge] ERROR: Conversion failed for " + modelId + ": " + cause.getMessage()); statusLabel.setText("Conversion failed: " + cause.getMessage()); JOptionPane.showMessageDialog(this, "Conversion failed:\n" + cause.getMessage(), @@ -385,6 +395,7 @@ private void registerConvertedModel(String modelId, String sanitized, Path outpu modelRegistry.mergeDownloadableAssets(List.of(desc)); refreshTable(); statusLabel.setText("\u2713 Converted and registered: " + modelId); + System.out.println("[LumenForge] Conversion complete: " + modelId + " \u2192 " + relativePath); if (onModelsUpdated != null) { onModelsUpdated.run(); }