Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions 02_ml_inference/02_text_to_image/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# RUNPOD_API_KEY=your_api_key_here
# FLASH_HOST=localhost
# FLASH_PORT=8888
# LOG_LEVEL=INFO
43 changes: 43 additions & 0 deletions 02_ml_inference/02_text_to_image/.flashignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Flash Build Ignore Patterns

# Python cache
__pycache__/
*.pyc

# Virtual environments
venv/
.venv/
env/

# IDE
.vscode/
.idea/

# Environment files
.env
.env.local

# Git
.git/
.gitignore

# Build artifacts
dist/
build/
*.egg-info/

# Flash resources
.flash_resources.pkl

# Tests
tests/
test_*.py
*_test.py

# Documentation
docs/
*.md
!README.md

# Demo output
generated.png
27 changes: 27 additions & 0 deletions 02_ml_inference/02_text_to_image/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Python
__pycache__/
*.pyc
*.pyo
*.egg-info/
dist/
build/

# Virtual environments
.venv/
venv/
env/

# Environment
.env
.env.local

# Flash
.flash_resources.pkl
.tetra_resources.pkl

# IDE
.vscode/
.idea/

# Demo output
generated.png
Empty file.
149 changes: 149 additions & 0 deletions 02_ml_inference/02_text_to_image/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#!/usr/bin/env python3
"""
Flash Demo — Generate an image with Flux and display it in your terminal.

Usage:
1. Start the server: cd 02_ml_inference/02_text_to_image && flash run
2. Run this script: python demo.py
3. Or with a prompt: python demo.py "a cat astronaut on mars"
"""

import base64
import io
import json
import os
import shutil
import subprocess
import sys
import time
import urllib.error
import urllib.request

API_URL = "http://localhost:8888/gpu/generate"
DEFAULT_PROMPT = "a tiny astronaut floating above earth, watercolor style"
OUTPUT_FILE = "generated.png"

# ── Terminal image rendering ─────────────────────────────────────────


def render_in_terminal(image_bytes: bytes, max_width: int | None = None):
"""Render an image in the terminal using ANSI true-color half-blocks.

Works in any terminal that supports 24-bit color (iTerm2, Kitty,
WezTerm, Windows Terminal, most modern terminals).
"""
from PIL import Image

img = Image.open(io.BytesIO(image_bytes)).convert("RGB")

# Fit to terminal width
term_width = max_width or min(shutil.get_terminal_size().columns, 80)
aspect = img.height / img.width
w = term_width
h = int(w * aspect)
if h % 2 != 0:
h += 1

img = img.resize((w, h), Image.LANCZOS)
px = img.load()

lines = []
for y in range(0, h, 2):
row = []
for x in range(w):
r1, g1, b1 = px[x, y]
r2, g2, b2 = px[x, y + 1] if y + 1 < h else (0, 0, 0)
row.append(f"\033[38;2;{r1};{g1};{b1}m\033[48;2;{r2};{g2};{b2}m▀")
lines.append("".join(row) + "\033[0m")

print("\n".join(lines))


def try_imgcat(image_bytes: bytes) -> bool:
"""Try to display via imgcat (iTerm2) or chafa."""
for cmd in ("imgcat", "chafa", "viu"):
if shutil.which(cmd):
try:
proc = subprocess.run(
[cmd, "-"],
input=image_bytes,
timeout=5,
)
return proc.returncode == 0
except Exception:
continue
return False


def display_image(image_bytes: bytes):
"""Display an image in the terminal with the best available method."""
# Try native image tools first (high-res)
if try_imgcat(image_bytes):
return

# Fall back to ANSI half-block rendering (works everywhere)
render_in_terminal(image_bytes)


# ── Main ─────────────────────────────────────────────────────────────


def main():
prompt = " ".join(sys.argv[1:]) if len(sys.argv) > 1 else DEFAULT_PROMPT

print()
print(" ⚡ Flash Demo — Flux Text-to-Image")
print(" ─────────────────────────────────────")
print(f' Prompt: "{prompt}"')
print(f" Server: {API_URL}")
print()

# Build request
hf_token = os.environ.get("HF_TOKEN", "")
payload = json.dumps({"prompt": prompt, "hf_token": hf_token}).encode()
req = urllib.request.Request(
API_URL,
data=payload,
headers={"Content-Type": "application/json"},
)

# Send request with timing
print(" Sending to RunPod GPU worker...", end="", flush=True)
t0 = time.time()

