A deep learning system for recolorizing images using a specified color palette, built as part of the Neural Networks and Deep Learning module at NUS. The project extends the PaletteNet architecture with attention layers and includes a multi-agent LLM-powered deployment pipeline.
Given an image and a color palette, recolorize the image with the palette in a way that is visually harmonious and aesthetically pleasing.
For setting up and experimenting with this repository, you can refer to Setup File .
- Model Architecture
- Project Structure
- Setup
- Training
- Inference
- Deployment
- Tech Stack
- Limitations
- Applications
We extended the PaletteNet architecture by incorporating additional attention layers into both the Encoder and Decoder components. We represent the target palette as an image, enabling support for variable palette sizes while ensuring illumination adjustments are applied exclusively to the palette colors.
Forward pass: Source image -> Encoder (ResNet + self-attention) -> multi-scale features -> Decoder (cross-attention with palette conditioning) -> Recolorized image
- Input: Source image (LAB, 256x256), target palette (4x24x3 tensor, 6 colors in LAB), illumination (L channel)
- Encoder (
encoder_v3.py): ResNet-based feature extraction with self-attention, outputs 4 multi-scale feature maps - Decoder (
decoder.py): Progressive upsampling with cross-attention at each stage, conditioning on the palette embedding - Output: 3-channel LAB image
We used Accelerate to train on a single A100 GPU. Images were resized to 256x256 to fit GPU memory. Training time was approximately 3 hours.
Recolorization/
├── src/ # Training source code
│ ├── custom_model/ # Main model implementation
│ │ ├── model.py # RecolorizerModel (encoder + decoder)
│ │ ├── encoder_v3.py # Feature encoder with self-attention
│ │ ├── decoder.py # Decoder with cross-attention
│ │ ├── attention.py # Self and cross-attention modules
│ │ ├── data.py # Dataset class (RecolorizeDataset)
│ │ ├── train_recolor.py # Trainer class
│ │ ├── run_recolor_training.py # Training entry point
│ │ ├── train_gpu.sh # GPU training launch script
│ │ └── requirements.txt
│ └── common_utils/ # Shared utilities
│ ├── configs/ # Accelerate configs (GPU/CPU)
│ └── train_utils/ # W&B logging
│
├── src_infer/ # Standalone inference testing
│ └── custom_model/
│ ├── test_model.py # Inference test script
│ └── benchmark_cpu_*.json # CPU benchmarking results
│
├── deployments/
│ ├── inference/ # FastAPI backend + agent system
│ │ ├── infer.py # Core inference utilities
│ │ ├── checkpoint/ # Model checkpoint directory
│ │ └── agents/
│ │ ├── server.py # FastAPI app + CORS + session cleanup
│ │ ├── agent_api.py # REST + WebSocket + SSE endpoints
│ │ ├── graph.py # LangGraph compiled state machine
│ │ ├── state.py # RecolorState TypedDict
│ │ ├── routing.py # Conditional edge routing functions
│ │ ├── session.py # In-memory session store (1-hour TTL)
│ │ ├── nodes/
│ │ │ ├── chat_agent.py # Conversational LLM + intent classifier
│ │ │ ├── input_analyzer.py # LLM-based dispatch + slot override detection
│ │ │ ├── image_agent.py # Image validation (PNG/JPEG/WEBP/AVIF/GIF)
│ │ │ ├── palette_agent.py # LLM tool-calling palette generator
│ │ │ ├── slot_checker.py # Checks image + palette readiness
│ │ │ └── recolor_agent.py # Model inference runner
│ │ ├── tools/
│ │ │ ├── palette_formation.py # generate_from_description, get_random, parse_user_colors
│ │ │ ├── palette_utils.py # Hex display, variation helpers
│ │ │ ├── color_extraction.py # ColorThief / Pylette extraction
│ │ │ └── colormind.py # Colormind API client
│ │ ├── tests/
│ │ │ ├── test_graph_flow.py
│ │ │ ├── test_palette_agent.py
│ │ │ ├── interactive_chat.py # Terminal REPL for direct graph testing
│ │ │ └── helpers.py
│ │ └── frontend/ # React (Vite) chat UI
│ │ ├── src/
│ │ │ ├── App.jsx
│ │ │ ├── api.js # sendChatStreaming, selectPalette
│ │ │ └── components/
│ │ │ ├── ProgressLog.jsx # Real-time pipeline log panel
│ │ │ ├── MessageBubble.jsx
│ │ │ ├── LeftPanel.jsx # Image upload + palette builder
│ │ │ ├── SidePanel.jsx # Result + palette display
│ │ │ ├── PaletteStrip.jsx
│ │ │ ├── PaletteCandidates.jsx
│ │ │ ├── InputBar.jsx
│ │ │ └── ColorWheel.jsx
│ │ └── package.json
│ └── streamlit_app/ # Interactive Streamlit UI (legacy)
│ ├── streamlit_app.py
│ └── requirements_deploy.txt
│
├── datasets/ # Training/test data (DVC-managed)
│ └── processed_palettenet_data_sample_v4/
├── assets/ # Architecture diagrams
├── Setup.md # Detailed setup instructions
└── README.md
- Python 3.12
- MiniConda (recommended)
- CUDA 12.x (for GPU training)
- Ollama with
llama3.1:8b(for the agent system)
conda create -n recolor python=3.12
conda activate recolorDownload the pretrained model checkpoint from Google Drive and place it in the appropriate directory depending on your use case:
- Training/testing:
src_infer/custom_model/ - Streamlit app:
deployments/streamlit_app/ - FastAPI backend:
deployments/inference/checkpoint/checkpoint_epoch_90.pt
For detailed setup instructions, refer to Setup.md.
# Pull the dataset
dvc pull datasets/processed_palettenet_data_sample_v4
# Install training dependencies
cd src/custom_model
pip install -r requirements.txt
# Visualize the data (optional)
python data.py
# Launch training
./train_gpu.shTraining configuration (via train_gpu.sh):
- Batch size: 8 (train), 4 (validation)
- Learning rate: 0.0002 (Adam)
- Loss: MSE (L2)
- Epochs: 1000 with checkpointing every 5 epochs
- Hardware: Single A100 GPU with FP16 mixed precision
- Experiment tracking: Weights & Biases
Checkpoints are saved to src/custom_model/recolor_model_ckpts/.
The training entry point is train_gpu.sh, which launches training via HuggingFace Accelerate. The flow is:
run_recolor_training.py- Initializes the trainer and starts trainingtrain_recolor.py- Trainer class with the training loopdata.py- Dataset class that loads images, converts to LAB color space, and prepares palette tensorsmodel.py- Main model combining encoder (encoder_v3.py) and decoder (decoder.py)
cd src_infer/custom_model
pip install torch torchvision scikit-image
python test_model.pyResults are saved to src_infer/custom_model/test_results/.
An interactive web UI for uploading images, picking 6 colors, and generating recolorized results.
pip install -r deployments/streamlit_app/requirements_deploy.txt
pip install watchdog
cd deployments/streamlit_app
streamlit run streamlit_app.pyOpens at http://localhost:8501. Images are resized to max 350x350 (rounded to nearest 16x16) for inference.
A multi-agent system built with LangGraph that provides a conversational interface for recolorization. It uses Ollama (Llama 3.1:8b) for natural language understanding and palette generation.
User Message
|
v
chat_agent --> input_analyzer --> [routes to one or more agents]
| | |
v v v
image_agent palette_agent chat_agent
| | |
+-----+------+-----------+
|
v
join_slots (checks if image + palette ready)
|
+------+------+
| |
v v
recolor_agent respond
| |
+------+------+
|
v
respond --> User
Agents:
- input_analyzer: Deterministic intent detection (upload image, set palette, describe palette, extract colors, recolor, etc.)
- image_agent: Validates and stores uploaded images (format, size checks)
- palette_agent: Generates palettes via LLM tool-calling -- extract from images (ColorThief/Pylette), generate from text descriptions, fetch from Colormind API, parse hex/RGB input, create variations (warmer, cooler, bold, subtle, complementary, etc.)
- recolor_agent: Runs the recolorization model inference
- respond: Formats the final response with text, palette, and result image
cd deployments/inference
pip install -r requirements.txt
# Ensure Ollama is running with llama3.1:8b
ollama pull llama3.1:8b
python -m uvicorn agents.api:app --reload --host 0.0.0.0 --port 8000API available at http://localhost:8000 (Swagger docs at /docs).
Endpoints:
| Method | Path | Description |
|---|---|---|
| POST | /chat |
Send a message (with optional image) and get a response with palette suggestions or recolorized result |
| GET | /health |
Health check |
For LangSmith tracing, create a .env file:
LANGSMITH_TRACING=true
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
LANGSMITH_API_KEY=<your_key>
LANGSMITH_PROJECT="recolor-workflow"
| Category | Technologies |
|---|---|
| Deep Learning | PyTorch 2.5.1, TorchVision |
| Training | HuggingFace Accelerate, Weights & Biases |
| Color Processing | scikit-image (LAB conversion), ColorThief, Pylette |
| Agent Framework | LangGraph, LangChain, LangSmith |
| LLM | Ollama (Llama 3.1:8b) |
| Backend | FastAPI, Uvicorn |
| Frontend | Streamlit |
| Data Management | DVC |
| Palette APIs | Colormind |
- Inference resolution: Attention layers increase memory usage, limiting CPU inference to ~256x256 images
- Single GPU training: Current setup supports only one A100 GPU
- Palette size: Fixed at 6 colors for the standard model (variable palette support is experimental)
- Marketing - Ensure brand assets follow specific color palettes
- Gaming & Animation - Recolor game assets, characters, and environments for different themes
- Education & Research - Experiment with color theory and simulate artistic effects
- Design Tools - Rapid color iteration in design workflows

