Skip to content

[feat] JoyAI-JoyImage-Edit support#13444

Open
Moran232 wants to merge 6 commits intohuggingface:mainfrom
Moran232:joyimage_edit
Open

[feat] JoyAI-JoyImage-Edit support#13444
Moran232 wants to merge 6 commits intohuggingface:mainfrom
Moran232:joyimage_edit

Conversation

@Moran232
Copy link
Copy Markdown

@Moran232 Moran232 commented Apr 10, 2026

Description

We are the JoyAI Team, and this is the Diffusers implementation for the JoyAI-Image-Edit model.

GitHub Repository: [https://github.com/jd-opensource/JoyAI-Image]
Hugging Face Model: [https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers]
Original opensource weights [https://huggingface.co/jdopensource/JoyAI-Image-Edit]
Fixes #13430

Model Overview

JoyAI-Image is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT).

Kye Features

  • Advanced Text Rendering Showcase: JoyAI-Image is optimized for challenging text-heavy scenarios, including multi-panel comics, dense multi-line text, multilingual typography, long-form layouts, real-world scene text, and handwritten styles.
  • Multi-view Generation and Spatial Editing Showcase: JoyAI-Image showcases a spatially grounded generation and editing pipeline that supports multi-view generation, geometry-aware transformations, camera control, object rotation, and precise location-specific object editing. Across these settings, it preserves scene content, structure, and visual consistency while following viewpoint-sensitive instructions more accurately.
  • Spatial Editing for Spatial Reasoning Showcase: JoyAI-Image poses high-fidelity spatial editing, serving as a powerful catalyst for enhancing spatial reasoning. Compared with Qwen-Image-Edit and Nano Banana Pro, JoyAI-Image-Edit synthesizes the most diagnostic viewpoints by faithfully executing camera motions. These high-fidelity novel views effectively disambiguate complex spatial relations, providing clearer visual evidence for downstream reasoning.

Image edit examples

spatial-editing-showcase

@github-actions github-actions bot added models pipelines size/L PR with diff > 200 LOC labels Apr 10, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR! I left some initial feedbacks

return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))


class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh what's going on here? is this some legancy code? can we remove?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We first developed JoyImage, and then trained JoyImage-Edit based on it. This Transformer 3D model belongs to JoyImage, and JoyImage-Edit is inherited from JoyImage. We will also open-source JoyImage in the future.

They essentially share similar Transformer 3D models. I understand that each pipeline requires a specific Transformer model, which is why we implemented inheritance in this way.

Comment on lines +371 to +391
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)

txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False)

q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)

attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False)
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
attn_output, text_attn_output = self.attn(...)

can we refactor the attention implementation to follow diffusers style?
basically you need to move all the layers used in attention calculation here into a JoyImageAttention (similar to FluxAttention https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L275)

also create a JoyImageAttnProcessor (see FluxAttnProcessor as example, I think it is same) https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L75 )

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I'll clean up this messy code.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in d397b68

Comment on lines +242 to +250
class ModulateX(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor

def forward(self, x: torch.Tensor):
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ModulateX(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor
def forward(self, x: torch.Tensor):
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)]

Comment on lines +214 to +225
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.factor = factor
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)

def forward(self, x: torch.Tensor):
return self.linear(self.act(x)).chunk(self.factor, dim=-1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.factor = factor
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor):
return self.linear(self.act(x)).chunk(self.factor, dim=-1)

is ModulateWan is one used in the model? if so let's remove the ModulateDit and ModulateX

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in f557113

head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)

self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs)
self.img_mod = JoyImageModulate(...)

let's remove the load_modulation function and use the layer directly, better to rename to JoyImageModulate too

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will refactor modulation and use ModulateWan

