@@ -561,47 +561,83 @@ class StableDiffusionGGML {
561561 int64_t t1 = ggml_time_ms ();
562562 LOG_INFO (" loading model from '%s' completed, taking %.2fs" , SAFE_STR (sd_ctx_params->model_path ), (t1 - t0) * 1 .0f / 1000 );
563563
564- // check is_using_v_parameterization_for_sd2
565-
566- if (sd_version_is_sd2 (version)) {
567- if (is_using_v_parameterization_for_sd2 (ctx, sd_version_is_inpaint (version))) {
568- is_using_v_parameterization = true ;
569- }
570- } else if (sd_version_is_sdxl (version)) {
571- if (model_loader.tensor_storages_types .find (" edm_vpred.sigma_max" ) != model_loader.tensor_storages_types .end ()) {
572- // CosXL models
573- // TODO: get sigma_min and sigma_max values from file
574- is_using_edm_v_parameterization = true ;
564+ if (sd_ctx_params->prediction != DEFAULT_PRED ) {
565+ switch (sd_ctx_params->prediction ) {
566+ case EPS_PRED :
567+ LOG_INFO (" running in eps-prediction mode" );
568+ break ;
569+ case V_PRED :
570+ LOG_INFO (" running in v-prediction mode" );
571+ denoiser = std::make_shared<CompVisVDenoiser>();
572+ break ;
573+ case EDM_V_PRED :
574+ LOG_INFO (" running in v-prediction EDM mode" );
575+ denoiser = std::make_shared<EDMVDenoiser>();
576+ break ;
577+ case SD3_FLOW_PRED :
578+ LOG_INFO (" running in FLOW mode" );
579+ denoiser = std::make_shared<DiscreteFlowDenoiser>();
580+ break ;
581+ case FLUX_FLOW_PRED :
582+ {
583+ LOG_INFO (" running in Flux FLOW mode" );
584+ float shift = 1 .0f ; // TODO: validate
585+ for (auto pair : model_loader.tensor_storages_types ) {
586+ if (pair.first .find (" model.diffusion_model.guidance_in.in_layer.weight" ) != std::string::npos) {
587+ shift = 1 .15f ;
588+ break ;
589+ }
590+ }
591+ denoiser = std::make_shared<FluxFlowDenoiser>(shift);
592+ break ;
593+ }
594+ default :
595+ LOG_ERROR (" Unknown parametrization %i" , sd_ctx_params->prediction );
596+ abort ();
575597 }
576- if (model_loader.tensor_storages_types .find (" v_pred" ) != model_loader.tensor_storages_types .end ()) {
598+ } else {
599+ // check is_using_v_parameterization_for_sd2
600+
601+ if (sd_version_is_sd2 (version)) {
602+ if (is_using_v_parameterization_for_sd2 (ctx, sd_version_is_inpaint (version))) {
603+ is_using_v_parameterization = true ;
604+ }
605+ } else if (sd_version_is_sdxl (version)) {
606+ if (model_loader.tensor_storages_types .find (" edm_vpred.sigma_max" ) != model_loader.tensor_storages_types .end ()) {
607+ // CosXL models
608+ // TODO: get sigma_min and sigma_max values from file
609+ is_using_edm_v_parameterization = true ;
610+ }
611+ if (model_loader.tensor_storages_types .find (" v_pred" ) != model_loader.tensor_storages_types .end ()) {
612+ is_using_v_parameterization = true ;
613+ }
614+ } else if (version == VERSION_SVD ) {
615+ // TODO: V_PREDICTION_EDM
577616 is_using_v_parameterization = true ;
578617 }
579- } else if (version == VERSION_SVD ) {
580- // TODO: V_PREDICTION_EDM
581- is_using_v_parameterization = true ;
582- }
583618
584- if (sd_version_is_sd3 (version)) {
585- LOG_INFO (" running in FLOW mode" );
586- denoiser = std::make_shared<DiscreteFlowDenoiser>();
587- } else if (sd_version_is_flux (version)) {
588- LOG_INFO (" running in Flux FLOW mode" );
589- float shift = 1 .0f ; // TODO: validate
590- for (auto pair : model_loader.tensor_storages_types ) {
591- if (pair.first .find (" model.diffusion_model.guidance_in.in_layer.weight" ) != std::string::npos) {
592- shift = 1 .15f ;
593- break ;
619+ if (sd_version_is_sd3 (version)) {
620+ LOG_INFO (" running in FLOW mode" );
621+ denoiser = std::make_shared<DiscreteFlowDenoiser>();
622+ } else if (sd_version_is_flux (version)) {
623+ LOG_INFO (" running in Flux FLOW mode" );
624+ float shift = 1 .0f ; // TODO: validate
625+ for (auto pair : model_loader.tensor_storages_types ) {
626+ if (pair.first .find (" model.diffusion_model.guidance_in.in_layer.weight" ) != std::string::npos) {
627+ shift = 1 .15f ;
628+ break ;
629+ }
594630 }
631+ denoiser = std::make_shared<FluxFlowDenoiser>(shift);
632+ } else if (is_using_v_parameterization) {
633+ LOG_INFO (" running in v-prediction mode" );
634+ denoiser = std::make_shared<CompVisVDenoiser>();
635+ } else if (is_using_edm_v_parameterization) {
636+ LOG_INFO (" running in v-prediction EDM mode" );
637+ denoiser = std::make_shared<EDMVDenoiser>();
638+ } else {
639+ LOG_INFO (" running in eps-prediction mode" );
595640 }
596- denoiser = std::make_shared<FluxFlowDenoiser>(shift);
597- } else if (is_using_v_parameterization) {
598- LOG_INFO (" running in v-prediction mode" );
599- denoiser = std::make_shared<CompVisVDenoiser>();
600- } else if (is_using_edm_v_parameterization) {
601- LOG_INFO (" running in v-prediction EDM mode" );
602- denoiser = std::make_shared<EDMVDenoiser>();
603- } else {
604- LOG_INFO (" running in eps-prediction mode" );
605641 }
606642
607643 if (sd_ctx_params->schedule != DEFAULT ) {
@@ -1290,6 +1326,31 @@ enum schedule_t str_to_schedule(const char* str) {
12901326 return SCHEDULE_COUNT ;
12911327}
12921328
1329+ const char * prediction_to_str[] = {
1330+ " default" ,
1331+ " eps" ,
1332+ " v" ,
1333+ " edm_v" ,
1334+ " sd3_flow" ,
1335+ " flux_flow" ,
1336+ };
1337+
1338+ const char * sd_prediction_name (enum prediction_t prediction) {
1339+ if (prediction < PREDICTION_COUNT ) {
1340+ return prediction_to_str[prediction];
1341+ }
1342+ return NONE_STR ;
1343+ }
1344+
1345+ enum prediction_t str_to_prediction (const char * str) {
1346+ for (int i = 0 ; i < PREDICTION_COUNT ; i++) {
1347+ if (!strcmp (str, prediction_to_str[i])) {
1348+ return (enum prediction_t )i;
1349+ }
1350+ }
1351+ return PREDICTION_COUNT ;
1352+ }
1353+
12931354void sd_ctx_params_init (sd_ctx_params_t * sd_ctx_params) {
12941355 memset ((void *)sd_ctx_params, 0 , sizeof (sd_ctx_params_t ));
12951356 sd_ctx_params->vae_decode_only = true ;
@@ -1299,6 +1360,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
12991360 sd_ctx_params->wtype = SD_TYPE_COUNT ;
13001361 sd_ctx_params->rng_type = CUDA_RNG ;
13011362 sd_ctx_params->schedule = DEFAULT ;
1363+ sd_ctx_params->prediction = DEFAULT_PRED ;
13021364 sd_ctx_params->keep_clip_on_cpu = false ;
13031365 sd_ctx_params->keep_control_net_on_cpu = false ;
13041366 sd_ctx_params->keep_vae_on_cpu = false ;
@@ -1333,6 +1395,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
13331395 " wtype: %s\n "
13341396 " rng_type: %s\n "
13351397 " schedule: %s\n "
1398+ " prediction: %s\n "
13361399 " keep_clip_on_cpu: %s\n "
13371400 " keep_control_net_on_cpu: %s\n "
13381401 " keep_vae_on_cpu: %s\n "
@@ -1358,6 +1421,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
13581421 sd_type_name (sd_ctx_params->wtype ),
13591422 sd_rng_type_name (sd_ctx_params->rng_type ),
13601423 sd_schedule_name (sd_ctx_params->schedule ),
1424+ sd_prediction_name (sd_ctx_params->prediction ),
13611425 BOOL_STR (sd_ctx_params->keep_clip_on_cpu ),
13621426 BOOL_STR (sd_ctx_params->keep_control_net_on_cpu ),
13631427 BOOL_STR (sd_ctx_params->keep_vae_on_cpu ),
0 commit comments