Skip to content

Commit 08aec64

Browse files
committed
Added prediction argument
1 parent 1896b28 commit 08aec64

4 files changed

Lines changed: 132 additions & 35 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ arguments:
333333
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
334334
-b, --batch-count COUNT number of images to generate
335335
--schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)
336+
--prediction {eps, v, edm_v, sd3_flow, flux_flow} Prediction type override
336337
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
337338
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
338339
--vae-tiling process vae in tiles to reduce memory usage

examples/cli/main.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ struct SDParams {
8585

8686
sample_method_t sample_method = EULER_A;
8787
schedule_t schedule = DEFAULT;
88+
prediction_t prediction = DEFAULT_PRED;
8889
int sample_steps = 20;
8990
float strength = 0.75f;
9091
float control_strength = 0.9f;
@@ -156,6 +157,7 @@ void print_params(SDParams params) {
156157
printf(" height: %d\n", params.height);
157158
printf(" sample_method: %s\n", sd_sample_method_name(params.sample_method));
158159
printf(" schedule: %s\n", sd_schedule_name(params.schedule));
160+
printf(" prediction: %s\n", sd_prediction_name(params.prediction));
159161
printf(" sample_steps: %d\n", params.sample_steps);
160162
printf(" strength(img2img): %.2f\n", params.strength);
161163
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
@@ -224,6 +226,7 @@ void print_usage(int argc, const char* argv[]) {
224226
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
225227
printf(" -b, --batch-count COUNT number of images to generate\n");
226228
printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n");
229+
printf(" --prediction {eps, v, edm_v, sd3_flow, flux_flow} Prediction type override.\n");
227230
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
228231
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
229232
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
@@ -494,6 +497,20 @@ void parse_args(int argc, const char** argv, SDParams& params) {
494497
return 1;
495498
};
496499

500+
auto on_prediction_arg = [&](int argc, const char** argv, int index) {
501+
if (++index >= argc) {
502+
return -1;
503+
}
504+
const char* arg = argv[index];
505+
params.prediction = str_to_prediction(arg);
506+
if (params.prediction == PREDICTION_COUNT) {
507+
fprintf(stderr, "error: invalid prediction type %s\n",
508+
arg);
509+
return -1;
510+
}
511+
return 1;
512+
};
513+
497514
auto on_sample_method_arg = [&](int argc, const char** argv, int index) {
498515
if (++index >= argc) {
499516
return -1;
@@ -564,6 +581,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
564581
{"-s", "--seed", "", on_seed_arg},
565582
{"", "--sampling-method", "", on_sample_method_arg},
566583
{"", "--schedule", "", on_schedule_arg},
584+
{"", "--prediction", "", on_prediction_arg},
567585
{"", "--skip-layers", "", on_skip_layers_arg},
568586
{"-r", "--ref-image", "", on_ref_image_arg},
569587
{"-h", "--help", "", on_help_arg},
@@ -883,6 +901,7 @@ int main(int argc, const char* argv[]) {
883901
params.wtype,
884902
params.rng_type,
885903
params.schedule,
904+
params.prediction,
886905
params.clip_on_cpu,
887906
params.control_net_cpu,
888907
params.vae_on_cpu,

stable-diffusion.cpp

Lines changed: 99 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -561,47 +561,83 @@ class StableDiffusionGGML {
561561
int64_t t1 = ggml_time_ms();
562562
LOG_INFO("loading model from '%s' completed, taking %.2fs", SAFE_STR(sd_ctx_params->model_path), (t1 - t0) * 1.0f / 1000);
563563

564-
// check is_using_v_parameterization_for_sd2
565-
566-
if (sd_version_is_sd2(version)) {
567-
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
568-
is_using_v_parameterization = true;
569-
}
570-
} else if (sd_version_is_sdxl(version)) {
571-
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
572-
// CosXL models
573-
// TODO: get sigma_min and sigma_max values from file
574-
is_using_edm_v_parameterization = true;
564+
if (sd_ctx_params->prediction != DEFAULT_PRED) {
565+
switch (sd_ctx_params->prediction) {
566+
case EPS_PRED:
567+
LOG_INFO("running in eps-prediction mode");
568+
break;
569+
case V_PRED:
570+
LOG_INFO("running in v-prediction mode");
571+
denoiser = std::make_shared<CompVisVDenoiser>();
572+
break;
573+
case EDM_V_PRED:
574+
LOG_INFO("running in v-prediction EDM mode");
575+
denoiser = std::make_shared<EDMVDenoiser>();
576+
break;
577+
case SD3_FLOW_PRED:
578+
LOG_INFO("running in FLOW mode");
579+
denoiser = std::make_shared<DiscreteFlowDenoiser>();
580+
break;
581+
case FLUX_FLOW_PRED:
582+
{
583+
LOG_INFO("running in Flux FLOW mode");
584+
float shift = 1.0f; // TODO: validate
585+
for (auto pair : model_loader.tensor_storages_types) {
586+
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
587+
shift = 1.15f;
588+
break;
589+
}
590+
}
591+
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
592+
break;
593+
}
594+
default:
595+
LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction);
596+
abort();
575597
}
576-
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
598+
} else {
599+
// check is_using_v_parameterization_for_sd2
600+
601+
if (sd_version_is_sd2(version)) {
602+
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
603+
is_using_v_parameterization = true;
604+
}
605+
} else if (sd_version_is_sdxl(version)) {
606+
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
607+
// CosXL models
608+
// TODO: get sigma_min and sigma_max values from file
609+
is_using_edm_v_parameterization = true;
610+
}
611+
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
612+
is_using_v_parameterization = true;
613+
}
614+
} else if (version == VERSION_SVD) {
615+
// TODO: V_PREDICTION_EDM
577616
is_using_v_parameterization = true;
578617
}
579-
} else if (version == VERSION_SVD) {
580-
// TODO: V_PREDICTION_EDM
581-
is_using_v_parameterization = true;
582-
}
583618

584-
if (sd_version_is_sd3(version)) {
585-
LOG_INFO("running in FLOW mode");
586-
denoiser = std::make_shared<DiscreteFlowDenoiser>();
587-
} else if (sd_version_is_flux(version)) {
588-
LOG_INFO("running in Flux FLOW mode");
589-
float shift = 1.0f; // TODO: validate
590-
for (auto pair : model_loader.tensor_storages_types) {
591-
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
592-
shift = 1.15f;
593-
break;
619+
if (sd_version_is_sd3(version)) {
620+
LOG_INFO("running in FLOW mode");
621+
denoiser = std::make_shared<DiscreteFlowDenoiser>();
622+
} else if (sd_version_is_flux(version)) {
623+
LOG_INFO("running in Flux FLOW mode");
624+
float shift = 1.0f; // TODO: validate
625+
for (auto pair : model_loader.tensor_storages_types) {
626+
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
627+
shift = 1.15f;
628+
break;
629+
}
594630
}
631+
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
632+
} else if (is_using_v_parameterization) {
633+
LOG_INFO("running in v-prediction mode");
634+
denoiser = std::make_shared<CompVisVDenoiser>();
635+
} else if (is_using_edm_v_parameterization) {
636+
LOG_INFO("running in v-prediction EDM mode");
637+
denoiser = std::make_shared<EDMVDenoiser>();
638+
} else {
639+
LOG_INFO("running in eps-prediction mode");
595640
}
596-
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
597-
} else if (is_using_v_parameterization) {
598-
LOG_INFO("running in v-prediction mode");
599-
denoiser = std::make_shared<CompVisVDenoiser>();
600-
} else if (is_using_edm_v_parameterization) {
601-
LOG_INFO("running in v-prediction EDM mode");
602-
denoiser = std::make_shared<EDMVDenoiser>();
603-
} else {
604-
LOG_INFO("running in eps-prediction mode");
605641
}
606642

607643
if (sd_ctx_params->schedule != DEFAULT) {
@@ -1290,6 +1326,31 @@ enum schedule_t str_to_schedule(const char* str) {
12901326
return SCHEDULE_COUNT;
12911327
}
12921328

1329+
const char* prediction_to_str[] = {
1330+
"default",
1331+
"eps",
1332+
"v",
1333+
"edm_v",
1334+
"sd3_flow",
1335+
"flux_flow",
1336+
};
1337+
1338+
const char* sd_prediction_name(enum prediction_t prediction) {
1339+
if (prediction < PREDICTION_COUNT) {
1340+
return prediction_to_str[prediction];
1341+
}
1342+
return NONE_STR;
1343+
}
1344+
1345+
enum prediction_t str_to_prediction(const char* str) {
1346+
for (int i = 0; i < PREDICTION_COUNT; i++) {
1347+
if (!strcmp(str, prediction_to_str[i])) {
1348+
return (enum prediction_t)i;
1349+
}
1350+
}
1351+
return PREDICTION_COUNT;
1352+
}
1353+
12931354
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
12941355
memset((void*)sd_ctx_params, 0, sizeof(sd_ctx_params_t));
12951356
sd_ctx_params->vae_decode_only = true;
@@ -1299,6 +1360,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
12991360
sd_ctx_params->wtype = SD_TYPE_COUNT;
13001361
sd_ctx_params->rng_type = CUDA_RNG;
13011362
sd_ctx_params->schedule = DEFAULT;
1363+
sd_ctx_params->prediction = DEFAULT_PRED;
13021364
sd_ctx_params->keep_clip_on_cpu = false;
13031365
sd_ctx_params->keep_control_net_on_cpu = false;
13041366
sd_ctx_params->keep_vae_on_cpu = false;
@@ -1333,6 +1395,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
13331395
"wtype: %s\n"
13341396
"rng_type: %s\n"
13351397
"schedule: %s\n"
1398+
"prediction: %s\n"
13361399
"keep_clip_on_cpu: %s\n"
13371400
"keep_control_net_on_cpu: %s\n"
13381401
"keep_vae_on_cpu: %s\n"
@@ -1358,6 +1421,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
13581421
sd_type_name(sd_ctx_params->wtype),
13591422
sd_rng_type_name(sd_ctx_params->rng_type),
13601423
sd_schedule_name(sd_ctx_params->schedule),
1424+
sd_prediction_name(sd_ctx_params->prediction),
13611425
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
13621426
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
13631427
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),

stable-diffusion.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ enum schedule_t {
6060
SCHEDULE_COUNT
6161
};
6262

63+
enum prediction_t {
64+
DEFAULT_PRED,
65+
EPS_PRED,
66+
V_PRED,
67+
EDM_V_PRED,
68+
SD3_FLOW_PRED,
69+
FLUX_FLOW_PRED,
70+
PREDICTION_COUNT
71+
};
72+
6373
// same as enum ggml_type
6474
enum sd_type_t {
6575
SD_TYPE_F32 = 0,
@@ -130,6 +140,7 @@ typedef struct {
130140
enum sd_type_t wtype;
131141
enum rng_type_t rng_type;
132142
enum schedule_t schedule;
143+
enum prediction_t prediction;
133144
bool keep_clip_on_cpu;
134145
bool keep_control_net_on_cpu;
135146
bool keep_vae_on_cpu;
@@ -219,6 +230,8 @@ SD_API const char* sd_sample_method_name(enum sample_method_t sample_method);
219230
SD_API enum sample_method_t str_to_sample_method(const char* str);
220231
SD_API const char* sd_schedule_name(enum schedule_t schedule);
221232
SD_API enum schedule_t str_to_schedule(const char* str);
233+
SD_API const char* sd_prediction_name(enum prediction_t prediction);
234+
SD_API enum prediction_t str_to_prediction(const char* str);
222235

223236
SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
224237
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);

0 commit comments

Comments
 (0)