Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ use diffusion_rs::{
preset::{
Anima2Weight, AnimaWeight, ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight,
ErnieImageWeight, Flux1MiniWeight, Flux1Weight, Flux2Klein4BWeight, Flux2Klein9BWeight,
Flux2KleinBase4BWeight, Flux2KleinBase9BWeight, Flux2Weight, NitroSDRealismWeight,
NitroSDVibrantWeight, OvisImageWeight, Preset, PresetBuilder, PresetDiscriminants,
QwenImageWeight, SDXS512DreamShaperWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight,
WeightType, ZImageTurboWeight,
Flux2KleinBase4BWeight, Flux2KleinBase9BWeight, Flux2Weight, LongCatImageWeight,
NitroSDRealismWeight, NitroSDVibrantWeight, OvisImageWeight, Preset, PresetBuilder,
PresetDiscriminants, QwenImageWeight, SDXS512DreamShaperWeight, SSD1BWeight,
TwinFlowZImageTurboExpWeight, WeightType, ZImageTurboWeight,
},
util::set_hf_token,
};
Expand Down Expand Up @@ -412,6 +412,12 @@ fn get_preset(args: &Args) -> Preset {
),
PresetDiscriminants::HiDreamO1ImageDev => Preset::HiDreamO1ImageDev,
PresetDiscriminants::HiDreamO1Image => Preset::HiDreamO1Image,
PresetDiscriminants::LongCatImage => Preset::LongCatImage(
args.weights
.unwrap_or_else(|| LongCatImageWeight::default().into())
.try_into()
.unwrap(),
),
};
preset
}
35 changes: 34 additions & 1 deletion src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ pub struct HiresParams {
/// highres fix second pass denoising strength (default: 0.7)
#[builder(default = "0.7")]
denoising_strength: f32,
/// Custom sigma values for the highres fix second pass
#[builder(default = "None")]
hires_sigmas: Option<Vec<f32>>,
}

/// Config struct for a specific diffusion model
Expand Down Expand Up @@ -561,6 +564,10 @@ pub struct ModelConfig {
#[builder(default = "true")]
vae_temporal_tiling: bool,

/// Extra VAE tiling args, key=value list. LTX video VAE supports
#[builder(default = "(None, CLibString::default())", setter(custom))]
extra_tiling_args: (Option<HashMap<String, String>>, CLibString),

#[builder(default = "None", private)]
upscaler_ctx: Option<*mut upscaler_ctx_t>,

Expand Down Expand Up @@ -711,6 +718,22 @@ impl ModelConfigBuilder {
self.params_backend = Some((Some(backend_map), CLibString::from(params_backend_str)));
self
}

pub fn extra_tiling_args(
&mut self,
extra_tiling_args_map: HashMap<String, String>,
) -> &mut Self {
let extra_tiling_args_str = extra_tiling_args_map
.iter()
.map(|(key, value)| format!("{}={}", key, value))
.collect::<Vec<String>>()
.join(",");
self.extra_tiling_args = Some((
Some(extra_tiling_args_map),
CLibString::from(extra_tiling_args_str),
));
self
}
}