tacos8me added a commit to tacos8me/taco-desktop-backend that referenced this pull request Apr 11, 2026
New `model="joyai-edit"` on /v1/image-edit and /v2/image-edit, routed to a
separate FastAPI sidecar on 127.0.0.1:8092 that runs JoyImageEditPipeline
from the Moran232/diffusers fork + transformers 4.57.1. Process isolation
needed because the fork's diffusers core registry patches cannot be
vendored (PR huggingface/diffusers#13444 pending) and transformers 4.57.x
is incompatible with our 5.3.0 stack.

Phase 0 VRAM measurement: 50.3 GB resident, 65.5 GB peak reserved at
1024² / 30 steps (well under the 80 GB gate). Passed.

- `joyai_client.py` (NEW, 167 lines): thin httpx wrapper with per-call
  short-lived AsyncClient, split timeouts (180s edit / 60s mgmt),
  HTTPStatus→JoyAIError mapping. Singleton `joyai` exported.
- `config.py`: `JOYAI_SIDECAR_URL` (default http://127.0.0.1:8092) and
  `LOAD_JOYAI` env flag. Off by default.
- `server.py`: three-tenant swap protocol replaces the two-tenant v1.1.4
  helpers. New `_last_gpu_tenant` tracker + `_evict_other_tenants(new)`
  helper. All three `_ensure_*_ready()` helpers are now `async def` —
  13 call sites updated across _dispatch_job and v1 sync handlers.
  IMAGE_EDIT dispatch arm routes `model=="joyai-edit"` to joyai_client;
  validates len(image_paths)==1 (422 otherwise). Lifespan health-probes
  the sidecar when LOAD_JOYAI=1 (non-blocking — joyai-edit returns 503
  if unreachable).
- `flux_manager.py`: pre-existing bug fix — _edit() hardcoded
  ensure_model("flux2-klein"), silently ignoring the dispatcher's
  `model` kwarg. Now accepts and respects `model`. Guidance_scale
  is now conditional on model != "flux2-klein" (Klein strips CFG,
  Dev uses it).
- `tests/test_joyai_client.py` (NEW, 7 tests) + `tests/test_validation.py`
  (+3 tests): 89 tests passing (was 79).
- Docs: API.md, QUICKSTART.md, README.md, CLAUDE.md, AGENTS.md all
  updated with joyai-edit model entry, three-tenant swap diagram,
  latency table, sidecar location/port, LOAD_JOYAI env var, v1.1.8
  changelog entry.

Out-of-tree (not committed here, installed separately):
  /mnt/nvme-1/servers/joyai-sidecar/     (sidecar venv + sidecar.py + run.sh)
  ~/.config/systemd/user/joyai-sidecar.service

Smoke-tested end-to-end: upload → /v2/image-edit joyai-edit →
SSE stream (phase denoising → encoding → None) → fetch WEBP result
(352 KB, 91 s wall clock for 20 steps at 1024²). Three-tenant swap
evicted LTX and reloaded it cleanly via _evict_other_tenants.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
) # reshape
return output

class RMSNorm(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse the existing diffusers.models.normalization.RMSNorm implementation here? It should already implement the FP32 upcast:

variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in f557113

return img, txt


class WanTimeTextImageEmbedding(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class WanTimeTextImageEmbedding(nn.Module):
# Copied from diffusers.models.transfomers.transformer_wan.WanTimeTextImageEmbedding
class WanTimeTextImageEmbedding(nn.Module):

Is this intended to be identical to the Wan implementation? If so, we can add a # Copied from statement here to ensure the two implementations are synced.

Copy link
Copy Markdown
Author

@Moran232 Moran232 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I modify it as 'import from wanxxx‘

Comment on lines +454 to +459
self.args = SimpleNamespace(
enable_activation_checkpointing=enable_activation_checkpointing,
is_repa=is_repa,
repa_layer=repa_layer,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.args = SimpleNamespace(
enable_activation_checkpointing=enable_activation_checkpointing,
is_repa=is_repa,
repa_layer=repa_layer,
)

I think we can use self.config here (e.g. self.config.is_repa, self.config.repa_layer, etc.) instead of needing to define a separate namespace.

Copy link
Copy Markdown
Author

@Moran232 Moran232 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I delete this repaxxx, see f557113

tokenizer: Qwen2Tokenizer,
transformer: JoyImageEditTransformer3DModel,
processor: Qwen3VLProcessor,
args: Any = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to #13444 (comment), I think it would be better if we had individual pipeline arguments here instead of a separate namespace, e.g. something like

class JoyImageEditPipeline(DiffusionPipeline):
    ...
    def __init__(
        self,
        ...,
        enable_multi_task_training: bool = False,
        text_token_max_length: int = 2048,
        ...,
    ):
        ...
        self.enable_multi_task_training = enable_multi_task_training
        self.text_token_max_length = text_token_max_length
        ...

Comment on lines +900 to +901
timesteps: List[int] = None,
sigmas: List[float] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
timesteps: List[int] = None,
sigmas: List[float] = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,

nit: could we switch to Python 3.9+ style implicit type hints here and elsewhere?

Comment on lines +1003 to +1004
height, width = _dynamic_resize_from_bucket(image_size, basesize=1024)
processed_image = _resize_center_crop(image, (height, width))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be cleaner to refactor the image pre-processing logic into a separate VaeImageProcessor subclass (which self.image_processor would then be an instance of). See WanAnimateImageProcessor for an example:

class WanAnimateImageProcessor(VaeImageProcessor):

CC @yiyixuxu

]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
enable_tiling: bool = False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to support an enable_tiling argument here as users can always call pipe.vae.enable_tiling().

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in 9d78e4e

latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
t_expand = t.repeat(latent_model_input.shape[0])

with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should handle the device placement and dtypes explicitly here instead of using torch.autocast.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in 9d78e4e

if enable_denormalization:
latents = self.denormalize_latents(latents)

with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as #13444 (comment).

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left an initial design review :).

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026
@Moran232
Copy link
Copy Markdown
Author

@yiyixuxu @dg845
Thank you very much for your valuable feedback. I've made some modifications. See my latest commits.

Specifically, I refactored the attention module. However, since the weight key names in the Diffusers model are already fixed, I didn't change the actual keys in the attention part. Additionally, I will consider refactoring the image pre-processing logic, since the logic is quite complex, I directly copied it over from the training code.

If you have any further suggestions, please feel free to share. Thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models pipelines size/L PR with diff > 200 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for JoyAI-Image-Edit

3 participants