try:
resp = urllib.request.urlopen(req, timeout=300)
except urllib.error.URLError as e:
print(f"\n\n Error: Could not connect to {API_URL}")
print(" Make sure the Flash server is running: flash run")
print(f" ({e})")
sys.exit(1)

result = json.loads(resp.read())
elapsed = time.time() - t0

if result.get("status") != "success":
print(f"\n\n Error from worker: {result}")
sys.exit(1)

# Decode image
image_bytes = base64.b64decode(result["image_base64"])
size_kb = len(image_bytes) / 1024

print(f" done! ({elapsed:.1f}s)")
print(f" Image: {result.get('width')}x{result.get('height')}px, {size_kb:.0f}KB")
print()

# Save to disk
with open(OUTPUT_FILE, "wb") as f:
f.write(image_bytes)
print(f" Saved to {OUTPUT_FILE}")
print()

# Display in terminal
display_image(image_bytes)
print()


if __name__ == "__main__":
main()
137 changes: 137 additions & 0 deletions 02_ml_inference/02_text_to_image/gpu_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Flux Text-to-Image — GPU Worker

One warm worker. Cached FLUX pipeline.
"""

import os

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from runpod_flash import GpuGroup, LiveServerless, remote

# ── GPU Configuration ────────────────────────────────────────────────
# FLUX.1-schnell is a fast distilled model (~12GB VRAM).
# ADA_24 gives us an RTX 4090-class GPU with 24GB — plenty of room.
gpu_config = LiveServerless(
name="02_02_flux_schnell",
gpus=[GpuGroup.ADA_24],
workersMin=1,
workersMax=3,
idleTimeout=5,
)


@remote(
resource_config=gpu_config,
dependencies=[
"diffusers",
"torch",
"transformers",
"accelerate",
"sentencepiece",
"protobuf",
],
)
class FluxWorker:
"""Warm FLUX worker that caches the pipeline between requests."""

def __init__(self):
import torch

self._torch = torch
self._model_name = "black-forest-labs/FLUX.1-schnell"
self._pipe = None

def _ensure_pipeline(self, hf_token: str):
from diffusers import FluxPipeline
from huggingface_hub import login

if self._pipe is not None:
return

if hf_token:
login(token=hf_token)

self._pipe = FluxPipeline.from_pretrained(
self._model_name,
torch_dtype=self._torch.bfloat16,
)
self._pipe.enable_model_cpu_offload()

async def generate(self, input_data: dict) -> dict:
import base64
import io

hf_token = input_data.get("hf_token", "")
prompt = input_data.get("prompt", "a lightning flash above a datacenter")
width = int(input_data.get("width", 512))
height = int(input_data.get("height", 512))
num_steps = int(input_data.get("num_steps", 4))

try:
self._ensure_pipeline(hf_token=hf_token)
image = self._pipe(
prompt,
num_inference_steps=num_steps,
width=width,
height=height,
guidance_scale=0.0,
).images[0]
except Exception as exc:
return {"status": "error", "error": f"Image generation failed: {exc}"}

buf = io.BytesIO()
image.save(buf, format="PNG")
buf.seek(0)

return {
"status": "success",
"image_base64": base64.b64encode(buf.read()).decode(),
"prompt": prompt,
"width": width,
"height": height,
}


# ── FastAPI Router ───────────────────────────────────────────────────
gpu_router = APIRouter()
worker: FluxWorker | None = None


def get_worker() -> FluxWorker:
global worker
if worker is None:
worker = FluxWorker()
return worker


class ImageRequest(BaseModel):
prompt: str = Field(
default="a tiny astronaut floating in space, watercolor style",
description="Text prompt describing the image to generate",
)
width: int = Field(default=512, description="Image width in pixels")
height: int = Field(default=512, description="Image height in pixels")
num_steps: int = Field(default=4, description="Number of diffusion steps (1-8)")
hf_token: str = Field(
default="",
description="Optional Hugging Face token. Uses HF_TOKEN env var when omitted.",
)


@gpu_router.post("/generate")
async def generate(request: ImageRequest):
"""Generate an image from a text prompt using FLUX.1-schnell."""
hf_token = request.hf_token.strip() or os.environ.get("HF_TOKEN", "")
result = await get_worker().generate(
{
"prompt": request.prompt,
"width": request.width,
"height": request.height,
"num_steps": request.num_steps,
"hf_token": hf_token,
}
)
if result.get("status") != "success":
raise HTTPException(status_code=400, detail=result.get("error", "Image generation failed"))
return result
Loading