impl ModelConfig {
Expand Down Expand Up @@ -884,7 +907,8 @@ impl From<&ModelConfig> for ModelConfigBuilder {
)
.extra_sample_params(value.extra_sample_params.clone())
.backend(value.backend.0.clone().unwrap_or_default())
.params_backend(value.params_backend.0.clone().unwrap_or_default());
.params_backend(value.params_backend.0.clone().unwrap_or_default())
.extra_tiling_args(value.extra_tiling_args.0.clone().unwrap_or_default());

builder.lora_models_internal(value.lora_models.clone());

Expand Down Expand Up @@ -1368,6 +1392,7 @@ fn gen_img_maybe_progress(
rel_size_x: model_config.vae_relative_tile_size.0,
rel_size_y: model_config.vae_relative_tile_size.1,
temporal_tiling: model_config.vae_temporal_tiling,
extra_tiling_args: model_config.extra_tiling_args.1.as_ptr(),
};
let pm_params = sd_pm_params_t {
id_images: null_mut(),
Expand Down Expand Up @@ -1506,9 +1531,15 @@ fn gen_img_maybe_progress(
}

let mut hires_path = null();
let mut hires_sigmas = null_mut();
let mut hires_sigmas_count = 0;
if let Some(path) = &model_config.hires_params.2 {
hires_path = path.as_ptr();
}
if let Some(sigmas) = &mut model_config.hires_params.1.hires_sigmas {
hires_sigmas = sigmas.as_mut_ptr();
hires_sigmas_count = sigmas.len() as i32;
}

let hires = sd_hires_params_t {
enabled: model_config.hires_params.0 != Upscaler::SD_HIRES_UPSCALER_NONE,
Expand All @@ -1520,6 +1551,8 @@ fn gen_img_maybe_progress(
steps: model_config.hires_params.1.steps,
denoising_strength: model_config.hires_params.1.denoising_strength,
upscale_tile_size: model_config.hires_params.1.upscale_tile_size,
custom_sigmas: hires_sigmas,
custom_sigmas_count: hires_sigmas_count,
};

let sd_img_gen_params = sd_img_gen_params_t {
Expand Down
66 changes: 49 additions & 17 deletions src/preset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ use crate::{
anima, anima2, chroma, chroma_radiance, diff_instruct_star, dream_shaper_xl_2_1_turbo,
ernie_image, ernie_image_turbo, flux_1_dev, flux_1_mini, flux_1_schnell, flux_2_dev,
flux_2_klein_4b, flux_2_klein_9b, flux_2_klein_base_4b, flux_2_klein_base_9b,
hi_dream_o1_image, hi_dream_o1_image_dev, juggernaut_xl_11, nitro_sd_realism,
nitro_sd_vibrant, ovis_image, qwen_image, sd_turbo, sdxl_base_1_0, sdxl_turbo_1_0,
sdxs512_dream_shaper, segmind_vega, ssd_1b, stable_diffusion_1_4, stable_diffusion_1_5,
stable_diffusion_2_1, stable_diffusion_3_5_large, stable_diffusion_3_5_large_turbo,
stable_diffusion_3_5_medium, stable_diffusion_3_medium, twinflow_z_image_turbo,
z_image_turbo,
hi_dream_o1_image, hi_dream_o1_image_dev, juggernaut_xl_11, long_cat_image,
nitro_sd_realism, nitro_sd_vibrant, ovis_image, qwen_image, sd_turbo, sdxl_base_1_0,
sdxl_turbo_1_0, sdxs512_dream_shaper, segmind_vega, ssd_1b, stable_diffusion_1_4,
stable_diffusion_1_5, stable_diffusion_2_1, stable_diffusion_3_5_large,
stable_diffusion_3_5_large_turbo, stable_diffusion_3_5_medium, stable_diffusion_3_medium,
twinflow_z_image_turbo, z_image_turbo,
},
};

Expand All @@ -41,7 +41,8 @@ use crate::{
AnimaWeight(derive(Default)),
Anima2Weight(derive(Default)),
SDXS512DreamShaperWeight(derive(Default)),
ErnieImageWeight(derive(Default))
ErnieImageWeight(derive(Default)),
LongCatImageWeight(derive(Default))
)]
#[derive(Debug, Clone, Copy, EnumString, VariantNames)]
#[strum(ascii_case_insensitive)]
Expand Down Expand Up @@ -74,10 +75,17 @@ pub enum WeightType {
Flux2Klein9BWeight(default),
Flux2KleinBase9BWeight(default),
AnimaWeight,
ErnieImageWeight(default)
ErnieImageWeight(default),
LongCatImageWeight(default)
)]
Q4_0,
#[subenum(Flux2Weight, QwenImageWeight, AnimaWeight, ErnieImageWeight)]
#[subenum(
Flux2Weight,
QwenImageWeight,
AnimaWeight,
ErnieImageWeight,
LongCatImageWeight
)]
Q4_1,
#[subenum(
NitroSDRealismWeight,
Expand All @@ -88,10 +96,17 @@ pub enum WeightType {
QwenImageWeight,
TwinFlowZImageTurboExpWeight,
AnimaWeight,
ErnieImageWeight
ErnieImageWeight,
LongCatImageWeight
)]
Q5_0,
#[subenum(Flux2Weight, QwenImageWeight, AnimaWeight, ErnieImageWeight)]
#[subenum(
Flux2Weight,
QwenImageWeight,
AnimaWeight,
ErnieImageWeight,
LongCatImageWeight
)]
Q5_1,
#[subenum(
Flux1Weight,
Expand All @@ -112,7 +127,8 @@ pub enum WeightType {
AnimaWeight(default),
Anima2Weight(default),
SDXS512DreamShaperWeight,
ErnieImageWeight
ErnieImageWeight,
LongCatImageWeight
)]
Q8_0,
Q8_1,
Expand All @@ -139,7 +155,8 @@ pub enum WeightType {
QwenImageWeight,
TwinFlowZImageTurboExpWeight,
AnimaWeight,
ErnieImageWeight
ErnieImageWeight,
LongCatImageWeight
)]
Q3_K,
#[subenum(
Expand All @@ -149,7 +166,8 @@ pub enum WeightType {
QwenImageWeight,
AnimaWeight,
Anima2Weight,
ErnieImageWeight
ErnieImageWeight,
LongCatImageWeight
)]
Q4_K,
#[subenum(
Expand All @@ -158,7 +176,8 @@ pub enum WeightType {
QwenImageWeight,
AnimaWeight,
Anima2Weight,
ErnieImageWeight
ErnieImageWeight,
LongCatImageWeight
)]
Q5_K,
#[subenum(
Expand All @@ -172,7 +191,8 @@ pub enum WeightType {
TwinFlowZImageTurboExpWeight,
AnimaWeight,
Anima2Weight,
ErnieImageWeight
ErnieImageWeight,
LongCatImageWeight
)]
Q6_K,
Q8_K,
Expand Down Expand Up @@ -205,12 +225,15 @@ pub enum WeightType {
Flux2KleinBase9BWeight,
AnimaWeight,
Anima2Weight,
ErnieImageWeight
ErnieImageWeight,
LongCatImageWeight
)]
BF16,
TQ1_0,
TQ2_0,
MXFP4,
NVFP4,
Q1_0,
#[subenum(SSD1BWeight(default), QwenImageWeight)]
F8_E4M3,
}
Expand Down Expand Up @@ -313,6 +336,9 @@ pub enum Preset {
HiDreamO1ImageDev,
/// cfg_scale 1.0. 20 steps 1024x1024.
HiDreamO1Image,
/// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-dev> providing a token via [crate::util::set_hf_token]
/// cfg_scale 5.0. Enable [crate::api::SampleMethod::EULER_SAMPLE_METHOD] and Diffusion Flash attention. flow_shift 3.0. 512 x 512. 20 steps
LongCatImage(LongCatImageWeight),
}

impl Preset {
Expand Down Expand Up @@ -356,6 +382,7 @@ impl Preset {
Preset::ErnieImageTurbo(sd_type_t) => ernie_image_turbo(sd_type_t),
Preset::HiDreamO1ImageDev => hi_dream_o1_image_dev(),
Preset::HiDreamO1Image => hi_dream_o1_image(),
Preset::LongCatImage(sd_type_t) => long_cat_image(sd_type_t),
}
}
}
Expand Down Expand Up @@ -682,4 +709,9 @@ mod tests {
fn test_ernie_image_turbo() {
run(Preset::ErnieImageTurbo(super::ErnieImageWeight::Q4_0));
}
#[ignore]
#[test]
fn long_cat_image() {
run(Preset::LongCatImage(super::LongCatImageWeight::Q4_0));
}
}
Loading
Loading