diff --git a/common/arg.cpp b/common/arg.cpp index 55795d357d90..895a9eaadc69 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -57,6 +57,7 @@ static std::initializer_list mmproj_examples = { LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI, + LLAMA_EXAMPLE_DIFFUSION, }; static std::string read_file(const std::string & fname) { @@ -2228,7 +2229,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.image.emplace_back(item); } } - ).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_CLI})); + ).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION})); add_opt(common_arg( {"--image-min-tokens"}, "N", "minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)", @@ -3864,6 +3865,116 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("enable visual diffusion mode (show progressive generation) (default: %s)", params.diffusion.visual_mode ? "true" : "false"), [](common_params & params) { params.diffusion.visual_mode = true; } ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--no-diffusion-gpu-sampling"}, + "disable CUDA block-diffusion sampling fast path", + [](common_params & params) { params.diffusion.gpu_sampling = false; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--no-diffusion-device-selfcond"}, + "disable device-resident block-diffusion self-conditioning", + [](common_params & params) { params.diffusion.device_self_cond = false; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--no-diffusion-device-denoise-loop"}, + "disable device-side block-diffusion canvas and stop-state updates", + [](common_params & params) { params.diffusion.device_denoise_loop = false; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-pin-host-outputs"}, + "register compact diffusion output buffers as pinned host memory", + [](common_params & params) { params.diffusion.pin_host_outputs = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-self-cond-top-k"}, "N", + string_format("block-diffusion sparse self-conditioning width (default: %d)", params.diffusion.self_cond_top_k), + [](common_params & params, int value) { params.diffusion.self_cond_top_k = value; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-input-gpu-groups"}, "N", + string_format("bitmask of block-diffusion decoder input groups assigned to GPU backend (default: %u)", params.diffusion.input_gpu_groups), + [](common_params & params, int value) { params.diffusion.input_gpu_groups = (uint32_t) std::max(value, 0); } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-default-top-k"}, "N", + "block-diffusion top-k used when --top-k is not explicitly provided", + [](common_params & params, int value) { params.diffusion.default_top_k = value; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-force-top-k"}, "N", + "block-diffusion server: override per-request top_k when N > 0", + [](common_params & params, int value) { params.diffusion.force_top_k = value; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-fused-self-cond-embd"}, + "use fused device self-conditioning embedding input for block diffusion", + [](common_params & params) { params.diffusion.fused_self_cond_embd = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-fuse-final-softcap"}, + "move final logit softcap into the CUDA diffusion sampling kernel", + [](common_params & params) { params.diffusion.fuse_final_logit_softcap = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-separate-encoder-decoder"}, + "build separate block-diffusion encoder and decoder graph variants", + [](common_params & params) { params.diffusion.separate_encoder_decoder = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-cuda-direct-self-cond"}, + "write CUDA diffusion self-conditioning directly into decoder graph inputs", + [](common_params & params) { params.diffusion.cuda_direct_self_cond = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-cuda-final-tokens-on-stop"}, + "copy final diffusion tokens only when the device stop condition is reached", + [](common_params & params) { params.diffusion.cuda_final_tokens_on_stop = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-cuda-fused-top-k-sample"}, + "fuse CUDA diffusion top-k selection and sampling", + [](common_params & params) { params.diffusion.cuda_fused_top_k_sample = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-cuda-tight-top-k"}, + "avoid extra CUDA diffusion top-k scratch width when possible", + [](common_params & params) { params.diffusion.cuda_tight_top_k = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-cuda-parallel-full-softmax"}, + "parallelize CUDA diffusion full-vocab sampling when top-k is 0", + [](common_params & params) { params.diffusion.cuda_parallel_full_softmax = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-cuda-fused-full-softmax"}, + "fuse CUDA diffusion full-vocab softmax sampling and self-conditioning", + [](common_params & params) { params.diffusion.cuda_fused_full_softmax = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--diffusion-cuda-top-k-local-k"}, "N", + "CUDA diffusion local top-k candidates per thread (0 = backend default)", + [](common_params & params, int value) { params.diffusion.cuda_top_k_local_k = value; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--no-diffusion-cuda-fast-top-k"}, + "disable CUDA diffusion CUB/fast top-k selection path", + [](common_params & params) { params.diffusion.cuda_fast_top_k = false; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--top-k-start"}, "N", + "block-diffusion: anneal top-k from N at the first (high-entropy) denoising step (with --top-k-end)", + [](common_params & params, int value) { params.diffusion.top_k_start = value; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--top-k-end"}, "N", + "block-diffusion: anneal top-k to N at the last denoising step (with --top-k-start)", + [](common_params & params, int value) { params.diffusion.top_k_end = value; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + {"--top-k-tail-correction"}, + "block-diffusion: use the exact full-vocab entropy (logsumexp) for the accept/stop signal under top-k", + [](common_params & params) { params.diffusion.top_k_tail_correction = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); add_opt(common_arg( {"--diffusion-eps"}, "F", string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps), diff --git a/common/common.cpp b/common/common.cpp index b01772e1cbfe..23209395dbef 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1585,6 +1585,11 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.op_offload = !params.no_op_offload; cparams.swa_full = params.swa_full; cparams.kv_unified = params.kv_unified; + cparams.diffusion_self_cond_top_k = params.diffusion.self_cond_top_k; + cparams.diffusion_input_gpu_groups = params.diffusion.input_gpu_groups; + cparams.diffusion_fused_self_cond_embd = params.diffusion.fused_self_cond_embd; + cparams.diffusion_fuse_final_logit_softcap = params.diffusion.fuse_final_logit_softcap; + cparams.diffusion_separate_encoder_decoder = params.diffusion.separate_encoder_decoder; cparams.type_k = params.cache_type_k; cparams.type_v = params.cache_type_v; diff --git a/common/common.h b/common/common.h index 4864186f6287..d702febf2137 100644 --- a/common/common.h +++ b/common/common.h @@ -381,6 +381,10 @@ struct common_params_vocoder { struct common_params_diffusion { int32_t steps = 128; bool visual_mode = false; + bool gpu_sampling = true; // use CUDA diffusion sampling fast path when available + bool device_self_cond = true; // keep diffusion self-conditioning state on device + bool device_denoise_loop = true; // update diffusion canvas/stop state on device + bool pin_host_outputs = false; // register compact D2H output buffers as pinned host memory float eps = 0; // epsilon for timesteps int32_t block_length = 0; // block length for generation @@ -390,6 +394,30 @@ struct common_params_diffusion { float cfg_scale = 0; // classifier-free guidance scale bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0 + + // block-diffusion (diffusion-gemma) top-k host sampling knobs + int32_t top_k_start = 0; // anneal top-k from this (first/high-entropy step) ... + int32_t top_k_end = 0; // ... to this (last step); both > 0 enables annealing + bool top_k_tail_correction = false; // use exact full-vocab entropy for accept/stop + int32_t default_top_k = 0; // top-k used when --top-k is not explicitly provided + int32_t force_top_k = 0; // server: override per-request top_k when > 0 + int32_t self_cond_top_k = 256; // sparse self-conditioning gather width + uint32_t input_gpu_groups = 63; // decoder input tensor groups assigned to GPU backend + + // CUDA diffusion sampling fast-path knobs. Defaults preserve behavior when no tuning flags are passed. + bool cuda_fast_top_k = true; + bool cuda_direct_self_cond = false; + bool cuda_final_tokens_on_stop = false; + bool cuda_fused_top_k_sample = false; + bool cuda_tight_top_k = false; + bool cuda_parallel_full_softmax = false; + bool cuda_fused_full_softmax = false; + int32_t cuda_top_k_local_k = 0; // 0 = backend default + + // Diffusion graph-shape knobs. + bool fused_self_cond_embd = false; + bool fuse_final_logit_softcap = false; + bool separate_encoder_decoder = false; }; // reasoning API response format (not to be confused as chat template's reasoning format) diff --git a/conversion/__init__.py b/conversion/__init__.py index 18162976f458..0269642ef32b 100644 --- a/conversion/__init__.py +++ b/conversion/__init__.py @@ -78,6 +78,7 @@ "Gemma4AssistantForCausalLM": "gemma", "Gemma4ForConditionalGeneration": "gemma", "Gemma4ForCausalLM": "gemma", + "DiffusionGemmaForBlockDiffusion": "gemma", "Gemma4UnifiedForConditionalGeneration": "gemma", "Gemma4UnifiedAssistantForCausalLM": "gemma", "GemmaForCausalLM": "gemma", @@ -245,6 +246,7 @@ "CogVLMForCausalLM": "cogvlm", "DeepseekOCR2ForCausalLM": "deepseek", "DeepseekOCRForCausalLM": "deepseek", + "DiffusionGemmaForBlockDiffusion": "gemma", "DotsOCRForCausalLM": "dotsocr", "Exaone4_5_ForConditionalGeneration": "exaone", "Gemma3ForConditionalGeneration": "gemma", diff --git a/conversion/gemma.py b/conversion/gemma.py index 5b4ca5c583df..b3880ef1b75a 100644 --- a/conversion/gemma.py +++ b/conversion/gemma.py @@ -655,7 +655,7 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() - num_kv_shared_layers = self.hparams["num_kv_shared_layers"] + num_kv_shared_layers = self.hparams.get("num_kv_shared_layers", 0) self.gguf_writer.add_shared_kv_layers(num_kv_shared_layers) # per-layer embedding is optional @@ -764,7 +764,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield from super().modify_tensors(data_torch, name, bid) - @ModelBase.register("Gemma4UnifiedForConditionalGeneration") class Gemma4UnifiedModel(Gemma4Model): model_arch = gguf.MODEL_ARCH.GEMMA4 @@ -805,6 +804,32 @@ def set_gguf_parameters(self): self.gguf_writer.add_nextn_predict_layers(self.block_count) +@ModelBase.register("DiffusionGemmaForBlockDiffusion") +class DiffusionGemmaModel(Gemma4Model): + # Block-diffusion variant of Gemma 4. Reuses the gemma4 decoder block; adds the + # self-conditioning MLP and nests the language model under `model.decoder.`. + model_arch = gguf.MODEL_ARCH.DIFFUSION_GEMMA + + @classmethod + def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: + name, _ = item + # The text encoder shares every weight with the decoder except its own + # per-layer `layer_scalar`. The single-stack graph uses the decoder scalars, + # so the encoder-only tensors are dropped here. + if name.startswith("model.encoder."): + return None + return super().filter_tensors(item) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # diffusion_gemma nests the language model under `model.decoder.`; strip it so + # the shared gemma4 tensor mappings apply. `model.decoder.self_conditioning.*` + # then maps to the SELF_COND_* tensors. + if name.startswith("model.decoder."): + name = "model." + name[len("model.decoder."):] + + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Gemma4ForConditionalGeneration") class Gemma4VisionAudioModel(MmprojModel): has_audio_encoder = True @@ -882,7 +907,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min")) yield (mapped_name, data_torch) - @ModelBase.register("Gemma4UnifiedForConditionalGeneration") class Gemma4UnifiedVisionAudioModel(Gemma4VisionAudioModel): has_audio_encoder = True @@ -945,3 +969,36 @@ def modify_tensors(self, data_torch, name, bid): perm = row * p * 3 + col * 3 + ch data_torch = data_torch[perm] return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("DiffusionGemmaForBlockDiffusion") +class DiffusionGemmaVisionModel(Gemma4VisionAudioModel): + # mmproj (vision) export for the v7 diffusion_gemma multimodal model. Reuses the gemma4 + # vision tower (GEMMA4V); the v7 checkpoint nests it under `model.encoder.*` and has no + # audio encoder, so only the vision tower + vision projector are exported. + has_audio_encoder = False + has_vision_encoder = True + + def set_gguf_parameters(self): + # MmprojModel base writes the generic vision params; do NOT call the gemma4 + # vision+audio set_gguf_parameters (it asserts an audio config, which v7 lacks). + MmprojModel.set_gguf_parameters(self) + assert self.hparams_vision is not None + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA4V) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6)) + + @classmethod + def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: + name, _ = item + # keep only the vision tower + vision projector; drop the diffusion decoder + # (the text-encoder language_model.* tensors are dropped by MmprojModel.filter_tensors) + if name.startswith("model.decoder."): + return None + return super().filter_tensors(item) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # v7 nests the vision tower / projector under `model.encoder.`; strip it so the gemma4 + # vision tensor mappings (model.vision_tower.* / model.embed_vision.*) apply. + if name.startswith("model.encoder."): + name = "model." + name[len("model.encoder."):] + yield from super().modify_tensors(data_torch, name, bid) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 39f802d250e1..db9276b5ee6c 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -34,6 +34,7 @@ else() add_subdirectory(gen-docs) add_subdirectory(training) add_subdirectory(diffusion) + add_subdirectory(diffusion-gemma) if (NOT GGML_BACKEND_DL) add_subdirectory(convert-llama2c-to-ggml) # these examples use the backends directly and cannot be built with dynamic loading diff --git a/examples/diffusion-gemma/CMakeLists.txt b/examples/diffusion-gemma/CMakeLists.txt new file mode 100644 index 000000000000..b0f805717a7c --- /dev/null +++ b/examples/diffusion-gemma/CMakeLists.txt @@ -0,0 +1,17 @@ +set(TARGET llama-diffusion-gemma-cli) +add_executable(${TARGET} diffusion-gemma-cli.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama llama-common mtmd ${CMAKE_THREAD_LIBS_INIT}) +# mtmd (tools/) is added after examples/, so add its include dir explicitly for the headers +target_include_directories(${TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../tools/mtmd) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +# OpenAI-compatible HTTP server for the block-diffusion models (llama-server analogue) +set(TARGET llama-diffusion-gemma-server) +add_executable(${TARGET} diffusion-gemma-server.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama llama-common mtmd cpp-httplib ${CMAKE_THREAD_LIBS_INIT}) +target_include_directories(${TARGET} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../tools/mtmd + ${CMAKE_SOURCE_DIR}/vendor) # cpp-httplib/httplib.h, nlohmann/json.hpp +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/diffusion-gemma/diffusion-gemma-cli.cpp b/examples/diffusion-gemma/diffusion-gemma-cli.cpp new file mode 100644 index 000000000000..20aff7e7b6e1 --- /dev/null +++ b/examples/diffusion-gemma/diffusion-gemma-cli.cpp @@ -0,0 +1,819 @@ +// Block-diffusion generation for diffusion_gemma. +// +// Implements the reference block-diffusion loop (EntropyBoundSampler + StableAndConfident +// stopping + linear temperature schedule) with KV-cache reuse: +// +// * ENCODER phase (causal, no self-conditioning): the prompt is prefilled once into the +// unified sliding-window KV cache. Its per-layer K/V become the read-only prefix. +// * DECODER phase (bidirectional, self-conditioned): each denoising step decodes only the +// canvas tokens at positions [n_past, n_past+canvas). They read the cached prefix and +// attend the canvas bidirectionally. After reading the logits the canvas K/V is rolled +// back (llama_memory_seq_rm) so the cache keeps only the committed prefix. +// +// This avoids re-encoding the prompt on every denoising step. Multi-block autoregressive +// generation (commit the finalized canvas via an encoder pass, then advance n_past) is +// layered on top of this single-block loop. + +#include "arg.h" +#include "chat.h" +#include "common.h" +#include "llama.h" +#include "log.h" +#include "mtmd.h" +#include "mtmd-helper.h" +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// reference defaults from generation_config.json / DiffusionGemmaGenerationConfig +static constexpr int DEF_CANVAS_LENGTH = 256; +static constexpr int DEF_MAX_DENOISE_STEPS = 48; +static constexpr float ENTROPY_BOUND = 0.1f; // EntropyBoundSamplerConfig.entropy_bound +static constexpr float TEMP_MIN = 0.4f; // LinearTemperatureScheduleConfig.t_min +static constexpr float TEMP_MAX = 0.8f; // LinearTemperatureScheduleConfig.t_max +static constexpr float CONFIDENCE_THRESHOLD = 0.005f; // StableAndConfident.confidence_threshold +static constexpr int STABILITY_THRESHOLD = 1; // StableAndConfident.stability_threshold +static constexpr int GPU_SAMPLING_MAX_TOP_K = 1024; // small top-k sort kernel limit + +#ifdef GGML_USE_CUDA +struct diffusion_cuda_host_pin { + diffusion_cuda_host_pin(void * ptr, size_t size, bool enabled) : ptr(ptr), size(size) { + registered = enabled && ptr && size && ggml_backend_cuda_register_host_buffer(ptr, size); + } + + ~diffusion_cuda_host_pin() { + if (registered) { + ggml_backend_cuda_unregister_host_buffer(ptr); + } + } + + diffusion_cuda_host_pin(const diffusion_cuda_host_pin &) = delete; + diffusion_cuda_host_pin & operator=(const diffusion_cuda_host_pin &) = delete; + + void * ptr = nullptr; + size_t size = 0; + bool registered = false; +}; +#else +struct diffusion_cuda_host_pin { + diffusion_cuda_host_pin(void *, size_t, bool) {} +}; +#endif + +static int diffusion_self_cond_top_k(const common_params & params) { + constexpr int def = 256; + const int k = params.diffusion.self_cond_top_k; + if (k <= 0) { + return def; + } + return std::min(k, def); +} + +// apply the model's chat template to the user prompt (this is a chat-trained model) +static std::string format_chat(llama_model * model, const std::string & prompt) { + auto tmpls = common_chat_templates_init(model, ""); + common_chat_templates_inputs inputs; + common_chat_msg user; + user.role = "user"; + user.content = prompt; + inputs.messages.push_back(user); + inputs.add_generation_prompt = true; + return common_chat_templates_apply(tmpls.get(), inputs).prompt; +} + +static void diffusion_gemma_print_usage(int, char **) { + printf("\nDiffusion-Gemma options:\n"); + printf(" --diffusion-timing print diffusion decode/sample timing breakdown\n"); +} + +int main(int argc, char ** argv) { + bool log_step_timing = false; + std::vector fwd; + fwd.push_back(argv[0]); + for (int i = 1; i < argc; ++i) { + const std::string arg = argv[i]; + if (arg == "--diffusion-timing") { + log_step_timing = true; + } else { + fwd.push_back(argv[i]); + } + } + + common_params params; + params.diffusion.steps = DEF_MAX_DENOISE_STEPS; + if (!common_params_parse((int) fwd.size(), fwd.data(), params, LLAMA_EXAMPLE_DIFFUSION, diffusion_gemma_print_usage)) { + return 1; + } + common_init(); + if (!params.diffusion.device_self_cond && params.diffusion.fused_self_cond_embd) { + LOG_ERR("--no-diffusion-device-selfcond cannot be used with --diffusion-fused-self-cond-embd\n"); + return 1; + } + + // diffusion config + // canvas_length is fixed at the trained block size (256). + const int canvas_length = DEF_CANVAS_LENGTH; + const int n_steps = std::max(params.diffusion.steps, 1); + // number of autoregressive canvas blocks = ceil(-n / canvas_length), i.e. -n is the total + // number of tokens to generate (e.g. -n 256 -> 1 canvas, -n 512 -> 2, -n 1024 -> 4). The + // model may stop earlier on an EOG token; default (no -n) is 1 canvas. + const int blocks_from_n = params.n_predict > 0 ? (params.n_predict + canvas_length - 1) / canvas_length : 1; + const int max_canvases = std::max(blocks_from_n, 1); // autoregressive blocks + const float entropy_bound = ENTROPY_BOUND; + + // top-k host sampling (CLI flags; default = full softmax over the whole vocab): + // --top-k k : top-k logits per position for softmax/sample/self-cond (0 = full). + // --top-k-start/--top-k-end : anneal k from START (first/high-entropy step) to END (last step). + // --top-k-tail-correction : exact full-vocab entropy (logsumexp) for the accept/stop signal, + // instead of the under-estimating top-k entropy. + // --top-k uses its own "0 = disabled" convention and is applied only when explicitly passed. + const int topk_fixed = (params.sampling.user_sampling_config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K) + ? params.sampling.top_k + : std::max(params.diffusion.default_top_k, 0); + const int topk_start = params.diffusion.top_k_start; + const int topk_end = params.diffusion.top_k_end; + const int topk_tail = params.diffusion.top_k_tail_correction ? 1 : 0; + const int SC_K = diffusion_self_cond_top_k(params); + + llama_backend_init(); + + llama_model_params model_params = llama_model_default_params(); + // Offload all layers to the GPU by default (when built with a GPU backend, e.g. + // -DGGML_CUDA=ON). Pass -ngl N to limit offload, or -ngl 0 to force CPU. With a + // CPU-only build this has no effect. (params.n_gpu_layers defaults to -1 = auto.) + model_params.n_gpu_layers = params.n_gpu_layers >= 0 ? params.n_gpu_layers : 999; + model_params.devices = params.devices.data(); + model_params.use_mmap = params.use_mmap; + + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); + if (!model) { + LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str()); + return 1; + } + if (!llama_model_is_diffusion(model)) { + LOG_ERR("error: not a diffusion model\n"); + llama_model_free(model); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_vocab = llama_vocab_n_tokens(vocab); + + // Build the prompt prefix. Text-only path tokenizes the chat-formatted prompt. Multimodal + // path (--mmproj + --image) tokenizes via libmtmd: the image marker expands to the gemma + // image tokens and the vision embeddings are produced by the GEMMA4V mmproj. + const bool use_mm = !params.mmproj.path.empty() && !params.image.empty(); + + std::vector prompt_tokens; // text-only prefill + mtmd::context_ptr mctx_vision; // multimodal context + mtmd::input_chunks mm_chunks(mtmd_input_chunks_init()); + int prefix_len = 0; // total positions in the prompt prefix + + if (use_mm) { + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params.mmproj_use_gpu; + mparams.print_timings = false; + mparams.n_threads = params.cpuparams.n_threads; + mctx_vision.reset(mtmd_init_from_file(params.mmproj.path.c_str(), model, mparams)); + if (!mctx_vision) { + LOG_ERR("error: failed to load mmproj '%s'\n", params.mmproj.path.c_str()); + llama_model_free(model); + return 1; + } + + // load image(s) and build one media marker per image + mtmd::bitmaps bitmaps; + std::string markers; + for (const auto & img : params.image) { + auto bmp_res = mtmd_helper_bitmap_init_from_file(mctx_vision.get(), img.c_str(), false); + if (!bmp_res.bitmap) { + LOG_ERR("error: failed to load image '%s'\n", img.c_str()); + llama_model_free(model); + return 1; + } + bitmaps.entries.emplace_back(bmp_res.bitmap); + markers += mtmd_default_marker(); + markers += "\n"; + } + + // chat-format with the image marker(s) prepended to the user content + const std::string formatted = format_chat(model, markers + params.prompt); + LOG_INF("formatted prompt: %s\n", formatted.c_str()); + + mtmd_input_text text; + text.text = formatted.c_str(); + text.add_special = false; + text.parse_special = true; + auto bmp_c = bitmaps.c_ptr(); + if (mtmd_tokenize(mctx_vision.get(), mm_chunks.ptr.get(), &text, bmp_c.data(), bmp_c.size()) != 0) { + LOG_ERR("error: mtmd_tokenize failed\n"); + llama_model_free(model); + return 1; + } + prefix_len = (int) mtmd_helper_get_n_pos(mm_chunks.ptr.get()); + } else { + // text-only: chat-format and tokenize (turn/channel special tokens) + if (!params.prompt.empty()) { + const std::string formatted = format_chat(model, params.prompt); + LOG_INF("formatted prompt: %s\n", formatted.c_str()); + prompt_tokens = common_tokenize(vocab, formatted, /*add_special*/ false, /*parse_special*/ true); + } + prefix_len = (int) prompt_tokens.size(); + } + + // Context holds the committed prefix (prompt + finalized canvases) plus the canvas being + // denoised, plus one extra canvas of headroom (the in-flight canvas's K/V is written then + // rolled back each denoising step, so the ring buffer needs room before cells are reused). + const int n_ctx_min = prefix_len + (max_canvases + 1) * canvas_length; + const int n_ctx = std::max(n_ctx_min, (int) params.n_ctx); + // Largest single decode is the prompt prefill (prefix_len) or a canvas pass (canvas_length). + const int n_ub = std::max(std::max(prefix_len, canvas_length), 1); + + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = n_ctx; + ctx_params.n_batch = n_ub; + ctx_params.n_ubatch = n_ub; + ctx_params.no_perf = params.no_perf; + ctx_params.diffusion_self_cond_top_k = SC_K; + ctx_params.diffusion_input_gpu_groups = params.diffusion.input_gpu_groups; + ctx_params.diffusion_fused_self_cond_embd = params.diffusion.fused_self_cond_embd; + ctx_params.diffusion_fuse_final_logit_softcap = params.diffusion.fuse_final_logit_softcap; + ctx_params.diffusion_separate_encoder_decoder = params.diffusion.separate_encoder_decoder; + + llama_context * ctx = llama_init_from_model(model, ctx_params); + if (!ctx) { + LOG_ERR("error: failed to create context\n"); + llama_model_free(model); + return 1; + } + llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads); + llama_set_diffusion_prompt_len(ctx, prefix_len); + llama_memory_t mem = llama_get_memory(ctx); + + const int topk_max_requested = + (topk_start > 0 && topk_end > 0) ? std::max(topk_start, topk_end) : topk_fixed; + const bool gpu_sampling_requested = params.diffusion.gpu_sampling; + const bool gpu_sampling_topk_ok = topk_max_requested <= 0 || topk_max_requested <= GPU_SAMPLING_MAX_TOP_K; + const bool use_gpu_sampling = gpu_sampling_requested && + gpu_sampling_topk_ok && + llama_diffusion_sample_topk_supported(ctx); + const bool use_device_self_cond = use_gpu_sampling && params.diffusion.device_self_cond; + const bool use_device_loop = use_device_self_cond && params.diffusion.device_denoise_loop; + const int device_early_stop_interval = use_device_loop ? 1 : 0; + llama_set_diffusion_gpu_sampling(ctx, use_gpu_sampling); + + LOG_INF("diffusion-gemma: prefix=%d canvas=%d max_canvases=%d steps=%d entropy_bound=%.3f temp=[%.2f,%.2f] n_ctx=%d mm=%d\n", + prefix_len, canvas_length, max_canvases, n_steps, entropy_bound, TEMP_MIN, TEMP_MAX, n_ctx, (int) use_mm); + LOG_INF("diffusion-gemma: gpu sampling: %s%s%s%s\n", + use_gpu_sampling ? "on" : "off", + (!gpu_sampling_topk_ok ? " (top-k exceeds CUDA fast-path limit)" : + (!gpu_sampling_requested ? " (disabled by --no-diffusion-gpu-sampling)" : "")), + use_device_self_cond ? " | device self-cond: on" : "", + use_device_loop ? " | device loop: on" : ""); + if (device_early_stop_interval > 0) { + LOG_INF("diffusion-gemma: device early-stop interval=%d\n", device_early_stop_interval); + } + if (topk_fixed > 0 || (topk_start > 0 && topk_end > 0)) { + LOG_INF("diffusion-gemma: top-k sampling: fixed=%d anneal=[%d->%d] tail_correction=%d (vocab=%d)\n", + topk_fixed, topk_start, topk_end, topk_tail, n_vocab); + } + + std::mt19937 rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? 1234u : params.sampling.seed); + std::uniform_int_distribution rand_tok(0, n_vocab - 1); + std::uniform_real_distribution rand_unif(0.0f, 1.0f); + + llama_batch batch = llama_batch_init(n_ub, 0, 1); + + // ---- ENCODER phase: prefill the prompt prefix into the KV cache (no self-conditioning) ---- + int n_past = 0; + llama_set_diffusion_decoder_phase(ctx, false); + llama_set_diffusion_self_cond_topk(ctx, nullptr, nullptr, 0, 0); + const auto t_prefill_start = std::chrono::steady_clock::now(); + if (use_mm) { + // mtmd_helper_eval_chunks decodes text chunks (tokens, causal) and the image chunk + // (vision embeddings, bidirectional for gemma) into the cache, managing causal_attn. + llama_pos new_n_past = 0; + if (mtmd_helper_eval_chunks(mctx_vision.get(), ctx, mm_chunks.ptr.get(), + /*n_past*/ 0, /*seq_id*/ 0, /*n_batch*/ n_ub, /*logits_last*/ true, &new_n_past)) { + LOG_ERR("error: multimodal prefill failed\n"); + llama_batch_free(batch); + llama_free(ctx); + llama_model_free(model); + return 1; + } + n_past = (int) new_n_past; // prompt + image K/V is now the committed read-only prefix + } else if (prefix_len > 0) { + llama_set_causal_attn(ctx, true); + batch.n_tokens = prefix_len; + for (int i = 0; i < prefix_len; ++i) { + batch.token[i] = prompt_tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = (i == prefix_len - 1) ? 1 : 0; // logits unused; keep n_outputs >= 1 + } + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("error: prompt prefill (encoder) decode failed\n"); + llama_batch_free(batch); + llama_free(ctx); + llama_model_free(model); + return 1; + } + n_past = prefix_len; // prompt K/V is now the committed read-only prefix + } + const double prefill_s = std::chrono::duration(std::chrono::steady_clock::now() - t_prefill_start).count(); + if (prefix_len > 0) { + // encoder-phase prefill: causal, NO self-conditioning -> same forward as the base + // gemma4 model (the self-cond matmul only runs in the decoder/denoise phase). + LOG_INF("prefill (encoder, no self-cond): %d tokens in %.3f s (%.1f tok/s)\n", + prefix_len, prefill_s, prefill_s > 0.0 ? prefix_len / prefill_s : 0.0); + } + + std::vector canvas(canvas_length); + std::vector argmax_canvas(canvas_length, -1); + std::vector prev_argmax(canvas_length, -1); + std::vector accepted(canvas_length); + int32_t stop_flag = 0; + // sparse self-conditioning over the canvas: the previous step's top-SC_K token ids + their + // (renormalized) softmax probabilities per position. Fed to the next decode (Option-2 graph + // gather: the decoder gathers just these SC_K embedding rows and blends them, instead of a + // dense full-vocab probs @ token_embd matmul). Zero probs => no self-conditioning (step 1). + // SC_K must match the graph's fixed gather width. Default is 256; use + // --diffusion-self-cond-top-k can be used for gated experiments with a smaller sparse blend. + std::vector sc_ids ((size_t) SC_K * canvas_length, 0); + std::vector sc_probs((size_t) SC_K * canvas_length, 0.0f); + + const bool pin_host_outputs = params.diffusion.pin_host_outputs; + diffusion_cuda_host_pin pin_argmax_canvas( + argmax_canvas.data(), argmax_canvas.size() * sizeof(argmax_canvas[0]), pin_host_outputs); + diffusion_cuda_host_pin pin_stop_flag(&stop_flag, sizeof(stop_flag), pin_host_outputs); + + // all generated tokens across the autoregressive canvas blocks + std::vector generated; + + // ---- autoregressive block loop: each block denoises a canvas against the cached prefix, + // then (if continuing) commits its finalized tokens to the cache as the next prefix ---- + int n_blocks_run = 0; + int n_steps_total = 0; + const auto t_gen_start = std::chrono::steady_clock::now(); + double decode_enqueue_s = 0.0; + double sample_sync_s = 0.0; + double host_loop_s = 0.0; + + bool done = false; + bool failed = false; + for (int block = 0; block < max_canvases && !done; ++block) { + ++n_blocks_run; + // 1. initialize canvas with random tokens + for (auto & t : canvas) t = rand_tok(rng); + std::fill(prev_argmax.begin(), prev_argmax.end(), -1); + llama_set_diffusion_self_cond_topk(ctx, nullptr, nullptr, 0, 0); // first step: zero self-conditioning + + // 2. denoising loop (DECODER phase): cur_step = n_steps .. 1 + int step_k = 0; // top-k used for the current step (0 = full softmax), for logging + for (int cur_step = n_steps; cur_step >= 1; --cur_step) { + ++n_steps_total; + // 2a. decode the canvas only, at positions [n_past, n_past+canvas). Bidirectional, with + // self-conditioning; it reads the cached prefix read-only. All canvas tokens are outputs. + llama_set_causal_attn(ctx, false); + llama_set_diffusion_decoder_phase(ctx, true); + batch.n_tokens = canvas_length; + for (int j = 0; j < canvas_length; ++j) { + batch.token[j] = canvas[j]; + batch.pos[j] = n_past + j; + batch.n_seq_id[j] = 1; + batch.seq_id[j][0] = 0; + batch.logits[j] = 1; + } + const auto t_decode_start = std::chrono::steady_clock::now(); + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("error: llama_decode failed at step %d\n", cur_step); + break; + } + decode_enqueue_s += std::chrono::duration(std::chrono::steady_clock::now() - t_decode_start).count(); + + // 2b. linear temperature schedule: t = t_min + (t_max - t_min) * (cur_step / n_steps) + const float temp = TEMP_MIN + (TEMP_MAX - TEMP_MIN) * ((float) cur_step / (float) n_steps); + + std::vector entropy(canvas_length); + std::vector sampled(canvas_length); + + // k for this step: 0 = full softmax. With annealing, k is high at the first (high-entropy) + // step and low at the last, since early canvases are flat (need many tokens) and late ones + // are peaked (a few suffice). + int k_step = 0; + if (topk_start > 0 && topk_end > 0) { + const float frac = (n_steps > 1) ? (float) (cur_step - 1) / (float) (n_steps - 1) : 0.0f; // 1 at first step (cur_step=n_steps), 0 at last (cur_step=1) + k_step = (int) lroundf(topk_end + (topk_start - topk_end) * frac); + } else if (topk_fixed > 0) { + k_step = topk_fixed; + } + if (k_step <= 0 || k_step >= n_vocab) k_step = 0; + step_k = k_step; + + if (use_device_loop) { + bool sampled_ok = true; + const int block_step = n_steps - cur_step + 1; + const bool final_step = cur_step == 1; + const bool use_device_early_stop = device_early_stop_interval > 0; + const bool reset_stop_state = use_device_early_stop && block_step == 1; + const bool check_stop = use_device_early_stop && !final_step; + stop_flag = 0; + const auto t_sample_start = std::chrono::steady_clock::now(); + llama_diffusion_sample_params sample_params = { + /* .n_tokens = */ canvas_length, + /* .top_k = */ k_step, + /* .self_cond_top_k = */ SC_K, + /* .temperature = */ temp, + /* .seed = */ params.sampling.seed == LLAMA_DEFAULT_SEED ? 1234u : params.sampling.seed, + /* .step = */ (uint32_t) n_steps_total, + /* .top_k_tail_correction = */ topk_tail != 0, + /* .cuda_fast_top_k = */ params.diffusion.cuda_fast_top_k, + /* .cuda_direct_self_cond = */ params.diffusion.cuda_direct_self_cond, + /* .cuda_final_tokens_on_stop = */ params.diffusion.cuda_final_tokens_on_stop, + /* .cuda_fused_top_k_sample = */ params.diffusion.cuda_fused_top_k_sample, + /* .cuda_tight_top_k = */ params.diffusion.cuda_tight_top_k, + /* .cuda_parallel_full_softmax = */ params.diffusion.cuda_parallel_full_softmax, + /* .cuda_fused_full_softmax = */ params.diffusion.cuda_fused_full_softmax, + /* .cuda_top_k_local_k = */ params.diffusion.cuda_top_k_local_k, + }; + llama_diffusion_sample_result sample_result = { + /* .sampled = */ nullptr, + /* .argmax = */ nullptr, + /* .entropy = */ nullptr, + /* .self_cond_ids = */ nullptr, + /* .self_cond_probs = */ nullptr, + /* .final_tokens = */ (final_step || check_stop) ? argmax_canvas.data() : nullptr, + /* .stop = */ check_stop ? &stop_flag : nullptr, + /* .entropy_bound = */ entropy_bound, + /* .confidence_threshold = */ CONFIDENCE_THRESHOLD, + /* .stability_threshold = */ STABILITY_THRESHOLD, + /* .update_canvas_on_device = */ cur_step > 1, + /* .update_stop_state_on_device = */ use_device_early_stop, + /* .check_stop_on_device = */ check_stop, + /* .reset_stop_state = */ reset_stop_state, + }; + sampled_ok = llama_diffusion_sample_topk(ctx, &sample_params, &sample_result); + sample_sync_s += std::chrono::duration(std::chrono::steady_clock::now() - t_sample_start).count(); + llama_memory_seq_rm(mem, 0, n_past, -1); + if (!sampled_ok) { + LOG_ERR("error: CUDA diffusion device loop failed at step %d (k=%d)\n", cur_step, k_step); + failed = true; + done = true; + break; + } + if (check_stop && stop_flag != 0) { + break; + } + continue; + } + + // sparse self-cond for the NEXT step: top-SC_K (id, prob) per position. Cleared each step; + // unused slots stay (id 0, prob 0) so the graph gather contributes nothing for them. + if (!use_device_self_cond) { + std::fill(sc_probs.begin(), sc_probs.end(), 0.0f); + std::fill(sc_ids.begin(), sc_ids.end(), 0); + } + + bool sampled_ok = true; + const auto t_sample_start = std::chrono::steady_clock::now(); + if (use_gpu_sampling) { + llama_diffusion_sample_params sample_params = { + /* .n_tokens = */ canvas_length, + /* .top_k = */ k_step, + /* .self_cond_top_k = */ SC_K, + /* .temperature = */ temp, + /* .seed = */ params.sampling.seed == LLAMA_DEFAULT_SEED ? 1234u : params.sampling.seed, + /* .step = */ (uint32_t) n_steps_total, + /* .top_k_tail_correction = */ topk_tail != 0, + /* .cuda_fast_top_k = */ params.diffusion.cuda_fast_top_k, + /* .cuda_direct_self_cond = */ params.diffusion.cuda_direct_self_cond, + /* .cuda_final_tokens_on_stop = */ params.diffusion.cuda_final_tokens_on_stop, + /* .cuda_fused_top_k_sample = */ params.diffusion.cuda_fused_top_k_sample, + /* .cuda_tight_top_k = */ params.diffusion.cuda_tight_top_k, + /* .cuda_parallel_full_softmax = */ params.diffusion.cuda_parallel_full_softmax, + /* .cuda_fused_full_softmax = */ params.diffusion.cuda_fused_full_softmax, + /* .cuda_top_k_local_k = */ params.diffusion.cuda_top_k_local_k, + }; + llama_diffusion_sample_result sample_result = { + /* .sampled = */ sampled.data(), + /* .argmax = */ argmax_canvas.data(), + /* .entropy = */ entropy.data(), + /* .self_cond_ids = */ use_device_self_cond ? nullptr : sc_ids.data(), + /* .self_cond_probs = */ use_device_self_cond ? nullptr : sc_probs.data(), + /* .final_tokens = */ nullptr, + /* .stop = */ nullptr, + /* .entropy_bound = */ 0.0f, + /* .confidence_threshold = */ 0.0f, + /* .stability_threshold = */ 0, + /* .update_canvas_on_device = */ false, + /* .update_stop_state_on_device = */ false, + /* .check_stop_on_device = */ false, + /* .reset_stop_state = */ false, + }; + sampled_ok = llama_diffusion_sample_topk(ctx, &sample_params, &sample_result); + if (!sampled_ok) { + LOG_ERR("error: CUDA diffusion sampling failed at step %d (k=%d)\n", cur_step, k_step); + } + } else if (k_step == 0) { + // canvas logits occupy rows [0, canvas_length) (canvas-only ubatch) + const float * logits = llama_get_logits(ctx); + // ---- full softmax over the whole vocabulary (reference behaviour) ---- + // self-cond still feeds only the top-SC_K tokens (full-normalized probs); the dropped + // tail carries negligible embedding weight and the post RMS norm absorbs the scale. + std::vector probs(n_vocab); + std::vector> scheap; scheap.reserve(SC_K); // min-heap of (x, idx), size SC_K + const auto cmp = [](const std::pair&a, const std::pair&b){ return a.first > b.first; }; + for (int j = 0; j < canvas_length; ++j) { + const float * lg = logits + (size_t) j * n_vocab; + float maxl = -INFINITY; + int amax = 0; + scheap.clear(); + for (int v = 0; v < n_vocab; ++v) { + const float x = lg[v] / temp; + if (x > maxl) { maxl = x; amax = v; } + if ((int) scheap.size() < SC_K) { + scheap.push_back({x, v}); + std::push_heap(scheap.begin(), scheap.end(), cmp); + } else if (x > scheap.front().first) { + std::pop_heap(scheap.begin(), scheap.end(), cmp); + scheap.back() = {x, v}; + std::push_heap(scheap.begin(), scheap.end(), cmp); + } + } + float sum = 0.0f; + for (int v = 0; v < n_vocab; ++v) { + const float p = expf(lg[v] / temp - maxl); + probs[v] = p; + sum += p; + } + float ent = 0.0f; + const float r = rand_unif(rng) * sum; + float cum = 0.0f; + int tok = amax; + bool picked = false; + for (int v = 0; v < n_vocab; ++v) { + const float p = probs[v] / sum; + if (p > 0.0f) ent -= p * logf(p); + cum += probs[v]; + if (!picked && cum >= r) { tok = v; picked = true; } + } + // store top-SC_K self-cond (full-normalized probability per selected token) + int32_t * sid = sc_ids.data() + (size_t) j * SC_K; + float * spr = sc_probs.data() + (size_t) j * SC_K; + int slot = 0; + for (auto & h : scheap) { sid[slot] = h.second; spr[slot] = expf(h.first - maxl) / sum; ++slot; } + entropy[j] = ent; + sampled[j] = tok; + argmax_canvas[j] = amax; + } + } else { + // canvas logits occupy rows [0, canvas_length) (canvas-only ubatch) + const float * logits = llama_get_logits(ctx); + // ---- top-k host sampling: softmax / entropy / sample / self-cond over the top-k + // logits only. Self-cond feeds the top min(k,SC_K) tokens (renormalized over the + // sampled top-k), gathered in-graph; the dropped tail carries negligible weight. ---- + const int heap_k = std::max(k_step, SC_K); // collect enough for both sampling and self-cond + std::vector> heap; // min-heap of (logit/temp, idx), size heap_k + heap.reserve(heap_k); + const auto cmp = [](const std::pair&a, const std::pair&b){ return a.first > b.first; }; + for (int j = 0; j < canvas_length; ++j) { + const float * lg = logits + (size_t) j * n_vocab; + float maxl = -INFINITY; + int amax = 0; + heap.clear(); + for (int v = 0; v < n_vocab; ++v) { + const float x = lg[v] / temp; + if (x > maxl) { maxl = x; amax = v; } + if ((int) heap.size() < heap_k) { + heap.push_back({x, v}); + std::push_heap(heap.begin(), heap.end(), cmp); + } else if (x > heap.front().first) { + std::pop_heap(heap.begin(), heap.end(), cmp); + heap.back() = {x, v}; + std::push_heap(heap.begin(), heap.end(), cmp); + } + } + // sort the collected entries by logit descending (exp is monotonic with x): the + // first k_step drive sampling/entropy, the first SC_K drive self-cond. + std::sort(heap.begin(), heap.end(), [](const std::pair&a, const std::pair&b){ return a.first > b.first; }); + + // softmax over the sampled top-k (renormalized); reuse .first to hold exp value + float Zk = 0.0f; + for (int i = 0; i < k_step; ++i) { const float e = expf(heap[i].first - maxl); heap[i].first = e; Zk += e; } + + float ent; + if (topk_tail) { + // exact full entropy via logsumexp over all logits (one expf pass, no per-token logf): + // H = ln(Z) - (sum_i (z_i-max) e_i)/Z + double Zf = 0.0, T = 0.0; + for (int v = 0; v < n_vocab; ++v) { + const double d = (double) (lg[v] / temp) - (double) maxl; + const double e = exp(d); + Zf += e; T += d * e; + } + ent = (float) (log(Zf) - T / Zf); + } else { + ent = 0.0f; + for (int i = 0; i < k_step; ++i) { const float q = heap[i].first / Zk; if (q > 0.0f) ent -= q * logf(q); } + } + + // multinomial sample over the sampled top-k + const float r = rand_unif(rng) * Zk; + float cum = 0.0f; + int tok = amax; + bool picked = false; + for (int i = 0; i < k_step; ++i) { + cum += heap[i].first; + if (!picked && cum >= r) { tok = heap[i].second; picked = true; } + } + + // store top-SC_K self-cond (renormalized over the sampled top-k) + int32_t * sid = sc_ids.data() + (size_t) j * SC_K; + float * spr = sc_probs.data() + (size_t) j * SC_K; + const int n_sc = std::min(k_step, SC_K); + for (int i = 0; i < n_sc; ++i) { sid[i] = heap[i].second; spr[i] = heap[i].first / Zk; } + entropy[j] = ent; + sampled[j] = tok; + argmax_canvas[j] = amax; + } + } + sample_sync_s += std::chrono::duration(std::chrono::steady_clock::now() - t_sample_start).count(); + if (!sampled_ok) { + llama_memory_seq_rm(mem, 0, n_past, -1); + failed = true; + done = true; + break; + } + + const auto t_host_start = std::chrono::steady_clock::now(); + + // roll back the canvas K/V written by this decode so the cache keeps only the committed + // prefix [0, n_past); the next step re-decodes the canvas fresh against that prefix. + llama_memory_seq_rm(mem, 0, n_past, -1); + + // 2c. entropy-bound accept: sort positions by entropy ascending, accept the prefix + // where sum(entropy of all-but-last) <= entropy_bound (monotonic -> prefix selection) + std::vector order(canvas_length); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&](int a, int b) { return entropy[a] < entropy[b]; }); + + std::vector accept_mask(canvas_length, 0); + float prefix = 0.0f; + for (int k = 0; k < canvas_length; ++k) { + if (prefix <= entropy_bound) { + accept_mask[order[k]] = 1; + prefix += entropy[order[k]]; + } else { + break; + } + } + + // accepted canvas: accepted positions take the sampled token, others keep current + int n_accept = 0; + for (int i = 0; i < canvas_length; ++i) { + if (accept_mask[i]) { accepted[i] = sampled[i]; ++n_accept; } + } + + // mean entropy (confidence) + const float mean_ent = std::accumulate(entropy.begin(), entropy.end(), 0.0f) / canvas_length; + + // 2d. stopping: stable (argmax canvas unchanged for STABILITY_THRESHOLD steps) AND confident + bool stable = (STABILITY_THRESHOLD == 0) || (argmax_canvas == prev_argmax); + bool confident = mean_ent < CONFIDENCE_THRESHOLD; + LOG_INF("step %3d temp=%.3f k=%d accepted=%4d/%d mean_entropy=%.4f%s\n", + cur_step, temp, step_k, n_accept, canvas_length, mean_ent, + (stable && confident) ? " [STOP]" : ""); + if (stable && confident) { + break; + } + prev_argmax = argmax_canvas; + + // self-conditioning for the NEXT denoising step. With device self-cond this was already + // copied D2D into the reused graph input tensors by llama_diffusion_sample_topk(). + if (!use_device_self_cond) { + llama_set_diffusion_self_cond_topk(ctx, sc_ids.data(), sc_probs.data(), SC_K, canvas_length); + } + + // 2e. renoise non-accepted positions with fresh random tokens -> next canvas + for (int i = 0; i < canvas_length; ++i) { + canvas[i] = accept_mask[i] ? accepted[i] : rand_tok(rng); + } + host_loop_s += std::chrono::duration(std::chrono::steady_clock::now() - t_host_start).count(); + } + if (failed) { + break; + } + + // 3. block output = the inline argmax of the last (stable) denoising step's logits. + // This matches the reference (DiffusionGemma _denoising_step uses argmax(processed_logits) + // taken during the denoising forward, read once the canvas is stable + confident). There is + // no separate read-out: the never-accepted tail is the model's own prediction given the + // settled context, rather than a stale-random scratch buffer. + const std::vector & block_out = argmax_canvas; + + // accumulate this block's finalized tokens; stop after a block that contains an EOG token + generated.insert(generated.end(), block_out.begin(), block_out.end()); + for (int j = 0; j < canvas_length; ++j) { + if (llama_vocab_is_eog(vocab, block_out[j])) { done = true; break; } + } + + // 4. COMMIT (ENCODER phase): if another block follows, write the finalized canvas's plain + // (non-self-conditioned, causal) K/V into the cache and advance the prefix pointer, so the + // next block's canvas cross-attends to it. Skipped on the last block / on EOG. + if (!done && block + 1 < max_canvases) { + llama_set_causal_attn(ctx, true); + llama_set_diffusion_decoder_phase(ctx, false); + llama_set_diffusion_self_cond_topk(ctx, nullptr, nullptr, 0, 0); + batch.n_tokens = canvas_length; + for (int j = 0; j < canvas_length; ++j) { + batch.token[j] = block_out[j]; + batch.pos[j] = n_past + j; + batch.n_seq_id[j] = 1; + batch.seq_id[j][0] = 0; + batch.logits[j] = (j == canvas_length - 1) ? 1 : 0; + } + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("error: canvas commit (encoder) decode failed at block %d\n", block); + break; + } + n_past += canvas_length; // finalized canvas is now part of the read-only prefix + LOG_INF("committed block %d -> n_past=%d\n", block, n_past); + } + } // end autoregressive block loop + + const double gen_s = std::chrono::duration(std::chrono::steady_clock::now() - t_gen_start).count(); + + llama_batch_free(batch); + + if (failed) { + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + // the full denoised canvas (thought channel + response), for reference + LOG_INF("\n=== generated canvas ===\n%s\n", common_detokenize(vocab, generated, false).c_str()); + + // the model answers in a "<|channel>thought ... " block followed by the response; + // extract the final response (after the last channel-close), truncated at the first + // end-of-generation token, and drop trailing duplicate sentences. + llama_token chan_close = LLAMA_TOKEN_NULL; + { + auto t = common_tokenize(vocab, "", false, true); + if (t.size() == 1) chan_close = t[0]; + } + const int n_gen = (int) generated.size(); + int start = 0; + if (chan_close != LLAMA_TOKEN_NULL) { + for (int j = 0; j < n_gen; ++j) if (generated[j] == chan_close) start = j + 1; + } + std::vector answer; + for (int j = start; j < n_gen; ++j) { + if (llama_vocab_is_eog(vocab, generated[j])) break; + answer.push_back(generated[j]); + } + std::string ans = common_detokenize(vocab, answer, false); + // drop a trailing exact-duplicate of the answer if the canvas repeated it + { + std::string s = ans; + size_t h = s.find_first_not_of(" \n\t"); if (h != std::string::npos) s = s.substr(h); + const size_t half = s.size() / 2; + if (half > 0 && s.compare(0, half, s, s.size() - half, half) == 0) { + ans = s.substr(0, half); // "X X" -> "X" + } + } + LOG_INF("=== answer ===\n%s\n", ans.c_str()); + + // generation timing (excludes model load + prompt prefill): wall-clock of the denoising + // block loop, the canvas tokens produced, and effective throughput. + const int n_canvas_tok = n_blocks_run * canvas_length; + LOG_INF("=== perf ===\n"); + LOG_INF("generation: %d block(s), %d denoising steps, %d canvas tokens in %.2f s " + "(%.1f canvas tok/s, %.3f s/step); answer tokens=%d\n", + n_blocks_run, n_steps_total, n_canvas_tok, gen_s, + gen_s > 0.0 ? n_canvas_tok / gen_s : 0.0, + n_steps_total > 0 ? gen_s / n_steps_total : 0.0, + (int) answer.size()); + if (log_step_timing) { + LOG_INF("timing: decode enqueue %.3f s, sample/sync %.3f s, host loop %.3f s\n", + decode_enqueue_s, sample_sync_s, host_loop_s); + } + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return 0; +} diff --git a/examples/diffusion-gemma/diffusion-gemma-server.cpp b/examples/diffusion-gemma/diffusion-gemma-server.cpp new file mode 100644 index 000000000000..97f5b99ce65b --- /dev/null +++ b/examples/diffusion-gemma/diffusion-gemma-server.cpp @@ -0,0 +1,1133 @@ +// OpenAI-compatible HTTP server for the block-diffusion models (diffusion-gemma). +// +// This is the llama-server analogue for the diffusion family: it loads a block-diffusion model +// once and serves the same denoising generation loop as diffusion-gemma-cli over HTTP, exposing +// the OpenAI endpoints (/v1/chat/completions, /v1/completions, /v1/models) plus the llama-server +// observability surface (/health, /v1/health, /props, /metrics, /slots) with the same response +// `timings`/`usage` objects, the same per-request timing logs, the same access log, and a Prometheus +// /metrics endpoint -- all with the metric semantics re-mapped to block diffusion (canvas blocks, +// denoising steps, s/step) instead of autoregressive token-by-token decode. +// +// Generation is NOT autoregressive: each request denoises one or more 256-token canvases against a +// cached prompt prefix (see diffusion-gemma-cli.cpp for the full description). A single llama_context +// is reused across requests and is not thread-safe, so generation is serialized behind a mutex (one +// slot); the HTTP layer still accepts many connections concurrently. + +#include "arg.h" +#include "chat.h" +#include "common.h" +#include "llama.h" +#include "log.h" +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +// llama-server-style log macros (server-common.h): "srv : ..." / "slot : id N | ..." +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SLT_INF(id, fmt, ...) LOG_INF("slot %12.*s: id %2d | " fmt, 12, __func__, (id), __VA_ARGS__) + +// reference defaults from generation_config.json / DiffusionGemmaGenerationConfig (mirror the CLI) +static constexpr int DEF_CANVAS_LENGTH = 256; +static constexpr int DEF_MAX_DENOISE_STEPS = 48; +static constexpr float ENTROPY_BOUND = 0.1f; +static constexpr float TEMP_MIN = 0.4f; +static constexpr float TEMP_MAX = 0.8f; +static constexpr float CONFIDENCE_THRESHOLD = 0.005f; +static constexpr int STABILITY_THRESHOLD = 1; +static constexpr int DEF_SC_K = 256; +static constexpr int GPU_SAMPLING_MAX_TOP_K = 1024; // CUDA diffusion sampler limit + +#ifdef GGML_USE_CUDA +struct diffusion_cuda_host_pin { + diffusion_cuda_host_pin(void * ptr, size_t size, bool enabled) : ptr(ptr), size(size) { + registered = enabled && ptr && size && ggml_backend_cuda_register_host_buffer(ptr, size); + } + + ~diffusion_cuda_host_pin() { + if (registered) { + ggml_backend_cuda_unregister_host_buffer(ptr); + } + } + + diffusion_cuda_host_pin(const diffusion_cuda_host_pin &) = delete; + diffusion_cuda_host_pin & operator=(const diffusion_cuda_host_pin &) = delete; + + void * ptr = nullptr; + size_t size = 0; + bool registered = false; +}; +#else +struct diffusion_cuda_host_pin { + diffusion_cuda_host_pin(void *, size_t, bool) {} +}; +#endif + +static int diffusion_self_cond_top_k(const common_params & params) { + const int k = params.diffusion.self_cond_top_k; + if (k <= 0) { + return DEF_SC_K; + } + return std::min(k, DEF_SC_K); +} + +// ------------------------------------------------------------------------------------------------ +// per-request generation parameters and result +// ------------------------------------------------------------------------------------------------ +struct diffusion_request { + int canvas_length = DEF_CANVAS_LENGTH; + int n_steps = DEF_MAX_DENOISE_STEPS; + int max_canvases = 1; // ceil(max_tokens / canvas_length) + int topk_fixed = 0; // 0 = full softmax + int topk_start = 0; + int topk_end = 0; + bool topk_tail = false; + bool ignore_eos = false; // run all max_canvases blocks (don't stop at end-of-text) + uint32_t seed = 1234; +}; + +struct diffusion_result { + std::string answer; // final response text (post channel-split / eog truncation) + std::string full; // full detokenized canvas (thought + response) + int prompt_tokens = 0; // prompt prefix tokens processed (encoder prefill) + int answer_tokens = 0; // tokens in the extracted answer + int n_blocks = 0; // canvas blocks run + int n_steps_total = 0; // denoising steps across all blocks + int canvas_tokens = 0; // n_blocks * canvas_length + int n_decode = 0; // total llama_decode() calls (prefill + denoise + commit) + double prefill_ms = 0.0; // encoder-phase prompt prefill wall time + double gen_ms = 0.0; // denoising block-loop wall time + bool ok = false; + std::string error; +}; + +// ------------------------------------------------------------------------------------------------ +// server metrics (llama-server analogue; protected by its own mutex) +// ------------------------------------------------------------------------------------------------ +struct server_metrics { + std::mutex mu; + int64_t t_start = 0; // process start (unix seconds) + + // cumulative counters + uint64_t n_requests_total = 0; + uint64_t n_prompt_tokens_total = 0; + double t_prompt_ms_total = 0.0; + uint64_t n_tokens_predicted_total = 0; // answer tokens + double t_predicted_ms_total = 0.0; // denoise time + uint64_t n_decode_total = 0; + uint64_t n_blocks_total = 0; + uint64_t n_steps_total = 0; + uint64_t n_canvas_tokens_total = 0; + + // last-request gauges + double last_prompt_tps = 0.0; + double last_predicted_tps = 0.0; + double last_steps_per_second = 0.0; + double last_canvas_tps = 0.0; + double last_ms_per_step = 0.0; + + std::atomic n_processing{0}; + + void add(const diffusion_result & r) { + std::lock_guard lk(mu); + n_requests_total += 1; + n_prompt_tokens_total += r.prompt_tokens; + t_prompt_ms_total += r.prefill_ms; + n_tokens_predicted_total += r.answer_tokens; + t_predicted_ms_total += r.gen_ms; + n_decode_total += r.n_decode; + n_blocks_total += r.n_blocks; + n_steps_total += r.n_steps_total; + n_canvas_tokens_total += r.canvas_tokens; + last_prompt_tps = r.prefill_ms > 0 ? r.prompt_tokens * 1e3 / r.prefill_ms : 0.0; + last_predicted_tps = r.gen_ms > 0 ? r.answer_tokens * 1e3 / r.gen_ms : 0.0; + last_steps_per_second= r.gen_ms > 0 ? r.n_steps_total * 1e3 / r.gen_ms : 0.0; + last_canvas_tps = r.gen_ms > 0 ? r.canvas_tokens * 1e3 / r.gen_ms : 0.0; + last_ms_per_step = r.n_steps_total > 0 ? r.gen_ms / r.n_steps_total : 0.0; + } +}; + +// ------------------------------------------------------------------------------------------------ +// server state (one model + one reused context == one slot, guarded by a mutex) +// ------------------------------------------------------------------------------------------------ +struct diffusion_server { + llama_model * model = nullptr; + llama_context * ctx = nullptr; + const llama_vocab * vocab = nullptr; + llama_memory_t mem = nullptr; + llama_batch batch{}; + common_chat_templates_ptr templates; + + int canvas_length = DEF_CANVAS_LENGTH; + int n_steps = DEF_MAX_DENOISE_STEPS; + int n_ctx = 0; + int n_ub = 0; // ubatch == canvas_length (keeps the decoder gather graph small) + int n_vocab = 0; + int topk_fixed = 0; + int topk_start = 0; + int topk_end = 0; + bool topk_tail = false; + bool use_gpu_sampling = false; + bool use_device_self_cond = false; + bool use_device_loop = false; + int device_early_stop_interval = 0; + common_params_diffusion diffusion; + std::string model_id; + std::string model_path; + std::string build_info = "llama.cpp diffusion-gemma-server"; + + server_metrics metrics; + std::mutex gen_mutex; // serialize generation (ctx is single-threaded == one slot) + + std::string format_messages(const json & messages) const { + common_chat_templates_inputs inputs; + inputs.messages = common_chat_msgs_parse_oaicompat(messages); + inputs.add_generation_prompt = true; + return common_chat_templates_apply(templates.get(), inputs).prompt; + } + + // causal prefill of `toks` starting at position pos0, chunked to n_ub (encoder phase). + // increments *n_dec with the number of llama_decode() calls made. + bool prefill_causal(const std::vector & toks, int pos0, int * n_dec) { + llama_set_causal_attn(ctx, true); + llama_set_diffusion_decoder_phase(ctx, false); + llama_set_diffusion_self_cond_topk(ctx, nullptr, nullptr, 0, 0); + const int n = (int) toks.size(); + for (int off = 0; off < n; off += n_ub) { + const int cnt = std::min(n_ub, n - off); + batch.n_tokens = cnt; + for (int i = 0; i < cnt; ++i) { + batch.token[i] = toks[off + i]; + batch.pos[i] = pos0 + off + i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = (off + i == n - 1) ? 1 : 0; + } + if (llama_decode(ctx, batch) != 0) return false; + ++(*n_dec); + } + return true; + } + + // run the block-diffusion denoising loop for one request (mirrors diffusion-gemma-cli.cpp) + diffusion_result generate(const std::vector & prompt_tokens, const diffusion_request & rq) { + diffusion_result out; + const int prefix_len = (int) prompt_tokens.size(); + int n_decode = 0; + + int max_canvases = rq.max_canvases; + const int fit = (n_ctx - prefix_len) / canvas_length - 1; + if (fit < 1) { + out.error = "prompt too long for context (n_ctx=" + std::to_string(n_ctx) + ")"; + return out; + } + if (max_canvases > fit) max_canvases = fit; + if (max_canvases < 1) max_canvases = 1; + + llama_memory_clear(mem, true); + + std::mt19937 rng(rq.seed); + std::uniform_int_distribution rand_tok(0, n_vocab - 1); + std::uniform_real_distribution rand_unif(0.0f, 1.0f); + + // ---- ENCODER phase: prefill the prompt prefix (causal, no self-conditioning) ---- + const auto t_prefill0 = std::chrono::steady_clock::now(); + int n_past = 0; + if (prefix_len > 0) { + if (!prefill_causal(prompt_tokens, 0, &n_decode)) { + out.error = "prompt prefill (encoder) decode failed"; + return out; + } + n_past = prefix_len; + } + out.prefill_ms = std::chrono::duration(std::chrono::steady_clock::now() - t_prefill0).count(); + + std::vector canvas(canvas_length); + std::vector argmax_canvas(canvas_length, -1); + std::vector prev_argmax(canvas_length, -1); + std::vector accepted(canvas_length); + int32_t stop_flag = 0; + const int SC_K = diffusion.self_cond_top_k; + std::vector sc_ids ((size_t) SC_K * canvas_length, 0); + std::vector sc_probs((size_t) SC_K * canvas_length, 0.0f); + std::vector generated; + + const bool pin_host_outputs = diffusion.pin_host_outputs; + diffusion_cuda_host_pin pin_argmax_canvas( + argmax_canvas.data(), argmax_canvas.size() * sizeof(argmax_canvas[0]), pin_host_outputs); + diffusion_cuda_host_pin pin_stop_flag(&stop_flag, sizeof(stop_flag), pin_host_outputs); + + int n_blocks_run = 0, n_steps_total = 0; + bool done = false; + const int topk_max_requested = + (rq.topk_start > 0 && rq.topk_end > 0) ? std::max(rq.topk_start, rq.topk_end) : rq.topk_fixed; + const bool use_device_loop_request = use_device_loop && + (topk_max_requested <= 0 || topk_max_requested <= GPU_SAMPLING_MAX_TOP_K); + + const auto t_gen0 = std::chrono::steady_clock::now(); + for (int block = 0; block < max_canvases && !done; ++block) { + ++n_blocks_run; + for (auto & t : canvas) t = rand_tok(rng); + std::fill(prev_argmax.begin(), prev_argmax.end(), -1); + llama_set_diffusion_self_cond_topk(ctx, nullptr, nullptr, 0, 0); + + for (int cur_step = n_steps; cur_step >= 1; --cur_step) { + ++n_steps_total; + const float temp = TEMP_MIN + (TEMP_MAX - TEMP_MIN) * ((float) cur_step / (float) n_steps); + + int k_step = 0; + if (rq.topk_start > 0 && rq.topk_end > 0) { + const float frac = (n_steps > 1) ? (float) (cur_step - 1) / (float) (n_steps - 1) : 0.0f; + k_step = (int) lroundf(rq.topk_end + (rq.topk_start - rq.topk_end) * frac); + } else if (rq.topk_fixed > 0) { + k_step = rq.topk_fixed; + } + if (k_step <= 0 || k_step >= n_vocab) k_step = 0; + const bool use_gpu_sampling_step = use_gpu_sampling && + (k_step <= 0 || k_step <= GPU_SAMPLING_MAX_TOP_K); + const bool use_device_self_cond_step = use_gpu_sampling_step && use_device_self_cond; + const bool use_device_loop_step = use_device_loop_request && use_gpu_sampling_step; + + llama_set_causal_attn(ctx, false); + llama_set_diffusion_decoder_phase(ctx, true); + llama_set_diffusion_gpu_sampling(ctx, use_gpu_sampling_step); + batch.n_tokens = canvas_length; + for (int j = 0; j < canvas_length; ++j) { + batch.token[j] = canvas[j]; + batch.pos[j] = n_past + j; + batch.n_seq_id[j] = 1; + batch.seq_id[j][0] = 0; + batch.logits[j] = 1; + } + if (llama_decode(ctx, batch) != 0) { + out.error = "llama_decode failed at denoising step " + std::to_string(cur_step); + return out; + } + ++n_decode; + + std::vector entropy(canvas_length); + std::vector sampled(canvas_length); + + if (use_device_loop_step) { + const int block_step = n_steps - cur_step + 1; + const bool final_step = cur_step == 1; + const bool use_device_early_stop = device_early_stop_interval > 0; + const bool reset_stop_state = use_device_early_stop && block_step == 1; + const bool check_stop = use_device_early_stop && !final_step; + stop_flag = 0; + llama_diffusion_sample_params sample_params = { + /* .n_tokens = */ canvas_length, + /* .top_k = */ k_step, + /* .self_cond_top_k = */ SC_K, + /* .temperature = */ temp, + /* .seed = */ rq.seed, + /* .step = */ (uint32_t) n_steps_total, + /* .top_k_tail_correction = */ rq.topk_tail, + /* .cuda_fast_top_k = */ diffusion.cuda_fast_top_k, + /* .cuda_direct_self_cond = */ diffusion.cuda_direct_self_cond, + /* .cuda_final_tokens_on_stop = */ diffusion.cuda_final_tokens_on_stop, + /* .cuda_fused_top_k_sample = */ diffusion.cuda_fused_top_k_sample, + /* .cuda_tight_top_k = */ diffusion.cuda_tight_top_k, + /* .cuda_parallel_full_softmax = */ diffusion.cuda_parallel_full_softmax, + /* .cuda_fused_full_softmax = */ diffusion.cuda_fused_full_softmax, + /* .cuda_top_k_local_k = */ diffusion.cuda_top_k_local_k, + }; + llama_diffusion_sample_result sample_result = { + /* .sampled = */ nullptr, + /* .argmax = */ nullptr, + /* .entropy = */ nullptr, + /* .self_cond_ids = */ nullptr, + /* .self_cond_probs = */ nullptr, + /* .final_tokens = */ (final_step || check_stop) ? argmax_canvas.data() : nullptr, + /* .stop = */ check_stop ? &stop_flag : nullptr, + /* .entropy_bound = */ ENTROPY_BOUND, + /* .confidence_threshold = */ CONFIDENCE_THRESHOLD, + /* .stability_threshold = */ STABILITY_THRESHOLD, + /* .update_canvas_on_device = */ cur_step > 1, + /* .update_stop_state_on_device = */ use_device_early_stop, + /* .check_stop_on_device = */ check_stop, + /* .reset_stop_state = */ reset_stop_state, + }; + if (!llama_diffusion_sample_topk(ctx, &sample_params, &sample_result)) { + llama_memory_seq_rm(mem, 0, n_past, -1); + out.error = "CUDA diffusion device loop failed at denoising step " + + std::to_string(cur_step) + " (k=" + std::to_string(k_step) + ")"; + return out; + } + llama_memory_seq_rm(mem, 0, n_past, -1); + if (check_stop && stop_flag != 0) { + break; + } + continue; + } + + if (!use_device_self_cond_step) { + std::fill(sc_probs.begin(), sc_probs.end(), 0.0f); + std::fill(sc_ids.begin(), sc_ids.end(), 0); + } + + if (use_gpu_sampling_step) { + llama_diffusion_sample_params sample_params = { + /* .n_tokens = */ canvas_length, + /* .top_k = */ k_step, + /* .self_cond_top_k = */ SC_K, + /* .temperature = */ temp, + /* .seed = */ rq.seed, + /* .step = */ (uint32_t) n_steps_total, + /* .top_k_tail_correction = */ rq.topk_tail, + /* .cuda_fast_top_k = */ diffusion.cuda_fast_top_k, + /* .cuda_direct_self_cond = */ diffusion.cuda_direct_self_cond, + /* .cuda_final_tokens_on_stop = */ diffusion.cuda_final_tokens_on_stop, + /* .cuda_fused_top_k_sample = */ diffusion.cuda_fused_top_k_sample, + /* .cuda_tight_top_k = */ diffusion.cuda_tight_top_k, + /* .cuda_parallel_full_softmax = */ diffusion.cuda_parallel_full_softmax, + /* .cuda_fused_full_softmax = */ diffusion.cuda_fused_full_softmax, + /* .cuda_top_k_local_k = */ diffusion.cuda_top_k_local_k, + }; + llama_diffusion_sample_result sample_result = { + /* .sampled = */ sampled.data(), + /* .argmax = */ argmax_canvas.data(), + /* .entropy = */ entropy.data(), + /* .self_cond_ids = */ use_device_self_cond_step ? nullptr : sc_ids.data(), + /* .self_cond_probs = */ use_device_self_cond_step ? nullptr : sc_probs.data(), + /* .final_tokens = */ nullptr, + /* .stop = */ nullptr, + /* .entropy_bound = */ 0.0f, + /* .confidence_threshold = */ 0.0f, + /* .stability_threshold = */ 0, + /* .update_canvas_on_device = */ false, + /* .update_stop_state_on_device = */ false, + /* .check_stop_on_device = */ false, + /* .reset_stop_state = */ false, + }; + if (!llama_diffusion_sample_topk(ctx, &sample_params, &sample_result)) { + llama_memory_seq_rm(mem, 0, n_past, -1); + out.error = "CUDA diffusion sampling failed at denoising step " + + std::to_string(cur_step) + " (k=" + std::to_string(k_step) + ")"; + return out; + } + } else if (k_step == 0) { + const float * logits = llama_get_logits(ctx); + const auto cmp = [](const std::pair&a, const std::pair&b){ return a.first > b.first; }; + std::vector probs(n_vocab); + std::vector> scheap; scheap.reserve(SC_K); + for (int j = 0; j < canvas_length; ++j) { + const float * lg = logits + (size_t) j * n_vocab; + float maxl = -INFINITY; int amax = 0; + scheap.clear(); + for (int v = 0; v < n_vocab; ++v) { + const float x = lg[v] / temp; + if (x > maxl) { maxl = x; amax = v; } + if ((int) scheap.size() < SC_K) { + scheap.push_back({x, v}); + std::push_heap(scheap.begin(), scheap.end(), cmp); + } else if (x > scheap.front().first) { + std::pop_heap(scheap.begin(), scheap.end(), cmp); + scheap.back() = {x, v}; + std::push_heap(scheap.begin(), scheap.end(), cmp); + } + } + float sum = 0.0f; + for (int v = 0; v < n_vocab; ++v) { const float p = expf(lg[v] / temp - maxl); probs[v] = p; sum += p; } + float ent = 0.0f; + const float r = rand_unif(rng) * sum; + float cum = 0.0f; int tok = amax; bool picked = false; + for (int v = 0; v < n_vocab; ++v) { + const float p = probs[v] / sum; + if (p > 0.0f) ent -= p * logf(p); + cum += probs[v]; + if (!picked && cum >= r) { tok = v; picked = true; } + } + int32_t * sid = sc_ids.data() + (size_t) j * SC_K; + float * spr = sc_probs.data() + (size_t) j * SC_K; + int slot = 0; + for (auto & h : scheap) { sid[slot] = h.second; spr[slot] = expf(h.first - maxl) / sum; ++slot; } + entropy[j] = ent; sampled[j] = tok; argmax_canvas[j] = amax; + } + } else { + const float * logits = llama_get_logits(ctx); + const auto cmp = [](const std::pair&a, const std::pair&b){ return a.first > b.first; }; + const int heap_k = std::max(k_step, SC_K); + std::vector> heap; heap.reserve(heap_k); + for (int j = 0; j < canvas_length; ++j) { + const float * lg = logits + (size_t) j * n_vocab; + float maxl = -INFINITY; int amax = 0; + heap.clear(); + for (int v = 0; v < n_vocab; ++v) { + const float x = lg[v] / temp; + if (x > maxl) { maxl = x; amax = v; } + if ((int) heap.size() < heap_k) { + heap.push_back({x, v}); + std::push_heap(heap.begin(), heap.end(), cmp); + } else if (x > heap.front().first) { + std::pop_heap(heap.begin(), heap.end(), cmp); + heap.back() = {x, v}; + std::push_heap(heap.begin(), heap.end(), cmp); + } + } + std::sort(heap.begin(), heap.end(), [](const std::pair&a, const std::pair&b){ return a.first > b.first; }); + float Zk = 0.0f; + for (int i = 0; i < k_step; ++i) { const float e = expf(heap[i].first - maxl); heap[i].first = e; Zk += e; } + float ent; + if (rq.topk_tail) { + double Zf = 0.0, T = 0.0; + for (int v = 0; v < n_vocab; ++v) { + const double d = (double) (lg[v] / temp) - (double) maxl; + const double e = exp(d); Zf += e; T += d * e; + } + ent = (float) (log(Zf) - T / Zf); + } else { + ent = 0.0f; + for (int i = 0; i < k_step; ++i) { const float q = heap[i].first / Zk; if (q > 0.0f) ent -= q * logf(q); } + } + const float r = rand_unif(rng) * Zk; + float cum = 0.0f; int tok = amax; bool picked = false; + for (int i = 0; i < k_step; ++i) { cum += heap[i].first; if (!picked && cum >= r) { tok = heap[i].second; picked = true; } } + int32_t * sid = sc_ids.data() + (size_t) j * SC_K; + float * spr = sc_probs.data() + (size_t) j * SC_K; + const int n_sc = std::min(k_step, SC_K); + for (int i = 0; i < n_sc; ++i) { sid[i] = heap[i].second; spr[i] = heap[i].first / Zk; } + entropy[j] = ent; sampled[j] = tok; argmax_canvas[j] = amax; + } + } + + llama_memory_seq_rm(mem, 0, n_past, -1); + + std::vector order(canvas_length); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&](int a, int b) { return entropy[a] < entropy[b]; }); + std::vector accept_mask(canvas_length, 0); + float prefix = 0.0f; + for (int k = 0; k < canvas_length; ++k) { + if (prefix <= ENTROPY_BOUND) { accept_mask[order[k]] = 1; prefix += entropy[order[k]]; } + else break; + } + for (int i = 0; i < canvas_length; ++i) if (accept_mask[i]) accepted[i] = sampled[i]; + + const float mean_ent = std::accumulate(entropy.begin(), entropy.end(), 0.0f) / canvas_length; + const bool stable = (STABILITY_THRESHOLD == 0) || (argmax_canvas == prev_argmax); + const bool confident = mean_ent < CONFIDENCE_THRESHOLD; + if (stable && confident) break; + prev_argmax = argmax_canvas; + + if (!use_device_self_cond_step) { + llama_set_diffusion_self_cond_topk(ctx, sc_ids.data(), sc_probs.data(), SC_K, canvas_length); + } + + for (int i = 0; i < canvas_length; ++i) canvas[i] = accept_mask[i] ? accepted[i] : rand_tok(rng); + } + + const std::vector & block_out = argmax_canvas; + generated.insert(generated.end(), block_out.begin(), block_out.end()); + if (!rq.ignore_eos) { + for (int j = 0; j < canvas_length; ++j) { + if (llama_vocab_is_eog(vocab, block_out[j])) { done = true; break; } + } + } + if (!done && block + 1 < max_canvases) { + if (!prefill_causal(block_out, n_past, &n_decode)) { + out.error = "canvas commit (encoder) decode failed"; + return out; + } + n_past += canvas_length; + } + } + out.gen_ms = std::chrono::duration(std::chrono::steady_clock::now() - t_gen0).count(); + out.full = common_detokenize(vocab, generated, false); + + // extract the final response: after the last "" close, until the first eog token + llama_token chan_close = LLAMA_TOKEN_NULL; + { + auto t = common_tokenize(vocab, "", false, true); + if (t.size() == 1) chan_close = t[0]; + } + const int n_gen = (int) generated.size(); + int start = 0; + if (chan_close != LLAMA_TOKEN_NULL) { + for (int j = 0; j < n_gen; ++j) if (generated[j] == chan_close) start = j + 1; + } + std::vector answer; + for (int j = start; j < n_gen; ++j) { + if (llama_vocab_is_eog(vocab, generated[j])) break; + answer.push_back(generated[j]); + } + std::string ans = common_detokenize(vocab, answer, false); + { + std::string s = ans; + size_t h = s.find_first_not_of(" \n\t"); if (h != std::string::npos) s = s.substr(h); + const size_t half = s.size() / 2; + if (half > 0 && s.compare(0, half, s, s.size() - half, half) == 0) ans = s.substr(0, half); + } + + out.answer = ans; + out.prompt_tokens = prefix_len; + out.answer_tokens = (int) answer.size(); + out.n_blocks = n_blocks_run; + out.n_steps_total = n_steps_total; + out.canvas_tokens = n_blocks_run * canvas_length; + out.n_decode = n_decode; + out.ok = true; + return out; + } +}; + +// ------------------------------------------------------------------------------------------------ +// JSON / log helpers +// ------------------------------------------------------------------------------------------------ +static std::atomic g_req_counter{0}; +static std::string gen_id(const char * prefix) { return std::string(prefix) + "-" + std::to_string(g_req_counter.fetch_add(1)); } +static json error_json(const std::string & msg, const std::string & type, int code) { + return json{{"error", {{"message", msg}, {"type", type}, {"code", code}}}}; +} + +// llama-server `timings` object, with diffusion fields. prompt_* describes the encoder prefill; +// predicted_* describes the answer tokens; the nested `diffusion` object describes the denoising. +static json timings_json(const diffusion_result & r) { + const double ppt = r.prompt_tokens > 0 ? r.prefill_ms / r.prompt_tokens : 0.0; + const double pps = r.prefill_ms > 0 ? r.prompt_tokens * 1e3 / r.prefill_ms : 0.0; + const double dpt = r.answer_tokens > 0 ? r.gen_ms / r.answer_tokens : 0.0; + const double dps = r.gen_ms > 0 ? r.answer_tokens * 1e3 / r.gen_ms : 0.0; + const double sps = r.gen_ms > 0 ? r.n_steps_total * 1e3 / r.gen_ms : 0.0; + const double cps = r.gen_ms > 0 ? r.canvas_tokens * 1e3 / r.gen_ms : 0.0; + const double mps = r.n_steps_total > 0 ? r.gen_ms / r.n_steps_total : 0.0; + return json{ + {"cache_n", 0}, + {"prompt_n", r.prompt_tokens}, + {"prompt_ms", r.prefill_ms}, + {"prompt_per_token_ms", ppt}, + {"prompt_per_second", pps}, + {"predicted_n", r.answer_tokens}, + {"predicted_ms", r.gen_ms}, + {"predicted_per_token_ms", dpt}, + {"predicted_per_second", dps}, + {"diffusion", { + {"n_blocks", r.n_blocks}, + {"n_steps", r.n_steps_total}, + {"canvas_tokens", r.canvas_tokens}, + {"ms_per_step", mps}, + {"steps_per_second", sps}, + {"canvas_tokens_per_second", cps}, + {"n_decode", r.n_decode}, + }}, + }; +} + +static json usage_json(const diffusion_result & r) { + return json{ + {"prompt_tokens", r.prompt_tokens}, + {"completion_tokens", r.answer_tokens}, + {"total_tokens", r.prompt_tokens + r.answer_tokens}, + {"prompt_tokens_details", {{"cached_tokens", 0}}}, + }; +} + +// per-request timing log, llama_perf-style but with the denoise phase substituted for AR eval +static void log_timings(const diffusion_result & r) { + const double ppt = r.prompt_tokens > 0 ? r.prefill_ms / r.prompt_tokens : 0.0; + const double pps = r.prefill_ms > 0 ? r.prompt_tokens * 1e3 / r.prefill_ms : 0.0; + const double mps = r.n_steps_total > 0 ? r.gen_ms / r.n_steps_total : 0.0; + const double sps = r.gen_ms > 0 ? r.n_steps_total * 1e3 / r.gen_ms : 0.0; + const double cps = r.gen_ms > 0 ? r.canvas_tokens * 1e3 / r.gen_ms : 0.0; + SLT_INF(0, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + r.prefill_ms, r.prompt_tokens, ppt, pps); + SLT_INF(0, " denoise time = %10.2f ms / %5d steps (%8.2f ms per step, %8.2f steps per second)\n", + r.gen_ms, r.n_steps_total, mps, sps); + SLT_INF(0, " gen tokens = %5d answer | %5d canvas over %d block(s) (%8.2f canvas tok/s)\n", + r.answer_tokens, r.canvas_tokens, r.n_blocks, cps); + SLT_INF(0, " total time = %10.2f ms / %5d decode call(s)\n", + r.prefill_ms + r.gen_ms, r.n_decode); +} + +static diffusion_request request_from_body(const json & body, const diffusion_server & srv, uint32_t default_seed) { + diffusion_request rq; + rq.canvas_length = srv.canvas_length; + rq.n_steps = srv.n_steps; + rq.seed = default_seed; + rq.topk_fixed = srv.topk_fixed; + rq.topk_start = srv.topk_start; + rq.topk_end = srv.topk_end; + rq.topk_tail = srv.topk_tail; + int max_tokens = 0; + if (body.contains("max_tokens") && body["max_tokens"].is_number_integer()) max_tokens = body["max_tokens"].get(); + else if (body.contains("max_completion_tokens") && body["max_completion_tokens"].is_number_integer()) max_tokens = body["max_completion_tokens"].get(); + rq.max_canvases = max_tokens > 0 ? (max_tokens + rq.canvas_length - 1) / rq.canvas_length : 1; + if (body.contains("diffusion_steps") && body["diffusion_steps"].is_number_integer()) { + rq.n_steps = std::max(body["diffusion_steps"].get(), 1); + } else if (body.contains("n_denoise_steps") && body["n_denoise_steps"].is_number_integer()) { + rq.n_steps = std::max(body["n_denoise_steps"].get(), 1); + } + if (body.contains("top_k") && body["top_k"].is_number_integer()) { + rq.topk_fixed = body["top_k"].get(); + rq.topk_start = 0; + rq.topk_end = 0; + } + if (body.contains("top_k_start") && body["top_k_start"].is_number_integer()) rq.topk_start = body["top_k_start"].get(); + if (body.contains("top_k_end") && body["top_k_end"].is_number_integer()) rq.topk_end = body["top_k_end"].get(); + if (body.contains("top_k_tail_correction") && body["top_k_tail_correction"].is_boolean()) { + rq.topk_tail = body["top_k_tail_correction"].get(); + } + const int forced_top_k = std::max(srv.diffusion.force_top_k, 0); + if (forced_top_k > 0) { + rq.topk_fixed = forced_top_k; + rq.topk_start = 0; + rq.topk_end = 0; + } + if (body.contains("seed") && body["seed"].is_number_integer()) rq.seed = (uint32_t) body["seed"].get(); + if (body.contains("ignore_eos") && body["ignore_eos"].is_boolean()) rq.ignore_eos = body["ignore_eos"].get(); + return rq; +} + +// ------------------------------------------------------------------------------------------------ +int main(int argc, char ** argv) { + std::string hostname = "127.0.0.1"; + int port = 8080; + std::string api_key; + bool enable_metrics = false; + bool enable_slots = false; + + // Pull server-only flags out of argv before common_params_parse (which validates against + // LLAMA_EXAMPLE_DIFFUSION and would reject these). Everything else (-m, --mmproj, -ngl, -c, + // --seed, --top-k*, -t, ...) is forwarded and parsed normally. + std::vector fwd; + fwd.push_back(argv[0]); + for (int i = 1; i < argc; ++i) { + const std::string a = argv[i]; + auto next = [&](const char * def) -> std::string { return (i + 1 < argc) ? argv[++i] : def; }; + if (a == "--host") { hostname = next("127.0.0.1"); } + else if (a == "--port") { port = atoi(next("8080").c_str()); } + else if (a == "--api-key") { api_key = next(""); } + else if (a == "--metrics") { enable_metrics = true; } + else if (a == "--slots") { enable_slots = true; } + else if (a == "--no-metrics") { enable_metrics = false; } + else if (a == "--no-slots") { enable_slots = false; } + else { fwd.push_back(argv[i]); } + } + + common_params params; + params.diffusion.steps = DEF_MAX_DENOISE_STEPS; + if (!common_params_parse((int) fwd.size(), fwd.data(), params, LLAMA_EXAMPLE_DIFFUSION)) { + return 1; + } + common_init(); + if (!params.diffusion.device_self_cond && params.diffusion.fused_self_cond_embd) { + SRV_ERR("%s\n", "--no-diffusion-device-selfcond cannot be used with --diffusion-fused-self-cond-embd"); + return 1; + } + params.diffusion.self_cond_top_k = diffusion_self_cond_top_k(params); + + llama_backend_init(); + llama_numa_init(params.numa); + + SRV_INF("%s\n", "loading model"); + + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = params.n_gpu_layers >= 0 ? params.n_gpu_layers : 999; + model_params.devices = params.devices.data(); + model_params.use_mmap = params.use_mmap; + + diffusion_server srv; + srv.model = llama_model_load_from_file(params.model.path.c_str(), model_params); + if (!srv.model) { + SRV_ERR("failed to load model '%s'\n", params.model.path.c_str()); + return 1; + } + if (!llama_model_is_diffusion(srv.model)) { + SRV_ERR("'%s' is not a diffusion model\n", params.model.path.c_str()); + llama_model_free(srv.model); + return 1; + } + + srv.vocab = llama_model_get_vocab(srv.model); + srv.n_vocab = llama_vocab_n_tokens(srv.vocab); + srv.canvas_length = DEF_CANVAS_LENGTH; + srv.n_steps = std::max(params.diffusion.steps, 1); + srv.diffusion = params.diffusion; + srv.n_ub = srv.canvas_length; + srv.n_ctx = params.n_ctx > 0 ? (int) params.n_ctx : 4096; + if (srv.n_ctx < 2 * srv.canvas_length) srv.n_ctx = 2 * srv.canvas_length; + srv.topk_fixed = (params.sampling.user_sampling_config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K) + ? params.sampling.top_k + : std::max(params.diffusion.default_top_k, 0); + srv.topk_start = params.diffusion.top_k_start; + srv.topk_end = params.diffusion.top_k_end; + srv.topk_tail = params.diffusion.top_k_tail_correction; + srv.model_path = params.model.path; + srv.metrics.t_start = (int64_t) std::time(nullptr); + + { + std::string p = params.model.path; + size_t slash = p.find_last_of("/\\"); + srv.model_id = (slash == std::string::npos) ? p : p.substr(slash + 1); + size_t dot = srv.model_id.rfind(".gguf"); + if (dot != std::string::npos) srv.model_id = srv.model_id.substr(0, dot); + if (srv.model_id.empty()) srv.model_id = "diffusion-gemma"; + } + + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = srv.n_ctx; + ctx_params.n_batch = srv.n_ub; + ctx_params.n_ubatch = srv.n_ub; + ctx_params.no_perf = params.no_perf; + ctx_params.diffusion_self_cond_top_k = srv.diffusion.self_cond_top_k; + ctx_params.diffusion_input_gpu_groups = srv.diffusion.input_gpu_groups; + ctx_params.diffusion_fused_self_cond_embd = srv.diffusion.fused_self_cond_embd; + ctx_params.diffusion_fuse_final_logit_softcap = srv.diffusion.fuse_final_logit_softcap; + ctx_params.diffusion_separate_encoder_decoder = srv.diffusion.separate_encoder_decoder; + + srv.ctx = llama_init_from_model(srv.model, ctx_params); + if (!srv.ctx) { + SRV_ERR("%s\n", "failed to create context"); + llama_model_free(srv.model); + return 1; + } + llama_set_n_threads(srv.ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads); + srv.mem = llama_get_memory(srv.ctx); + srv.batch = llama_batch_init(srv.n_ub, 0, 1); + srv.templates = common_chat_templates_init(srv.model, ""); + + const int topk_max_requested = + (srv.topk_start > 0 && srv.topk_end > 0) ? std::max(srv.topk_start, srv.topk_end) : srv.topk_fixed; + const bool gpu_sampling_requested = params.diffusion.gpu_sampling; + srv.use_gpu_sampling = gpu_sampling_requested && + llama_diffusion_sample_topk_supported(srv.ctx); + srv.use_device_self_cond = srv.use_gpu_sampling && params.diffusion.device_self_cond; + srv.use_device_loop = srv.use_device_self_cond && params.diffusion.device_denoise_loop; + srv.device_early_stop_interval = srv.use_device_loop ? 1 : 0; + llama_set_diffusion_gpu_sampling(srv.ctx, srv.use_gpu_sampling); + const bool default_topk_gpu_ok = topk_max_requested <= 0 || topk_max_requested <= GPU_SAMPLING_MAX_TOP_K; + + const uint32_t default_seed = params.sampling.seed == LLAMA_DEFAULT_SEED ? 1234u : params.sampling.seed; + + SRV_INF("%s\n", llama_print_system_info()); + SRV_INF("model loaded: '%s' | n_ctx = %d | canvas = %d | denoise steps = %d | 1 slot\n", + srv.model_id.c_str(), srv.n_ctx, srv.canvas_length, srv.n_steps); + SRV_INF("gpu sampling: %s%s%s%s | top-k fixed=%d anneal=[%d->%d] tail_correction=%d\n", + srv.use_gpu_sampling ? "on" : "off", + (!default_topk_gpu_ok ? " (default top-k will use CPU fallback until k <= CUDA limit)" : + (!gpu_sampling_requested ? " (disabled by --no-diffusion-gpu-sampling)" : "")), + srv.use_device_self_cond ? " | device self-cond: on" : "", + srv.use_device_loop ? " | device loop: on" : "", + srv.topk_fixed, srv.topk_start, srv.topk_end, srv.topk_tail ? 1 : 0); + if (srv.device_early_stop_interval > 0) { + SRV_INF("device early-stop interval: %d\n", srv.device_early_stop_interval); + } + if (srv.diffusion.force_top_k > 0) { + SRV_INF("forcing request top-k to %d via --diffusion-force-top-k\n", srv.diffusion.force_top_k); + } + + // ------------------------------------------------------------------------------------------ + // HTTP server + // ------------------------------------------------------------------------------------------ + httplib::Server http; + http.set_default_headers({ + {"Server", "llama.cpp-diffusion-gemma"}, + {"Access-Control-Allow-Origin", "*"}, + {"Access-Control-Allow-Headers", "Content-Type, Authorization"}, + {"Access-Control-Allow-Methods", "GET, POST, OPTIONS"}, + }); + + // access log (llama-server: SRV_TRC "done request: METHOD PATH ADDR STATUS"; skip noisy paths) + http.set_logger([](const httplib::Request & req, const httplib::Response & res) { + if (req.path == "/health" || req.path == "/v1/health" || req.path == "/metrics" || + req.path == "/props" || req.path == "/models" || req.path == "/v1/models") { + return; + } + SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); + SRV_DBG("request body: %s\n", req.body.c_str()); + }); + + if (!api_key.empty()) { + http.set_pre_routing_handler([api_key](const httplib::Request & req, httplib::Response & res) { + if (req.method == "OPTIONS" || req.path == "/health" || req.path == "/v1/health") { + return httplib::Server::HandlerResponse::Unhandled; + } + if (req.get_header_value("Authorization") != "Bearer " + api_key) { + res.status = 401; + res.set_content(error_json("invalid api key", "authentication_error", 401).dump(), "application/json"); + return httplib::Server::HandlerResponse::Handled; + } + return httplib::Server::HandlerResponse::Unhandled; + }); + } + + http.Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) { res.status = 204; }); + + auto health = [](const httplib::Request &, httplib::Response & res) { + res.set_content(json{{"status", "ok"}}.dump(), "application/json"); + }; + http.Get("/health", health); + http.Get("/v1/health", health); + + http.Get("/v1/models", [&srv](const httplib::Request &, httplib::Response & res) { + json models = json::array(); + models.push_back({{"id", srv.model_id}, {"object", "model"}, {"created", (int64_t) std::time(nullptr)}, {"owned_by", "local"}}); + res.set_content(json{{"object", "list"}, {"data", models}}.dump(), "application/json"); + }); + http.Get("/models", [&srv](const httplib::Request &, httplib::Response & res) { + json models = json::array(); + models.push_back({{"id", srv.model_id}, {"object", "model"}, {"created", (int64_t) std::time(nullptr)}, {"owned_by", "local"}}); + res.set_content(json{{"object", "list"}, {"data", models}}.dump(), "application/json"); + }); + + // /props (llama-server analogue): server + model + default generation settings + http.Get("/props", [&](const httplib::Request &, httplib::Response & res) { + json props{ + {"default_generation_settings", { + {"n_ctx", srv.n_ctx}, + {"canvas_length", srv.canvas_length}, + {"n_denoise_steps", srv.n_steps}, + {"temperature_min", TEMP_MIN}, + {"temperature_max", TEMP_MAX}, + {"entropy_bound", ENTROPY_BOUND}, + {"confidence_threshold", CONFIDENCE_THRESHOLD}, + {"stability_threshold", STABILITY_THRESHOLD}, + {"self_cond_topk", srv.diffusion.self_cond_top_k}, + {"gpu_sampling", srv.use_gpu_sampling}, + {"device_self_cond", srv.use_device_self_cond}, + {"device_loop", srv.use_device_loop}, + {"device_early_stop_interval", srv.device_early_stop_interval}, + {"top_k", srv.topk_fixed}, + {"top_k_start", srv.topk_start}, + {"top_k_end", srv.topk_end}, + {"top_k_tail_correction", srv.topk_tail}, + }}, + {"total_slots", 1}, + {"model_path", srv.model_path}, + {"model_alias", srv.model_id}, + {"n_ctx", srv.n_ctx}, + {"build_info", srv.build_info}, + {"generation_type", "block-diffusion"}, + {"endpoint_slots", enable_slots}, + {"endpoint_metrics", enable_metrics}, + }; + res.set_content(props.dump(), "application/json"); + }); + + // /metrics (Prometheus). Gated behind --metrics, like llama-server. + if (enable_metrics) { + http.Get("/metrics", [&srv](const httplib::Request &, httplib::Response & res) { + server_metrics & m = srv.metrics; + std::lock_guard lk(m.mu); + const int n_proc = m.n_processing.load(); + char buf[8192]; + int n = snprintf(buf, sizeof(buf), + // --- counters (re-mapped to diffusion where the AR concept differs) --- + "# HELP llamacpp:prompt_tokens_total Number of prompt tokens processed.\n" + "# TYPE llamacpp:prompt_tokens_total counter\n" + "llamacpp:prompt_tokens_total %llu\n" + "# HELP llamacpp:prompt_seconds_total Prompt (encoder prefill) process time.\n" + "# TYPE llamacpp:prompt_seconds_total counter\n" + "llamacpp:prompt_seconds_total %.3f\n" + "# HELP llamacpp:tokens_predicted_total Number of answer tokens generated.\n" + "# TYPE llamacpp:tokens_predicted_total counter\n" + "llamacpp:tokens_predicted_total %llu\n" + "# HELP llamacpp:tokens_predicted_seconds_total Denoising (generation) process time.\n" + "# TYPE llamacpp:tokens_predicted_seconds_total counter\n" + "llamacpp:tokens_predicted_seconds_total %.3f\n" + "# HELP llamacpp:n_decode_total Total number of llama_decode() calls (prefill + denoise + commit).\n" + "# TYPE llamacpp:n_decode_total counter\n" + "llamacpp:n_decode_total %llu\n" + "# HELP llamacpp:requests_total Total number of completed generation requests.\n" + "# TYPE llamacpp:requests_total counter\n" + "llamacpp:requests_total %llu\n" + "# HELP llamacpp:diffusion_blocks_total Total number of canvas blocks denoised.\n" + "# TYPE llamacpp:diffusion_blocks_total counter\n" + "llamacpp:diffusion_blocks_total %llu\n" + "# HELP llamacpp:diffusion_steps_total Total number of denoising steps.\n" + "# TYPE llamacpp:diffusion_steps_total counter\n" + "llamacpp:diffusion_steps_total %llu\n" + "# HELP llamacpp:diffusion_canvas_tokens_total Total canvas tokens denoised (blocks * canvas_length).\n" + "# TYPE llamacpp:diffusion_canvas_tokens_total counter\n" + "llamacpp:diffusion_canvas_tokens_total %llu\n" + // --- gauges (last request) --- + "# HELP llamacpp:prompt_tokens_seconds Average prompt throughput in tokens/s.\n" + "# TYPE llamacpp:prompt_tokens_seconds gauge\n" + "llamacpp:prompt_tokens_seconds %.3f\n" + "# HELP llamacpp:predicted_tokens_seconds Average answer-token throughput in tokens/s.\n" + "# TYPE llamacpp:predicted_tokens_seconds gauge\n" + "llamacpp:predicted_tokens_seconds %.3f\n" + "# HELP llamacpp:diffusion_steps_per_second Denoising steps per second (last request).\n" + "# TYPE llamacpp:diffusion_steps_per_second gauge\n" + "llamacpp:diffusion_steps_per_second %.3f\n" + "# HELP llamacpp:diffusion_canvas_tokens_per_second Canvas tokens per second (last request).\n" + "# TYPE llamacpp:diffusion_canvas_tokens_per_second gauge\n" + "llamacpp:diffusion_canvas_tokens_per_second %.3f\n" + "# HELP llamacpp:diffusion_ms_per_step Milliseconds per denoising step (last request).\n" + "# TYPE llamacpp:diffusion_ms_per_step gauge\n" + "llamacpp:diffusion_ms_per_step %.3f\n" + "# HELP llamacpp:requests_processing Number of requests currently generating.\n" + "# TYPE llamacpp:requests_processing gauge\n" + "llamacpp:requests_processing %d\n", + (unsigned long long) m.n_prompt_tokens_total, + m.t_prompt_ms_total / 1000.0, + (unsigned long long) m.n_tokens_predicted_total, + m.t_predicted_ms_total / 1000.0, + (unsigned long long) m.n_decode_total, + (unsigned long long) m.n_requests_total, + (unsigned long long) m.n_blocks_total, + (unsigned long long) m.n_steps_total, + (unsigned long long) m.n_canvas_tokens_total, + m.last_prompt_tps, + m.last_predicted_tps, + m.last_steps_per_second, + m.last_canvas_tps, + m.last_ms_per_step, + n_proc); + res.set_header("Process-Start-Time-Unix", std::to_string(m.t_start)); + res.set_content(std::string(buf, n > 0 ? (size_t) n : 0), "text/plain; version=0.0.4"); + }); + SRV_INF("%s\n", "metrics endpoint enabled at /metrics"); + } + + // /slots (single slot). Gated behind --slots, like llama-server. + if (enable_slots) { + http.Get("/slots", [&srv](const httplib::Request &, httplib::Response & res) { + json slot{ + {"id", 0}, + {"n_ctx", srv.n_ctx}, + {"is_processing", srv.metrics.n_processing.load() > 0}, + {"generation_type", "block-diffusion"}, + {"params", { + {"canvas_length", srv.canvas_length}, + {"n_denoise_steps", srv.n_steps}, + {"self_cond_topk", srv.diffusion.self_cond_top_k}, + {"gpu_sampling", srv.use_gpu_sampling}, + {"device_self_cond", srv.use_device_self_cond}, + {"device_loop", srv.use_device_loop}, + {"device_early_stop_interval", srv.device_early_stop_interval}, + {"top_k", srv.topk_fixed}, + {"top_k_start", srv.topk_start}, + {"top_k_end", srv.topk_end}, + {"top_k_tail_correction", srv.topk_tail}, + }}, + }; + res.set_content(json::array({slot}).dump(), "application/json"); + }); + SRV_INF("%s\n", "slots endpoint enabled at /slots"); + } + + // shared: run one generation (under the slot mutex), update metrics, print timing log + auto run_for_body = [&srv, default_seed](const json & body, const std::vector & prompt_tokens) { + diffusion_request rq = request_from_body(body, srv, default_seed); + srv.metrics.n_processing.fetch_add(1); + std::unique_lock lock(srv.gen_mutex); + diffusion_result r = srv.generate(prompt_tokens, rq); + lock.unlock(); + srv.metrics.n_processing.fetch_sub(1); + if (r.ok) { srv.metrics.add(r); log_timings(r); } + return r; + }; + + // POST /v1/chat/completions + http.Post("/v1/chat/completions", [&](const httplib::Request & req, httplib::Response & res) { + json body; + try { body = json::parse(req.body); } + catch (const std::exception & e) { res.status = 400; res.set_content(error_json(std::string("invalid JSON: ") + e.what(), "invalid_request_error", 400).dump(), "application/json"); return; } + if (!body.contains("messages") || !body["messages"].is_array()) { + res.status = 400; res.set_content(error_json("'messages' (array) is required", "invalid_request_error", 400).dump(), "application/json"); return; + } + std::string prompt; + try { prompt = srv.format_messages(body["messages"]); } + catch (const std::exception & e) { res.status = 400; res.set_content(error_json(std::string("failed to format messages: ") + e.what(), "invalid_request_error", 400).dump(), "application/json"); return; } + const std::vector prompt_tokens = common_tokenize(srv.vocab, prompt, false, true); + const bool stream = body.value("stream", false); + + const diffusion_result r = run_for_body(body, prompt_tokens); + if (!r.ok) { res.status = 500; res.set_content(error_json(r.error, "server_error", 500).dump(), "application/json"); return; } + + if (!stream) { + json out{ + {"id", gen_id("chatcmpl")}, {"object", "chat.completion"}, {"created", (int64_t) std::time(nullptr)}, {"model", srv.model_id}, + {"choices", json::array({ json{{"index", 0}, {"message", {{"role", "assistant"}, {"content", r.answer}}}, {"finish_reason", "stop"}} })}, + {"usage", usage_json(r)}, + {"timings", timings_json(r)}, + }; + res.set_content(out.dump(), "application/json"); + return; + } + + const std::string id = gen_id("chatcmpl"); + const int64_t created = (int64_t) std::time(nullptr); + const std::string model_id = srv.model_id; + const std::string answer = r.answer; + const json timings = timings_json(r); + res.set_chunked_content_provider("text/event-stream", + [id, created, model_id, answer, timings](size_t, httplib::DataSink & sink) { + auto send = [&](const json & c) { std::string s = "data: " + c.dump() + "\n\n"; return sink.write(s.data(), s.size()); }; + auto chunk = [&](const json & delta, const char * finish, const json * extra) { + json c{{"id", id}, {"object", "chat.completion.chunk"}, {"created", created}, {"model", model_id}, + {"choices", json::array({ json{{"index", 0}, {"delta", delta}, {"finish_reason", finish ? json(finish) : json(nullptr)}} })}}; + if (extra) c["timings"] = *extra; + return send(c); + }; + if (!chunk(json{{"role", "assistant"}}, nullptr, nullptr)) return false; + if (!chunk(json{{"content", answer}}, nullptr, nullptr)) return false; + if (!chunk(json::object(), "stop", &timings)) return false; + const std::string done = "data: [DONE]\n\n"; + sink.write(done.data(), done.size()); + sink.done(); + return true; + }); + }); + + // POST /v1/completions (plain prompt) + http.Post("/v1/completions", [&](const httplib::Request & req, httplib::Response & res) { + json body; + try { body = json::parse(req.body); } + catch (const std::exception & e) { res.status = 400; res.set_content(error_json(std::string("invalid JSON: ") + e.what(), "invalid_request_error", 400).dump(), "application/json"); return; } + if (!body.contains("prompt") || !body["prompt"].is_string()) { + res.status = 400; res.set_content(error_json("'prompt' (string) is required", "invalid_request_error", 400).dump(), "application/json"); return; + } + const std::string prompt = body["prompt"].get(); + const std::vector prompt_tokens = common_tokenize(srv.vocab, prompt, true, true); + + const diffusion_result r = run_for_body(body, prompt_tokens); + if (!r.ok) { res.status = 500; res.set_content(error_json(r.error, "server_error", 500).dump(), "application/json"); return; } + json out{ + {"id", gen_id("cmpl")}, {"object", "text_completion"}, {"created", (int64_t) std::time(nullptr)}, {"model", srv.model_id}, + {"choices", json::array({ json{{"index", 0}, {"text", r.answer}, {"finish_reason", "stop"}} })}, + {"usage", usage_json(r)}, + {"timings", timings_json(r)}, + }; + res.set_content(out.dump(), "application/json"); + }); + + SRV_INF("server is listening on http://%s:%d\n", hostname.c_str(), port); + SRV_INF("%s\n", "all slots are idle"); + if (!http.listen(hostname, port)) { + SRV_ERR("failed to bind to %s:%d\n", hostname.c_str(), port); + llama_batch_free(srv.batch); + llama_free(srv.ctx); + llama_model_free(srv.model); + llama_backend_free(); + return 1; + } + + llama_batch_free(srv.batch); + llama_free(srv.ctx); + llama_model_free(srv.model); + llama_backend_free(); + return 0; +} diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h index 5436c7ef579c..0a72747c3838 100644 --- a/ggml/include/ggml-cuda.h +++ b/ggml/include/ggml-cuda.h @@ -43,6 +43,54 @@ GGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * f GGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size); GGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer); +struct ggml_cuda_diffusion_sample_params { + int32_t n_vocab; + int32_t n_tokens; + int32_t top_k; + int32_t self_cond_top_k; + float temperature; + uint32_t seed; + uint32_t step; + bool top_k_tail_correction; + float logit_softcap; + bool fast_top_k; + bool direct_self_cond; + bool final_tokens_on_stop; + bool fused_top_k_sample; + bool tight_top_k; + bool parallel_full_softmax; + bool fused_full_softmax; + int32_t top_k_local_k; +}; + +struct ggml_cuda_diffusion_sample_result { + int32_t * sampled; + int32_t * argmax; + float * entropy; + int32_t * self_cond_ids; + float * self_cond_probs; + struct ggml_tensor * self_cond_ids_tensor; + struct ggml_tensor * self_cond_probs_tensor; + struct ggml_tensor * self_cond_embd_tensor; + const struct ggml_tensor * token_embd_tensor; + struct ggml_tensor * canvas_tokens_tensor; + int32_t * final_tokens; + int32_t * stop; + float entropy_bound; + float confidence_threshold; + int32_t stability_threshold; + bool update_canvas_on_device; + bool update_stop_state_on_device; + bool check_stop_on_device; + bool reset_stop_state; +}; + +typedef bool (*ggml_backend_cuda_diffusion_sample_topk_t)( + ggml_backend_t backend, + const struct ggml_tensor * logits, + const struct ggml_cuda_diffusion_sample_params * params, + struct ggml_cuda_diffusion_sample_result * result); + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void); #ifdef __cplusplus diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 87615921c09b..bbdbc40a0513 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1405,6 +1405,7 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra } sched->graph.n_nodes = 0; sched->graph.n_leafs = 0; + sched->graph.flags = graph->flags; struct ggml_cgraph * graph_copy = &sched->graph; @@ -2134,6 +2135,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s graph_copy->nodes[i] = node_copy; } graph_copy->n_nodes = graph->n_nodes; + graph_copy->flags = graph->flags; ggml_hash_set_free(&hash_set); free(node_copies); diff --git a/ggml/src/ggml-cuda/diffusion-sampling.cu b/ggml/src/ggml-cuda/diffusion-sampling.cu new file mode 100644 index 000000000000..93a66ad3aac2 --- /dev/null +++ b/ggml/src/ggml-cuda/diffusion-sampling.cu @@ -0,0 +1,1450 @@ +#include "argsort.cuh" +#include "diffusion-sampling.cuh" +#include "../ggml-backend-impl.h" + +#include +#include +#include +#include + +#ifdef GGML_CUDA_USE_CUB +# include +# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) +# define CUB_DIFFUSION_TOP_K_AVAILABLE +# include +using namespace cub; +# endif +#endif + +static int next_power_of_2_host(int x) { + int n = 1; + while (n < x) { + n <<= 1; + } + return n; +} + +static int diffusion_topk_local_k(const int heap_k, int requested) { + if (requested <= 0) { + requested = heap_k <= 256 ? 8 : 16; + } + + const int min_local_k = std::max(1, (heap_k + 255) / 256); + if (requested <= 1) { + requested = 1; + } else if (requested <= 2) { + requested = 2; + } else if (requested <= 4) { + requested = 4; + } else if (requested <= 8) { + requested = 8; + } else { + requested = 16; + } + + while (requested < min_local_k && requested < 16) { + requested <<= 1; + } + return requested; +} + +struct diffusion_sample_scratch { + int * top_ids = nullptr; + int * sampled = nullptr; + int * argmax = nullptr; + int * prev_argmax = nullptr; + int * stop = nullptr; + float * entropy = nullptr; + int * sc_ids = nullptr; + float * sc_probs = nullptr; + + size_t top_ids_cap = 0; + size_t sampled_cap = 0; + size_t argmax_cap = 0; + size_t prev_argmax_cap = 0; + size_t stop_cap = 0; + size_t entropy_cap = 0; + size_t sc_ids_cap = 0; + size_t sc_probs_cap = 0; +}; + +static std::mutex g_diffusion_scratch_mutex; +static std::map g_diffusion_scratch; + +template +static void diffusion_scratch_reserve(cudaStream_t stream, T ** ptr, size_t * cap, const size_t need, bool * synced) { + if (*cap >= need) { + return; + } + if (*ptr) { + if (!*synced) { + CUDA_CHECK(cudaStreamSynchronize(stream)); + *synced = true; + } + CUDA_CHECK(cudaFree(*ptr)); + *ptr = nullptr; + *cap = 0; + } + CUDA_CHECK(cudaMalloc((void **) ptr, need * sizeof(T))); + *cap = need; +} + +static diffusion_sample_scratch * diffusion_get_scratch( + cudaStream_t stream, + const int n_tokens, + const int heap_k, + const int sc_k) { + std::lock_guard lock(g_diffusion_scratch_mutex); + diffusion_sample_scratch & scratch = g_diffusion_scratch[stream]; + bool synced = false; + + diffusion_scratch_reserve(stream, &scratch.top_ids, &scratch.top_ids_cap, (size_t) n_tokens * heap_k, &synced); + diffusion_scratch_reserve(stream, &scratch.sampled, &scratch.sampled_cap, (size_t) n_tokens, &synced); + diffusion_scratch_reserve(stream, &scratch.argmax, &scratch.argmax_cap, (size_t) n_tokens, &synced); + diffusion_scratch_reserve(stream, &scratch.prev_argmax, &scratch.prev_argmax_cap, (size_t) n_tokens, &synced); + diffusion_scratch_reserve(stream, &scratch.stop, &scratch.stop_cap, (size_t) 1, &synced); + diffusion_scratch_reserve(stream, &scratch.entropy, &scratch.entropy_cap, (size_t) n_tokens, &synced); + diffusion_scratch_reserve(stream, &scratch.sc_ids, &scratch.sc_ids_cap, (size_t) n_tokens * sc_k, &synced); + diffusion_scratch_reserve(stream, &scratch.sc_probs, &scratch.sc_probs_cap, (size_t) n_tokens * sc_k, &synced); + + return &scratch; +} + +static __device__ __forceinline__ bool diffusion_should_swap_desc( + const float a_val, const int a_id, + const float b_val, const int b_id) { + return a_val < b_val || (a_val == b_val && a_id > b_id); +} + +template +static __global__ void diffusion_select_topk_local_kernel( + const float * __restrict__ logits, + int * __restrict__ top_ids, + const int n_vocab, + const int heap_k) { + const int row = blockIdx.x; + const int tid = threadIdx.x; + + float vals[LOCAL_K]; + int ids[LOCAL_K]; +#pragma unroll + for (int i = 0; i < LOCAL_K; ++i) { + vals[i] = -FLT_MAX; + ids[i] = 0; + } + + const float * row_logits = logits + (size_t) row * n_vocab; + for (int v = tid; v < n_vocab; v += blockDim.x) { + const float x = row_logits[v]; + if (diffusion_should_swap_desc(vals[LOCAL_K - 1], ids[LOCAL_K - 1], x, v)) { + int pos = LOCAL_K - 1; +#pragma unroll + for (int i = LOCAL_K - 1; i > 0; --i) { + if (pos == i && diffusion_should_swap_desc(vals[i - 1], ids[i - 1], x, v)) { + vals[i] = vals[i - 1]; + ids[i] = ids[i - 1]; + --pos; + } + } + vals[pos] = x; + ids[pos] = v; + } + } + + constexpr int candidate_count = LOCAL_K * 256; + + extern __shared__ unsigned char smem[]; + float * s_vals = (float *) smem; + int * s_ids = (int *) (s_vals + candidate_count); + +#pragma unroll + for (int i = 0; i < LOCAL_K; ++i) { + const int dst = tid * LOCAL_K + i; + s_vals[dst] = vals[i]; + s_ids[dst] = ids[i]; + } + __syncthreads(); + + for (int k = 2; k <= candidate_count; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + for (int i = tid; i < candidate_count; i += blockDim.x) { + const int ixj = i ^ j; + if (ixj > i) { + const bool descending = (i & k) == 0; + const bool swap = descending + ? diffusion_should_swap_desc(s_vals[i], s_ids[i], s_vals[ixj], s_ids[ixj]) + : diffusion_should_swap_desc(s_vals[ixj], s_ids[ixj], s_vals[i], s_ids[i]); + if (swap) { + const float tv = s_vals[i]; + s_vals[i] = s_vals[ixj]; + s_vals[ixj] = tv; + const int ti = s_ids[i]; + s_ids[i] = s_ids[ixj]; + s_ids[ixj] = ti; + } + } + } + __syncthreads(); + } + } + + for (int i = tid; i < heap_k; i += blockDim.x) { + top_ids[(size_t) row * heap_k + i] = s_ids[i]; + } +} + +static void diffusion_select_topk_local( + const float * logits, + int * top_ids, + const int n_vocab, + const int n_tokens, + const int heap_k, + const int top_k_local_k, + cudaStream_t stream) { + constexpr int block_size = 256; + const int local_k = diffusion_topk_local_k(heap_k, top_k_local_k); + + switch (local_k) { + case 1: { + constexpr int local_k_t = 1; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)); + diffusion_select_topk_local_kernel<<>>( + logits, top_ids, n_vocab, heap_k); + } break; + case 2: { + constexpr int local_k_t = 2; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)); + diffusion_select_topk_local_kernel<<>>( + logits, top_ids, n_vocab, heap_k); + } break; + case 4: { + constexpr int local_k_t = 4; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)); + diffusion_select_topk_local_kernel<<>>( + logits, top_ids, n_vocab, heap_k); + } break; + case 8: { + constexpr int local_k_t = 8; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)); + diffusion_select_topk_local_kernel<<>>( + logits, top_ids, n_vocab, heap_k); + } break; + default: { + constexpr int local_k_t = 16; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)); + diffusion_select_topk_local_kernel<<>>( + logits, top_ids, n_vocab, heap_k); + } break; + } +} + +#ifdef CUB_DIFFUSION_TOP_K_AVAILABLE +static void diffusion_top_k_cub( + ggml_cuda_pool & pool, + const float * src, + int * dst, + const int ncols, + const int k, + cudaStream_t stream) { + auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed, + cuda::execution::output_ordering::unsorted); + auto stream_env = cuda::stream_ref{ stream }; + auto env = cuda::std::execution::env{ stream_env, requirements }; + auto indexes_in = cuda::make_counting_iterator(0); + + size_t temp_storage_bytes = 0; + CUDA_CHECK(DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, env)); + + ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); + void * d_temp_storage = temp_storage_alloc.get(); + + CUDA_CHECK(DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, env)); +} +#endif + +static __global__ void diffusion_sort_top_ids_kernel( + const float * __restrict__ logits, + int * __restrict__ top_ids, + const int n_vocab, + const int heap_k, + const int heap_k_pad, + const float inv_temp) { + const int row = blockIdx.x; + const int tid = threadIdx.x; + + extern __shared__ unsigned char smem[]; + int * ids = (int *) smem; + float * vals = (float *) (ids + heap_k_pad); + + const int base = row * heap_k; + for (int i = tid; i < heap_k_pad; i += blockDim.x) { + if (i < heap_k) { + const int id = top_ids[base + i]; + ids[i] = id; + vals[i] = logits[(size_t) row * n_vocab + id] * inv_temp; + } else { + ids[i] = 0; + vals[i] = -FLT_MAX; + } + } + __syncthreads(); + + for (int k = 2; k <= heap_k_pad; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + const int ixj = tid ^ j; + if (ixj > tid && ixj < heap_k_pad) { + const bool up = (tid & k) == 0; + const bool swap = up ? (vals[tid] < vals[ixj]) : (vals[tid] > vals[ixj]); + if (swap) { + const float tv = vals[tid]; + vals[tid] = vals[ixj]; + vals[ixj] = tv; + const int ti = ids[tid]; + ids[tid] = ids[ixj]; + ids[ixj] = ti; + } + } + __syncthreads(); + } + } + + for (int i = tid; i < heap_k; i += blockDim.x) { + top_ids[base + i] = ids[i]; + } +} + +static __device__ __forceinline__ uint32_t diffusion_rng_u32(uint32_t x) { + x ^= x >> 16; + x *= 0x7feb352du; + x ^= x >> 15; + x *= 0x846ca68bu; + x ^= x >> 16; + return x; +} + +static __device__ __forceinline__ float diffusion_rng_uniform(uint32_t seed, uint32_t step, uint32_t row) { + const uint32_t x = diffusion_rng_u32(seed ^ (step * 0x9e3779b9u) ^ (row * 0x85ebca6bu)); + return ((x >> 8) + 0.5f) * (1.0f / 16777216.0f); +} + +static __device__ __forceinline__ float diffusion_apply_logit_softcap(const float x, const float softcap) { + return softcap > 0.0f ? softcap * tanhf(x / softcap) : x; +} + +template +static __device__ __forceinline__ float diffusion_embd_to_float(const T x) { + return (float) x; +} + +template<> +__device__ __forceinline__ float diffusion_embd_to_float(const half x) { + return __half2float(x); +} + +template +static __global__ void diffusion_build_selfcond_embd_kernel( + const T * __restrict__ token_embd, + const int64_t embd_stride, + const int n_embd, + const int n_tokens, + const int sc_k, + const int * __restrict__ sc_ids, + const float * __restrict__ sc_probs, + float * __restrict__ dst) { + extern __shared__ unsigned char smem[]; + int * s_ids = (int *) smem; + float * s_probs = (float *) (s_ids + sc_k); + + const int token = blockIdx.y; + const int dim = blockIdx.x * blockDim.x + threadIdx.x; + + for (int i = threadIdx.x; i < sc_k; i += blockDim.x) { + const int off = token * sc_k + i; + s_ids[i] = sc_ids[off]; + s_probs[i] = sc_probs[off]; + } + __syncthreads(); + + if (token >= n_tokens || dim >= n_embd) { + return; + } + + float sum = 0.0f; + for (int i = 0; i < sc_k; ++i) { + const float p = s_probs[i]; + if (p != 0.0f) { + const int id = s_ids[i]; + sum += p * diffusion_embd_to_float(token_embd[(int64_t) id * embd_stride + dim]); + } + } + dst[(int64_t) token * n_embd + dim] = sum; +} + +static bool diffusion_build_selfcond_embd( + const ggml_tensor * token_embd, + const int * sc_ids, + const float * sc_probs, + const int n_tokens, + const int sc_k, + ggml_tensor * dst, + cudaStream_t stream) { + if (!token_embd || !dst || !sc_ids || !sc_probs) { + return false; + } + if (!ggml_is_contiguous(token_embd) || !ggml_is_contiguous(dst) || + token_embd->data == nullptr || dst->data == nullptr || + token_embd->buffer == nullptr || dst->buffer == nullptr || + ggml_backend_buffer_is_host(token_embd->buffer) || + ggml_backend_buffer_is_host(dst->buffer) || + dst->type != GGML_TYPE_F32 || + token_embd->ne[1] <= 0 || + dst->ne[0] != token_embd->ne[0] || + dst->ne[1] < n_tokens) { + return false; + } + + const int n_embd = (int) token_embd->ne[0]; + const int block_size = 256; + const dim3 block(block_size, 1, 1); + const dim3 grid((n_embd + block_size - 1) / block_size, n_tokens, 1); + const size_t smem = (size_t) sc_k * (sizeof(int) + sizeof(float)); + + switch (token_embd->type) { + case GGML_TYPE_F16: { + const int64_t embd_stride = token_embd->nb[1] / (int64_t) sizeof(half); + diffusion_build_selfcond_embd_kernel<<>>( + (const half *) token_embd->data, embd_stride, n_embd, n_tokens, sc_k, + sc_ids, sc_probs, (float *) dst->data); + } break; + case GGML_TYPE_F32: { + const int64_t embd_stride = token_embd->nb[1] / (int64_t) sizeof(float); + diffusion_build_selfcond_embd_kernel<<>>( + (const float *) token_embd->data, embd_stride, n_embd, n_tokens, sc_k, + sc_ids, sc_probs, (float *) dst->data); + } break; + default: + return false; + } + return true; +} + +template +static __global__ void diffusion_sample_topk_fused_local_kernel( + const float * __restrict__ logits, + const int n_vocab, + const int n_tokens, + const int top_k, + const int sc_k, + const float inv_temp, + const float logit_softcap, + const uint32_t seed, + const uint32_t step, + int * __restrict__ sampled, + int * __restrict__ argmax, + float * __restrict__ entropy, + int * __restrict__ sc_ids, + float * __restrict__ sc_probs) { + constexpr int block_size = 256; + constexpr int candidate_count = LOCAL_K * block_size; + + const int row = blockIdx.x; + const int tid = threadIdx.x; + if (row >= n_tokens) { + return; + } + + float vals[LOCAL_K]; + int ids[LOCAL_K]; +#pragma unroll + for (int i = 0; i < LOCAL_K; ++i) { + vals[i] = -FLT_MAX; + ids[i] = 0; + } + + const float * row_logits = logits + (size_t) row * n_vocab; + for (int v = tid; v < n_vocab; v += blockDim.x) { + const float x = row_logits[v]; + if (diffusion_should_swap_desc(vals[LOCAL_K - 1], ids[LOCAL_K - 1], x, v)) { + int pos = LOCAL_K - 1; +#pragma unroll + for (int i = LOCAL_K - 1; i > 0; --i) { + if (pos == i && diffusion_should_swap_desc(vals[i - 1], ids[i - 1], x, v)) { + vals[i] = vals[i - 1]; + ids[i] = ids[i - 1]; + --pos; + } + } + vals[pos] = x; + ids[pos] = v; + } + } + + extern __shared__ unsigned char smem[]; + float * s_vals = (float *) smem; + int * s_ids = (int *) (s_vals + candidate_count); + float * s_sum = (float *) (s_ids + candidate_count); + float * s_t = s_sum + block_size; + +#pragma unroll + for (int i = 0; i < LOCAL_K; ++i) { + const int dst = tid * LOCAL_K + i; + s_vals[dst] = vals[i]; + s_ids[dst] = ids[i]; + } + __syncthreads(); + + for (int k = 2; k <= candidate_count; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + for (int i = tid; i < candidate_count; i += blockDim.x) { + const int ixj = i ^ j; + if (ixj > i) { + const bool descending = (i & k) == 0; + const bool swap = descending + ? diffusion_should_swap_desc(s_vals[i], s_ids[i], s_vals[ixj], s_ids[ixj]) + : diffusion_should_swap_desc(s_vals[ixj], s_ids[ixj], s_vals[i], s_ids[i]); + if (swap) { + const float tv = s_vals[i]; + s_vals[i] = s_vals[ixj]; + s_vals[ixj] = tv; + const int ti = s_ids[i]; + s_ids[i] = s_ids[ixj]; + s_ids[ixj] = ti; + } + } + } + __syncthreads(); + } + } + + if (tid < top_k) { + s_vals[tid] = diffusion_apply_logit_softcap(s_vals[tid], logit_softcap); + } + __syncthreads(); + + const float max_l = s_vals[0] * inv_temp; + const int amax = s_ids[0]; + + float local_sum = 0.0f; + float local_t = 0.0f; + if (tid < top_k) { + const float d = s_vals[tid] * inv_temp - max_l; + const float e = expf(d); + local_sum = e; + local_t = d * e; + } + + s_sum[tid] = local_sum; + s_t[tid] = local_t; + __syncthreads(); + + for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (tid < stride) { + s_sum[tid] += s_sum[tid + stride]; + s_t[tid] += s_t[tid + stride]; + } + __syncthreads(); + } + + const float z = s_sum[0]; + const float t = s_t[0]; + + if (tid == 0) { + argmax[row] = amax; + entropy[row] = logf(z) - t / z; + + const float r = diffusion_rng_uniform(seed, step, row) * z; + float cum = 0.0f; + int tok = amax; + for (int i = 0; i < top_k; ++i) { + const float d = s_vals[i] * inv_temp - max_l; + cum += expf(d); + if (cum >= r) { + tok = s_ids[i]; + break; + } + } + sampled[row] = tok; + + const int n_sc = min(sc_k, top_k); + for (int i = 0; i < sc_k; ++i) { + const int out = row * sc_k + i; + if (i < n_sc) { + sc_ids[out] = s_ids[i]; + sc_probs[out] = expf(s_vals[i] * inv_temp - max_l) / z; + } else { + sc_ids[out] = 0; + sc_probs[out] = 0.0f; + } + } + } +} + +static void diffusion_sample_topk_fused_local( + const float * logits, + const int n_vocab, + const int n_tokens, + const int top_k, + const int sc_k, + const int top_k_local_k, + const float inv_temp, + const float logit_softcap, + const uint32_t seed, + const uint32_t step, + int * sampled, + int * argmax, + float * entropy, + int * sc_ids, + float * sc_probs, + cudaStream_t stream) { + constexpr int block_size = 256; + const int local_k = diffusion_topk_local_k(top_k, top_k_local_k); + + switch (local_k) { + case 1: { + constexpr int local_k_t = 1; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)) + + 2 * block_size * sizeof(float); + diffusion_sample_topk_fused_local_kernel<<>>( + logits, n_vocab, n_tokens, top_k, sc_k, inv_temp, logit_softcap, seed, step, + sampled, argmax, entropy, sc_ids, sc_probs); + } break; + case 2: { + constexpr int local_k_t = 2; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)) + + 2 * block_size * sizeof(float); + diffusion_sample_topk_fused_local_kernel<<>>( + logits, n_vocab, n_tokens, top_k, sc_k, inv_temp, logit_softcap, seed, step, + sampled, argmax, entropy, sc_ids, sc_probs); + } break; + case 4: { + constexpr int local_k_t = 4; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)) + + 2 * block_size * sizeof(float); + diffusion_sample_topk_fused_local_kernel<<>>( + logits, n_vocab, n_tokens, top_k, sc_k, inv_temp, logit_softcap, seed, step, + sampled, argmax, entropy, sc_ids, sc_probs); + } break; + case 8: { + constexpr int local_k_t = 8; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)) + + 2 * block_size * sizeof(float); + diffusion_sample_topk_fused_local_kernel<<>>( + logits, n_vocab, n_tokens, top_k, sc_k, inv_temp, logit_softcap, seed, step, + sampled, argmax, entropy, sc_ids, sc_probs); + } break; + default: { + constexpr int local_k_t = 16; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)) + + 2 * block_size * sizeof(float); + diffusion_sample_topk_fused_local_kernel<<>>( + logits, n_vocab, n_tokens, top_k, sc_k, inv_temp, logit_softcap, seed, step, + sampled, argmax, entropy, sc_ids, sc_probs); + } break; + } +} + +template +static __global__ void diffusion_sample_full_softmax_fused_local_kernel( + const float * __restrict__ logits, + const int n_vocab, + const int n_tokens, + const int sc_k, + const float inv_temp, + const float logit_softcap, + const uint32_t seed, + const uint32_t step, + int * __restrict__ sampled, + int * __restrict__ argmax, + float * __restrict__ entropy, + int * __restrict__ sc_ids, + float * __restrict__ sc_probs) { + constexpr int block_size = 256; + constexpr int candidate_count = LOCAL_K * block_size; + + const int row = blockIdx.x; + const int tid = threadIdx.x; + if (row >= n_tokens) { + return; + } + + float vals[LOCAL_K]; + int ids[LOCAL_K]; +#pragma unroll + for (int i = 0; i < LOCAL_K; ++i) { + vals[i] = -FLT_MAX; + ids[i] = 0; + } + + const float * row_logits = logits + (size_t) row * n_vocab; + float local_max = -FLT_MAX; + int local_idx = 0; + for (int v = tid; v < n_vocab; v += blockDim.x) { + const float raw = row_logits[v]; + const float x = diffusion_apply_logit_softcap(raw, logit_softcap) * inv_temp; + if (x > local_max) { + local_max = x; + local_idx = v; + } + if (diffusion_should_swap_desc(vals[LOCAL_K - 1], ids[LOCAL_K - 1], raw, v)) { + int pos = LOCAL_K - 1; +#pragma unroll + for (int i = LOCAL_K - 1; i > 0; --i) { + if (pos == i && diffusion_should_swap_desc(vals[i - 1], ids[i - 1], raw, v)) { + vals[i] = vals[i - 1]; + ids[i] = ids[i - 1]; + --pos; + } + } + vals[pos] = raw; + ids[pos] = v; + } + } + + extern __shared__ unsigned char smem[]; + float * s_vals = (float *) smem; + int * s_ids = (int *) (s_vals + candidate_count); + int * s_max_ids = s_ids + candidate_count; + float * s_sum = (float *) (s_max_ids + block_size); + float * s_t = s_sum + block_size; + float * s_cdf = s_t + block_size; + __shared__ int s_sampled; + +#pragma unroll + for (int i = 0; i < LOCAL_K; ++i) { + const int dst = tid * LOCAL_K + i; + s_vals[dst] = vals[i]; + s_ids[dst] = ids[i]; + } + s_sum[tid] = local_max; + s_max_ids[tid] = local_idx; + __syncthreads(); + + for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (tid < stride && s_sum[tid + stride] > s_sum[tid]) { + s_sum[tid] = s_sum[tid + stride]; + s_max_ids[tid] = s_max_ids[tid + stride]; + } + __syncthreads(); + } + + const float max_l = s_sum[0]; + const int amax = s_max_ids[0]; + + for (int k = 2; k <= candidate_count; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + for (int i = tid; i < candidate_count; i += blockDim.x) { + const int ixj = i ^ j; + if (ixj > i) { + const bool descending = (i & k) == 0; + const bool swap = descending + ? diffusion_should_swap_desc(s_vals[i], s_ids[i], s_vals[ixj], s_ids[ixj]) + : diffusion_should_swap_desc(s_vals[ixj], s_ids[ixj], s_vals[i], s_ids[i]); + if (swap) { + const float tv = s_vals[i]; + s_vals[i] = s_vals[ixj]; + s_vals[ixj] = tv; + const int ti = s_ids[i]; + s_ids[i] = s_ids[ixj]; + s_ids[ixj] = ti; + } + } + } + __syncthreads(); + } + } + + float local_sum = 0.0f; + float local_t = 0.0f; + for (int v = tid; v < n_vocab; v += blockDim.x) { + const float d = diffusion_apply_logit_softcap(row_logits[v], logit_softcap) * inv_temp - max_l; + const float e = expf(d); + local_sum += e; + local_t += d * e; + } + + s_sum[tid] = local_sum; + s_t[tid] = local_t; + s_cdf[tid] = local_sum; + __syncthreads(); + + for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (tid < stride) { + s_sum[tid] += s_sum[tid + stride]; + s_t[tid] += s_t[tid + stride]; + } + __syncthreads(); + } + + const float z = s_sum[0]; + const float t = s_t[0]; + + if (tid == 0) { + argmax[row] = amax; + entropy[row] = logf(z) - t / z; + s_sampled = amax; + } + __syncthreads(); + + for (int stride = 1; stride < blockDim.x; stride <<= 1) { + const float add = tid >= stride ? s_cdf[tid - stride] : 0.0f; + __syncthreads(); + s_cdf[tid] += add; + __syncthreads(); + } + + const float r = diffusion_rng_uniform(seed, step, row) * z; + const float chunk_begin = tid == 0 ? 0.0f : s_cdf[tid - 1]; + const float chunk_end = s_cdf[tid]; + if (r > chunk_begin && r <= chunk_end) { + float thread_cum = chunk_begin; + for (int v = tid; v < n_vocab; v += blockDim.x) { + thread_cum += expf(diffusion_apply_logit_softcap(row_logits[v], logit_softcap) * inv_temp - max_l); + if (thread_cum >= r) { + s_sampled = v; + break; + } + } + } + __syncthreads(); + + if (tid == 0) { + sampled[row] = s_sampled; + for (int i = 0; i < sc_k; ++i) { + const int out = row * sc_k + i; + sc_ids[out] = s_ids[i]; + sc_probs[out] = expf(diffusion_apply_logit_softcap(s_vals[i], logit_softcap) * inv_temp - max_l) / z; + } + } +} + +static void diffusion_sample_full_softmax_fused_local( + const float * logits, + const int n_vocab, + const int n_tokens, + const int sc_k, + const int top_k_local_k, + const float inv_temp, + const float logit_softcap, + const uint32_t seed, + const uint32_t step, + int * sampled, + int * argmax, + float * entropy, + int * sc_ids, + float * sc_probs, + cudaStream_t stream) { + constexpr int block_size = 256; + const int local_k = diffusion_topk_local_k(sc_k, top_k_local_k); + + switch (local_k) { + case 1: { + constexpr int local_k_t = 1; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)) + + 3 * block_size * sizeof(float) + block_size * sizeof(int); + diffusion_sample_full_softmax_fused_local_kernel<<>>( + logits, n_vocab, n_tokens, sc_k, inv_temp, logit_softcap, seed, step, + sampled, argmax, entropy, sc_ids, sc_probs); + } break; + case 2: { + constexpr int local_k_t = 2; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)) + + 3 * block_size * sizeof(float) + block_size * sizeof(int); + diffusion_sample_full_softmax_fused_local_kernel<<>>( + logits, n_vocab, n_tokens, sc_k, inv_temp, logit_softcap, seed, step, + sampled, argmax, entropy, sc_ids, sc_probs); + } break; + case 4: { + constexpr int local_k_t = 4; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)) + + 3 * block_size * sizeof(float) + block_size * sizeof(int); + diffusion_sample_full_softmax_fused_local_kernel<<>>( + logits, n_vocab, n_tokens, sc_k, inv_temp, logit_softcap, seed, step, + sampled, argmax, entropy, sc_ids, sc_probs); + } break; + case 8: { + constexpr int local_k_t = 8; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)) + + 3 * block_size * sizeof(float) + block_size * sizeof(int); + diffusion_sample_full_softmax_fused_local_kernel<<>>( + logits, n_vocab, n_tokens, sc_k, inv_temp, logit_softcap, seed, step, + sampled, argmax, entropy, sc_ids, sc_probs); + } break; + default: { + constexpr int local_k_t = 16; + const size_t smem = (size_t) local_k_t * block_size * (sizeof(float) + sizeof(int)) + + 3 * block_size * sizeof(float) + block_size * sizeof(int); + diffusion_sample_full_softmax_fused_local_kernel<<>>( + logits, n_vocab, n_tokens, sc_k, inv_temp, logit_softcap, seed, step, + sampled, argmax, entropy, sc_ids, sc_probs); + } break; + } +} + +static __device__ __forceinline__ bool diffusion_pair_gt(float a_val, int a_id, float b_val, int b_id) { + return a_val > b_val || (a_val == b_val && a_id > b_id); +} + +static __global__ void diffusion_update_canvas_kernel( + const float * __restrict__ entropy, + const int * __restrict__ sampled, + int * __restrict__ canvas_tokens, + const int n_tokens, + const int n_vocab, + const float entropy_bound, + const uint32_t seed, + const uint32_t step) { + const int tid = threadIdx.x; + + __shared__ float s_entropy[1024]; + __shared__ int s_index[1024]; + __shared__ unsigned char s_accept[1024]; + + if (tid < n_tokens) { + s_entropy[tid] = entropy[tid]; + s_index[tid] = tid; + s_accept[tid] = 0; + } else { + s_entropy[tid] = FLT_MAX; + s_index[tid] = tid; + } + __syncthreads(); + + for (int k = 2; k <= blockDim.x; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + const int ixj = tid ^ j; + if (ixj > tid) { + const bool ascending = (tid & k) == 0; + const float a_val = s_entropy[tid]; + const int a_idx = s_index[tid]; + const float b_val = s_entropy[ixj]; + const int b_idx = s_index[ixj]; + const bool swap = ascending ? + diffusion_pair_gt(a_val, a_idx, b_val, b_idx) : + diffusion_pair_gt(b_val, b_idx, a_val, a_idx); + if (swap) { + s_entropy[tid] = b_val; + s_index[tid] = b_idx; + s_entropy[ixj] = a_val; + s_index[ixj] = a_idx; + } + } + __syncthreads(); + } + } + + if (tid == 0) { + float prefix = 0.0f; + for (int i = 0; i < n_tokens; ++i) { + const int pos = s_index[i]; + if (prefix <= entropy_bound) { + s_accept[pos] = 1; + prefix += s_entropy[i]; + } else { + break; + } + } + } + __syncthreads(); + + if (tid < n_tokens) { + if (s_accept[tid]) { + canvas_tokens[tid] = sampled[tid]; + } else { + const uint32_t r = diffusion_rng_u32(seed ^ ((step + 1u) * 0x9e3779b9u) ^ ((uint32_t) tid * 0x7f4a7c15u) ^ 0xa5a5a5a5u); + canvas_tokens[tid] = (int) (r % (uint32_t) n_vocab); + } + } +} + +static __global__ void diffusion_stop_state_kernel( + const float * __restrict__ entropy, + const int * __restrict__ argmax, + int * __restrict__ prev_argmax, + int * __restrict__ stop, + const int n_tokens, + const float confidence_threshold, + const int stability_threshold, + const int check_stop, + const int reset_state) { + const int tid = threadIdx.x; + + __shared__ float s_entropy[1024]; + __shared__ int s_diff[1024]; + + if (tid < n_tokens) { + s_entropy[tid] = entropy[tid]; + s_diff[tid] = reset_state ? 1 : (prev_argmax[tid] != argmax[tid]); + } else { + s_entropy[tid] = 0.0f; + s_diff[tid] = 0; + } + __syncthreads(); + + for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (tid < stride) { + s_entropy[tid] += s_entropy[tid + stride]; + s_diff[tid] += s_diff[tid + stride]; + } + __syncthreads(); + } + + if (tid == 0 && check_stop) { + const bool stable = !reset_state && (stability_threshold == 0 || s_diff[0] == 0); + const bool confident = confidence_threshold > 0.0f && + (s_entropy[0] / (float) n_tokens) < confidence_threshold; + stop[0] = (stable && confident) ? 1 : 0; + } + __syncthreads(); + + if (tid < n_tokens) { + prev_argmax[tid] = argmax[tid]; + } +} + +static __global__ void diffusion_sample_kernel( + const float * __restrict__ logits, + const int * __restrict__ top_ids, + const int n_vocab, + const int n_tokens, + const int top_k, + const int heap_k, + const int sc_k, + const float inv_temp, + const float logit_softcap, + const uint32_t seed, + const uint32_t step, + const bool tail_correction, + const bool parallel_full_softmax_sample, + int * __restrict__ sampled, + int * __restrict__ argmax, + float * __restrict__ entropy, + int * __restrict__ sc_ids, + float * __restrict__ sc_probs) { + const int row = blockIdx.x; + const int tid = threadIdx.x; + if (row >= n_tokens) { + return; + } + + __shared__ float s_val[1024]; + __shared__ float s_sum[1024]; + __shared__ float s_cdf[1024]; + __shared__ int s_idx[1024]; + __shared__ int s_sampled; + + const float * row_logits = logits + (size_t) row * n_vocab; + const int * row_top = top_ids + (size_t) row * heap_k; + + float local_max = -FLT_MAX; + int local_idx = 0; + + if (top_k == 0 || tail_correction) { + for (int v = tid; v < n_vocab; v += blockDim.x) { + const float x = diffusion_apply_logit_softcap(row_logits[v], logit_softcap) * inv_temp; + if (x > local_max) { + local_max = x; + local_idx = v; + } + } + } else { + for (int i = tid; i < top_k; i += blockDim.x) { + const int v = row_top[i]; + const float x = diffusion_apply_logit_softcap(row_logits[v], logit_softcap) * inv_temp; + if (x > local_max) { + local_max = x; + local_idx = v; + } + } + } + + s_val[tid] = local_max; + s_idx[tid] = local_idx; + __syncthreads(); + + for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (tid < stride && s_val[tid + stride] > s_val[tid]) { + s_val[tid] = s_val[tid + stride]; + s_idx[tid] = s_idx[tid + stride]; + } + __syncthreads(); + } + + const float max_l = s_val[0]; + const int amax = s_idx[0]; + + float local_sum = 0.0f; + float local_t = 0.0f; + + if (top_k == 0 || tail_correction) { + for (int v = tid; v < n_vocab; v += blockDim.x) { + const float d = diffusion_apply_logit_softcap(row_logits[v], logit_softcap) * inv_temp - max_l; + const float e = expf(d); + local_sum += e; + local_t += d * e; + } + } else { + for (int i = tid; i < top_k; i += blockDim.x) { + const int v = row_top[i]; + const float d = diffusion_apply_logit_softcap(row_logits[v], logit_softcap) * inv_temp - max_l; + const float e = expf(d); + local_sum += e; + local_t += d * e; + } + } + + s_sum[tid] = local_sum; + s_val[tid] = local_t; + s_cdf[tid] = local_sum; + __syncthreads(); + + for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (tid < stride) { + s_sum[tid] += s_sum[tid + stride]; + s_val[tid] += s_val[tid + stride]; + } + __syncthreads(); + } + + const float z = s_sum[0]; + const float t = s_val[0]; + + if (tid == 0) { + argmax[row] = amax; + entropy[row] = logf(z) - t / z; + + float sample_z = z; + if (top_k > 0 && tail_correction) { + sample_z = 0.0f; + for (int i = 0; i < top_k; ++i) { + const int v = row_top[i]; + sample_z += expf(diffusion_apply_logit_softcap(row_logits[v], logit_softcap) * inv_temp - max_l); + } + } + + const float r = diffusion_rng_uniform(seed, step, row) * sample_z; + float cum = 0.0f; + int tok = amax; + + if (top_k == 0 && !parallel_full_softmax_sample) { + for (int v = 0; v < n_vocab; ++v) { + cum += expf(diffusion_apply_logit_softcap(row_logits[v], logit_softcap) * inv_temp - max_l); + if (cum >= r) { + tok = v; + break; + } + } + } else { + for (int i = 0; i < top_k; ++i) { + const int v = row_top[i]; + cum += expf(diffusion_apply_logit_softcap(row_logits[v], logit_softcap) * inv_temp - max_l); + if (cum >= r) { + tok = v; + break; + } + } + } + sampled[row] = tok; + + const int n_sc = top_k == 0 ? sc_k : min(sc_k, top_k); + for (int i = 0; i < sc_k; ++i) { + const int out = row * sc_k + i; + if (i < n_sc) { + const int v = row_top[i]; + sc_ids[out] = v; + sc_probs[out] = expf(diffusion_apply_logit_softcap(row_logits[v], logit_softcap) * inv_temp - max_l) / sample_z; + } else { + sc_ids[out] = 0; + sc_probs[out] = 0.0f; + } + } + } + + if (top_k == 0 && parallel_full_softmax_sample) { + for (int stride = 1; stride < blockDim.x; stride <<= 1) { + const float add = tid >= stride ? s_cdf[tid - stride] : 0.0f; + __syncthreads(); + s_cdf[tid] += add; + __syncthreads(); + } + + const float z = s_sum[0]; + const float r = diffusion_rng_uniform(seed, step, row) * z; + if (tid == 0) { + s_sampled = amax; + } + __syncthreads(); + + const float chunk_begin = tid == 0 ? 0.0f : s_cdf[tid - 1]; + const float chunk_end = s_cdf[tid]; + if (r > chunk_begin && r <= chunk_end) { + float thread_cum = chunk_begin; + for (int v = tid; v < n_vocab; v += blockDim.x) { + thread_cum += expf(diffusion_apply_logit_softcap(row_logits[v], logit_softcap) * inv_temp - max_l); + if (thread_cum >= r) { + s_sampled = v; + break; + } + } + } + __syncthreads(); + + if (tid == 0) { + sampled[row] = s_sampled; + } + } +} + +bool ggml_cuda_diffusion_sample_topk( + ggml_backend_t backend, + const ggml_tensor * logits, + const ggml_cuda_diffusion_sample_params * params, + ggml_cuda_diffusion_sample_result * result) { + if (!backend || !logits || !params || !result) { + return false; + } + if (!ggml_backend_is_cuda(backend) || logits->type != GGML_TYPE_F32 || !ggml_is_contiguous(logits)) { + return false; + } + const int n_vocab = params->n_vocab > 0 ? params->n_vocab : (int) logits->ne[0]; + const int n_tokens = params->n_tokens > 0 ? params->n_tokens : (int) ggml_nrows(logits); + if (n_vocab <= 0 || n_tokens <= 0 || logits->ne[0] != n_vocab || ggml_nrows(logits) < n_tokens) { + return false; + } + + int top_k = params->top_k; + if (top_k <= 0 || top_k >= n_vocab) { + top_k = 0; + } + + const int sc_k = params->self_cond_top_k; + if (sc_k <= 0 || sc_k > 1024) { + return false; + } + + if (result->update_canvas_on_device) { + if (n_tokens > 1024 || + result->canvas_tokens_tensor == nullptr || + result->canvas_tokens_tensor->type != GGML_TYPE_I32 || + !ggml_is_contiguous(result->canvas_tokens_tensor) || + result->canvas_tokens_tensor->data == nullptr || + result->canvas_tokens_tensor->buffer == nullptr || + ggml_backend_buffer_is_host(result->canvas_tokens_tensor->buffer) || + ggml_nbytes(result->canvas_tokens_tensor) != (size_t) n_tokens * sizeof(int)) { + return false; + } + } + if (result->update_stop_state_on_device || result->check_stop_on_device || result->reset_stop_state) { + if (n_tokens > 1024 || (result->check_stop_on_device && result->stop == nullptr)) { + return false; + } + } + + const bool have_self_cond_host = result->self_cond_ids != nullptr && result->self_cond_probs != nullptr; + const bool have_self_cond_tensor = result->self_cond_ids_tensor != nullptr && result->self_cond_probs_tensor != nullptr; + const bool have_self_cond_embd_tensor = result->self_cond_embd_tensor != nullptr; + if (!have_self_cond_host && !have_self_cond_tensor && !have_self_cond_embd_tensor) { + return false; + } + if ((result->self_cond_ids == nullptr) != (result->self_cond_probs == nullptr)) { + return false; + } + if ((result->self_cond_ids_tensor == nullptr) != (result->self_cond_probs_tensor == nullptr)) { + return false; + } + if (have_self_cond_tensor) { + if (result->self_cond_ids_tensor->type != GGML_TYPE_I32 || + result->self_cond_probs_tensor->type != GGML_TYPE_F32 || + !ggml_is_contiguous(result->self_cond_ids_tensor) || + !ggml_is_contiguous(result->self_cond_probs_tensor) || + result->self_cond_ids_tensor->data == nullptr || + result->self_cond_probs_tensor->data == nullptr || + result->self_cond_ids_tensor->buffer == nullptr || + result->self_cond_probs_tensor->buffer == nullptr || + ggml_backend_buffer_is_host(result->self_cond_ids_tensor->buffer) || + ggml_backend_buffer_is_host(result->self_cond_probs_tensor->buffer)) { + return false; + } + if (ggml_nbytes(result->self_cond_ids_tensor) != (size_t) n_tokens * sc_k * sizeof(int) || + ggml_nbytes(result->self_cond_probs_tensor) != (size_t) n_tokens * sc_k * sizeof(float)) { + return false; + } + } + if (have_self_cond_embd_tensor) { + if (result->token_embd_tensor == nullptr || + result->token_embd_tensor->data == nullptr || + result->token_embd_tensor->buffer == nullptr || + result->self_cond_embd_tensor->type != GGML_TYPE_F32 || + result->self_cond_embd_tensor->data == nullptr || + result->self_cond_embd_tensor->buffer == nullptr || + ggml_backend_buffer_is_host(result->token_embd_tensor->buffer) || + ggml_backend_buffer_is_host(result->self_cond_embd_tensor->buffer) || + !ggml_is_contiguous(result->token_embd_tensor) || + !ggml_is_contiguous(result->self_cond_embd_tensor) || + result->token_embd_tensor->ne[1] < n_vocab || + result->self_cond_embd_tensor->ne[0] != result->token_embd_tensor->ne[0] || + result->self_cond_embd_tensor->ne[1] < n_tokens) { + return false; + } + if (result->token_embd_tensor->type != GGML_TYPE_F16 && + result->token_embd_tensor->type != GGML_TYPE_F32) { + return false; + } + } + + const int heap_k = top_k == 0 || !params->tight_top_k ? std::max(top_k, sc_k) : top_k; + if (heap_k <= 0 || heap_k > 1024 || heap_k > n_vocab) { + return false; + } + + const float temp = params->temperature > 0.0f ? params->temperature : 1.0f; + const float inv_temp = 1.0f / temp; + const float logit_softcap = params->logit_softcap > 0.0f ? params->logit_softcap : 0.0f; + + ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *) backend->context; + ggml_cuda_set_device(ctx->device); + ggml_cuda_pool & pool = ctx->pool(); + cudaStream_t stream = ctx->stream(); + + const float * logits_d = (const float *) logits->data; + + diffusion_sample_scratch * scratch = diffusion_get_scratch(stream, n_tokens, heap_k, sc_k); + int * top_ids = scratch->top_ids; + const bool direct_self_cond_tensor = have_self_cond_tensor && !have_self_cond_host && params->direct_self_cond; + int * sc_ids_out = direct_self_cond_tensor ? (int *) result->self_cond_ids_tensor->data : scratch->sc_ids; + float * sc_probs_out = direct_self_cond_tensor ? (float *) result->self_cond_probs_tensor->data : scratch->sc_probs; + + constexpr int block_size = 256; + bool sync_required = false; + const bool use_fused_topk_sample = params->fused_top_k_sample && + top_k > 0 && top_k <= block_size && !params->top_k_tail_correction && n_vocab >= top_k; + const bool use_fused_full_softmax_sample = + top_k == 0 && !params->top_k_tail_correction && params->fused_full_softmax; + const bool use_parallel_full_softmax_sample = + top_k == 0 && !use_fused_full_softmax_sample && params->parallel_full_softmax; + + if (use_fused_topk_sample) { + diffusion_sample_topk_fused_local(logits_d, n_vocab, n_tokens, top_k, sc_k, params->top_k_local_k, inv_temp, logit_softcap, + params->seed, params->step, scratch->sampled, scratch->argmax, scratch->entropy, + sc_ids_out, sc_probs_out, stream); + } else if (use_fused_full_softmax_sample) { + diffusion_sample_full_softmax_fused_local(logits_d, n_vocab, n_tokens, sc_k, params->top_k_local_k, inv_temp, logit_softcap, + params->seed, params->step, scratch->sampled, scratch->argmax, scratch->entropy, + sc_ids_out, sc_probs_out, stream); + } else { + bool top_ids_sorted = false; + const bool use_fast_topk = params->fast_top_k && heap_k <= 1024 && n_vocab >= heap_k; + sync_required = !use_fast_topk; + if (use_fast_topk) { + diffusion_select_topk_local(logits_d, top_ids, n_vocab, n_tokens, heap_k, params->top_k_local_k, stream); + top_ids_sorted = true; + } else { +#ifdef CUB_DIFFUSION_TOP_K_AVAILABLE + for (int row = 0; row < n_tokens; ++row) { + diffusion_top_k_cub(pool, logits_d + (size_t) row * n_vocab, top_ids + (size_t) row * heap_k, n_vocab, heap_k, stream); + } +#elif defined(GGML_CUDA_USE_CUB) + ggml_cuda_pool_alloc sorted_ids_alloc(pool, (size_t) n_tokens * n_vocab); + int * sorted_ids = sorted_ids_alloc.get(); + argsort_f32_i32_cuda_cub(pool, logits_d, sorted_ids, n_vocab, n_tokens, GGML_SORT_ORDER_DESC, stream); + CUDA_CHECK(cudaMemcpy2DAsync(top_ids, heap_k * sizeof(int), sorted_ids, n_vocab * sizeof(int), + heap_k * sizeof(int), n_tokens, cudaMemcpyDeviceToDevice, stream)); + top_ids_sorted = true; +#else + if (n_vocab > 1024) { + return false; + } + ggml_cuda_pool_alloc sorted_ids_alloc(pool, (size_t) n_tokens * n_vocab); + int * sorted_ids = sorted_ids_alloc.get(); + argsort_f32_i32_cuda_bitonic(logits_d, sorted_ids, n_vocab, n_tokens, GGML_SORT_ORDER_DESC, stream); + CUDA_CHECK(cudaMemcpy2DAsync(top_ids, heap_k * sizeof(int), sorted_ids, n_vocab * sizeof(int), + heap_k * sizeof(int), n_tokens, cudaMemcpyDeviceToDevice, stream)); + top_ids_sorted = true; +#endif + } + + if (!top_ids_sorted) { + const int heap_k_pad = next_power_of_2_host(heap_k); + const int sort_threads = std::max(32, heap_k_pad); + const size_t sort_smem = (size_t) heap_k_pad * (sizeof(int) + sizeof(float)); + diffusion_sort_top_ids_kernel<<>>( + logits_d, top_ids, n_vocab, heap_k, heap_k_pad, inv_temp); + } + + diffusion_sample_kernel<<>>( + logits_d, top_ids, n_vocab, n_tokens, top_k, heap_k, sc_k, inv_temp, + logit_softcap, params->seed, params->step, params->top_k_tail_correction, + use_parallel_full_softmax_sample, + scratch->sampled, scratch->argmax, scratch->entropy, + sc_ids_out, sc_probs_out); + } + + if (result->update_canvas_on_device) { + const int update_threads = next_power_of_2_host(n_tokens); + diffusion_update_canvas_kernel<<<1, update_threads, 0, stream>>>( + scratch->entropy, scratch->sampled, (int *) result->canvas_tokens_tensor->data, + n_tokens, n_vocab, result->entropy_bound, params->seed, params->step); + } + if (result->update_stop_state_on_device || result->check_stop_on_device || result->reset_stop_state) { + const int stop_threads = next_power_of_2_host(n_tokens); + diffusion_stop_state_kernel<<<1, stop_threads, 0, stream>>>( + scratch->entropy, scratch->argmax, scratch->prev_argmax, scratch->stop, + n_tokens, result->confidence_threshold, result->stability_threshold, + result->check_stop_on_device ? 1 : 0, result->reset_stop_state ? 1 : 0); + } + + if (result->sampled) { + CUDA_CHECK(cudaMemcpyAsync(result->sampled, scratch->sampled, (size_t) n_tokens * sizeof(int), + cudaMemcpyDeviceToHost, stream)); + } + if (result->argmax) { + CUDA_CHECK(cudaMemcpyAsync(result->argmax, scratch->argmax, (size_t) n_tokens * sizeof(int), + cudaMemcpyDeviceToHost, stream)); + } + if (result->entropy) { + CUDA_CHECK(cudaMemcpyAsync(result->entropy, scratch->entropy, (size_t) n_tokens * sizeof(float), + cudaMemcpyDeviceToHost, stream)); + } + const bool final_tokens_after_stop = + result->final_tokens && result->stop && result->check_stop_on_device && + params->final_tokens_on_stop; + + if (result->final_tokens && !final_tokens_after_stop) { + CUDA_CHECK(cudaMemcpyAsync(result->final_tokens, scratch->argmax, (size_t) n_tokens * sizeof(int), + cudaMemcpyDeviceToHost, stream)); + } + if (result->stop) { + CUDA_CHECK(cudaMemcpyAsync(result->stop, scratch->stop, sizeof(int), + cudaMemcpyDeviceToHost, stream)); + } + if (have_self_cond_host) { + CUDA_CHECK(cudaMemcpyAsync(result->self_cond_ids, scratch->sc_ids, (size_t) n_tokens * sc_k * sizeof(int), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(result->self_cond_probs, scratch->sc_probs, (size_t) n_tokens * sc_k * sizeof(float), + cudaMemcpyDeviceToHost, stream)); + } + if (have_self_cond_tensor && !direct_self_cond_tensor) { + CUDA_CHECK(cudaMemcpyAsync(result->self_cond_ids_tensor->data, scratch->sc_ids, (size_t) n_tokens * sc_k * sizeof(int), + cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(result->self_cond_probs_tensor->data, scratch->sc_probs, (size_t) n_tokens * sc_k * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + if (have_self_cond_embd_tensor) { + if (!diffusion_build_selfcond_embd(result->token_embd_tensor, sc_ids_out, sc_probs_out, n_tokens, sc_k, + result->self_cond_embd_tensor, stream)) { + return false; + } + } + + const bool host_outputs_requested = result->sampled || result->argmax || result->entropy || + (result->final_tokens && !final_tokens_after_stop) || result->stop || have_self_cond_host; + CUDA_CHECK(cudaGetLastError()); + if (sync_required || host_outputs_requested) { + CUDA_CHECK(cudaStreamSynchronize(stream)); + } + if (final_tokens_after_stop && *result->stop != 0) { + CUDA_CHECK(cudaMemcpyAsync(result->final_tokens, scratch->argmax, (size_t) n_tokens * sizeof(int), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + } + return true; +} diff --git a/ggml/src/ggml-cuda/diffusion-sampling.cuh b/ggml/src/ggml-cuda/diffusion-sampling.cuh new file mode 100644 index 000000000000..dfce406d9303 --- /dev/null +++ b/ggml/src/ggml-cuda/diffusion-sampling.cuh @@ -0,0 +1,9 @@ +#pragma once + +#include "common.cuh" + +bool ggml_cuda_diffusion_sample_topk( + ggml_backend_t backend, + const ggml_tensor * logits, + const ggml_cuda_diffusion_sample_params * params, + ggml_cuda_diffusion_sample_result * result); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e779a9be9e95..1b3ba4216227 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -23,6 +23,7 @@ #include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/diag.cuh" +#include "ggml-cuda/diffusion-sampling.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/fwht.cuh" #include "ggml-cuda/getrows.cuh" @@ -3296,6 +3297,12 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { bool use_cuda_graph = true; + static const bool log_graph_incompat = [] { + const char * env = getenv("GGML_CUDA_GRAPH_DIAG"); + return env != nullptr && atoi(env) != 0; + }(); + const bool is_diffusion_decoder_graph = ggml_graph_has_flag(cgraph, GGML_CGRAPH_FLAG_DIFFUSION_DECODER); + // Loop over nodes in GGML graph to obtain info needed for CUDA graph for (int i = 0; i < cgraph->n_nodes; i++) { @@ -3307,6 +3314,10 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) { use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture + if (log_graph_incompat) { + GGML_LOG_INFO("%s: disabling CUDA graphs due to split buffer at node %d %s (%s)\n", + __func__, i, node->name, ggml_op_name(node->op)); + } #ifndef NDEBUG GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__); #endif @@ -3315,12 +3326,23 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { // [TAG_MUL_MAT_ID_CUDA_GRAPHS] if (node->op == GGML_OP_MUL_MAT_ID) { const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc); - if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) { + const ggml_tensor * src0 = node->src[0]; + const ggml_tensor * src1 = node->src[1]; + const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc); + const bool use_mmvq = ggml_is_quantized(src0->type) && node->ne[2] <= mmvq_mmid_max; + const bool use_mmq = is_diffusion_decoder_graph && ggml_is_quantized(src0->type) && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[2], /*n_experts=*/src0->ne[2]); + const bool use_mmf = is_diffusion_decoder_graph && ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true); + if (!use_mmvq && !use_mmq && !use_mmf) { // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs // TODO: figure out a way to enable for larger batch sizes, without hurting performance // ref: https://github.com/ggml-org/llama.cpp/pull/18958 use_cuda_graph = false; + if (log_graph_incompat) { + GGML_LOG_INFO("%s: disabling CUDA graphs due to MUL_MAT_ID at node %d %s: src0=%s ne2=%lld src1_ne2=%lld n_experts=%lld mmvq_mmid_max=%d quantized=%d diffusion_decoder=%d mmq=%d mmf=%d\n", + __func__, i, node->name, ggml_type_name(src0->type), + (long long) node->ne[2], (long long) src1->ne[2], (long long) src0->ne[2], + mmvq_mmid_max, (int) ggml_is_quantized(src0->type), (int) is_diffusion_decoder_graph, (int) use_mmq, (int) use_mmf); + } #ifndef NDEBUG GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); #endif @@ -5664,6 +5686,9 @@ static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, con if (strcmp(name, "ggml_backend_get_features") == 0) { return (void *)ggml_backend_cuda_get_features; } + if (strcmp(name, "ggml_backend_cuda_diffusion_sample_topk") == 0) { + return (void *)ggml_cuda_diffusion_sample_topk; + } return nullptr; } diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index e1add5e03316..a07e49493e06 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -3,6 +3,13 @@ #include "quantize.cuh" #include "mmid.cuh" +#include + +static bool ggml_cuda_mmq_stream_k_enabled() { + const char * v = std::getenv("GGML_CUDA_MMQ_STREAM_K"); + return !v || std::atoi(v) != 0; +} + static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { case GGML_TYPE_Q1_0: @@ -118,8 +125,8 @@ void ggml_cuda_mul_mat_q( const int64_t s03 = src0->nb[3] / ts_src0; const int64_t s3 = dst->nb[3] / ts_dst; - const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) - || GGML_CUDA_CC_IS_CDNA(cc); + const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || GGML_CUDA_CC_IS_CDNA(cc)) && ggml_cuda_mmq_stream_k_enabled(); // TODO: tighter pool buffer size vs q8 path const bool use_native_fp4 = blackwell_mma_available(cc) && (src0->type == GGML_TYPE_MXFP4 || src0->type == GGML_TYPE_NVFP4); @@ -251,7 +258,7 @@ void ggml_cuda_op_mul_mat_q( // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_CDNA(cc)) - && src1_ncols == ne11; + && src1_ncols == ne11 && ggml_cuda_mmq_stream_k_enabled(); const mmq_args args = { src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst, diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index edf546d8f1e2..cf2c2d2f26c2 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -4,7 +4,9 @@ #include "vecdotq.cuh" #include "mma.cuh" +#include #include +#include #include using namespace ggml_cuda_mma; @@ -14,6 +16,25 @@ using namespace ggml_cuda_mma; #define MMQ_ITER_K_FP4 512 #define MMQ_NWARPS 8 +static bool mmq_env_enabled(const char * name, const bool default_value = false) { + const char * v = std::getenv(name); + return v ? std::atoi(v) != 0 : default_value; +} + +static int mmq_env_int(const char * name, const int default_value) { + const char * v = std::getenv(name); + return v ? std::atoi(v) : default_value; +} + +static int mmq_largest_divisor_leq(const int value, const int limit) { + for (int d = std::min(value, limit); d > 0; --d) { + if (value % d == 0) { + return d; + } + } + return 1; +} + typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted, @@ -3972,7 +3993,31 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const uint3 channel_ratio_fd = init_fastdiv_values(channel_ratio); const uint3 sample_ratio_fd = init_fastdiv_values(sample_ratio); - if (!args.use_stream_k) { + const int ntiles_dst = ntx * nty * ntzw; + const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm; + const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves); + + bool use_stream_k = args.use_stream_k; + int nblocks_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm; + + if (use_stream_k && mmq_env_enabled("GGML_CUDA_MMQ_STREAM_K_DIVISOR")) { + const int divisor_min_pct = std::max(0, std::min(100, mmq_env_int("GGML_CUDA_MMQ_STREAM_K_DIVISOR_MIN_PCT", 80))); + const int divisor_min_blocks = std::max(1, nsm * divisor_min_pct / 100); + const int divisor_blocks = mmq_largest_divisor_leq(ntiles_dst, nsm); + if (divisor_blocks >= divisor_min_blocks) { + nblocks_stream_k = divisor_blocks; + } + } + + const bool stream_k_fixup_needed = ntiles_dst % nblocks_stream_k != 0; + if (use_stream_k && stream_k_fixup_needed && mmq_env_enabled("GGML_CUDA_MMQ_AVOID_FIXUP")) { + const int avoid_min_eff = std::max(0, std::min(100, mmq_env_int("GGML_CUDA_MMQ_AVOID_FIXUP_MIN_EFF", 80))); + if (tiles_efficiency_percent >= avoid_min_eff) { + use_stream_k = false; + } + } + + if (!use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<<>> @@ -3995,14 +4040,11 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a // For the stream-k kernel it is possible to run it with tiling by setting the number of CUDA blocks equal to the number of tiles. // This is worthwhile if the efficiency of tiling is high and skipping the fixup kernel is more important. - const int ntiles_dst = ntx * nty * ntzw; - const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm; - const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves); - const dim3 block_nums_stream_k(GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm, 1, 1); + const dim3 block_nums_stream_k(nblocks_stream_k, 1, 1); GGML_ASSERT(ntiles_dst * blocks_per_ne00_fd.z < (1 << 30)); // Assert that variable kbc will not overflow. - const bool fixup_needed = ntiles_dst % block_nums_stream_k.x != 0; + const bool fixup_needed = stream_k_fixup_needed; ggml_cuda_pool & pool = ctx.pool(id); ggml_cuda_pool_alloc tmp_fixup(pool); @@ -4060,7 +4102,13 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda const int warp_size = ggml_cuda_info().devices[id].warp_size; const int nwarps = mmq_get_nwarps_host(cc, warp_size); - const int mmq_x_max = get_mmq_x_max_host(cc); + int mmq_x_max = get_mmq_x_max_host(cc); + if (mmq_env_enabled("GGML_CUDA_MMQ_MAX_X")) { + const int mmq_x_max_env = mmq_env_int("GGML_CUDA_MMQ_MAX_X", mmq_x_max); + if (mmq_x_max_env > 0) { + mmq_x_max = std::min(mmq_x_max, std::max(8, 8 * (mmq_x_max_env / 8))); + } + } const int mmq_y = get_mmq_y_host(cc); int mmq_x_best = 0; diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 62b76abbcec9..81d083be30da 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -326,6 +326,10 @@ enum ggml_cgraph_eval_order { GGML_CGRAPH_EVAL_ORDER_COUNT }; +enum ggml_cgraph_flag { + GGML_CGRAPH_FLAG_DIFFUSION_DECODER = 1u << 0, +}; + struct ggml_cgraph { int size; // maximum number of nodes/leafs/grads/grad_accs int n_nodes; // number of nodes currently in use @@ -341,11 +345,25 @@ struct ggml_cgraph { enum ggml_cgraph_eval_order order; + uint32_t flags; + // an optional identifier that can be utilized to recognize same graphs if two non-zero values match // a value of 0 means it is not set and should be ignored uint64_t uid; }; +static inline void ggml_graph_set_flag(struct ggml_cgraph * cgraph, enum ggml_cgraph_flag flag, bool enabled) { + if (enabled) { + cgraph->flags |= (uint32_t) flag; + } else { + cgraph->flags &= ~((uint32_t) flag); + } +} + +static inline bool ggml_graph_has_flag(const struct ggml_cgraph * cgraph, enum ggml_cgraph_flag flag) { + return (cgraph->flags & (uint32_t) flag) != 0; +} + // returns a slice of cgraph with nodes [i0, i1) // the slice does not have leafs or gradients // if you need the gradients, get them from the original graph diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 18a5ebd2ab0c..9a841c200aa9 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7162,6 +7162,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz /*.use_counts =*/ use_counts_ptr, /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, + /*.flags =*/ 0, /*.uid =*/ 0, }; @@ -7190,6 +7191,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) /*.use_counts =*/ cgraph0->use_counts, /*.visited_hash_set =*/ cgraph0->visited_hash_set, /*.order =*/ cgraph0->order, + /*.flags =*/ cgraph0->flags, /*.uid =*/ 0 }; @@ -7204,6 +7206,7 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { dst->n_leafs = src->n_leafs; dst->n_nodes = src->n_nodes; dst->order = src->order; + dst->flags = src->flags; for (int i = 0; i < src->n_leafs; ++i) { dst->leafs[i] = src->leafs[i]; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 584594097346..5323f515147f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -441,6 +441,7 @@ class MODEL_ARCH(IntEnum): GEMMA3N = auto() GEMMA4 = auto() GEMMA4_ASSISTANT = auto() + DIFFUSION_GEMMA = auto() GEMMA_EMBEDDING = auto() STARCODER2 = auto() RWKV6 = auto() @@ -591,6 +592,10 @@ class MODEL_TENSOR(IntEnum): ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() LAYER_OUT_SCALE = auto() + SELF_COND_NORM = auto() # diffusion-gemma (self-conditioning pre-norm) + SELF_COND_GATE = auto() # diffusion-gemma + SELF_COND_UP = auto() # diffusion-gemma + SELF_COND_DOWN = auto() # diffusion-gemma PER_LAYER_TOKEN_EMBD = auto() # gemma3n PER_LAYER_MODEL_PROJ = auto() # gemma3n PER_LAYER_INP_GATE = auto() # gemma3n @@ -992,6 +997,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA3N: "gemma3n", MODEL_ARCH.GEMMA4: "gemma4", MODEL_ARCH.GEMMA4_ASSISTANT: "gemma4-assistant", + MODEL_ARCH.DIFFUSION_GEMMA: "diffusion-gemma", MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.RWKV6: "rwkv6", @@ -1141,6 +1147,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.MOE_LATENT_UP: "blk.{bid}.ffn_latent_up", # nemotron 3 super MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.LAYER_OUT_SCALE: "blk.{bid}.layer_output_scale", + MODEL_TENSOR.SELF_COND_NORM: "self_cond_norm", # diffusion-gemma + MODEL_TENSOR.SELF_COND_GATE: "self_cond_gate", # diffusion-gemma + MODEL_TENSOR.SELF_COND_UP: "self_cond_up", # diffusion-gemma + MODEL_TENSOR.SELF_COND_DOWN: "self_cond_down", # diffusion-gemma MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm", # gemma3n @@ -2607,6 +2617,38 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_POST_NORM, MODEL_TENSOR.LAYER_OUT_SCALE, ], + MODEL_ARCH.DIFFUSION_GEMMA: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_UP_EXP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_PRE_NORM, + MODEL_TENSOR.FFN_PRE_NORM_2, + MODEL_TENSOR.FFN_POST_NORM, + MODEL_TENSOR.FFN_POST_NORM_1, + MODEL_TENSOR.FFN_POST_NORM_2, + MODEL_TENSOR.LAYER_OUT_SCALE, + MODEL_TENSOR.SELF_COND_NORM, + MODEL_TENSOR.SELF_COND_GATE, + MODEL_TENSOR.SELF_COND_UP, + MODEL_TENSOR.SELF_COND_DOWN, + ], MODEL_ARCH.GEMMA_EMBEDDING: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 5f1e28818509..dd23debeb36b 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -736,6 +736,22 @@ class TensorNameMap: "model.blocks.{bid}.embed_skip.a_g", # talkie ), + MODEL_TENSOR.SELF_COND_NORM: ( + "model.self_conditioning.pre_norm", # diffusion-gemma + ), + + MODEL_TENSOR.SELF_COND_GATE: ( + "model.self_conditioning.gate_proj", # diffusion-gemma + ), + + MODEL_TENSOR.SELF_COND_UP: ( + "model.self_conditioning.up_proj", # diffusion-gemma + ), + + MODEL_TENSOR.SELF_COND_DOWN: ( + "model.self_conditioning.down_proj", # diffusion-gemma + ), + MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: ( "model.embed_tokens_per_layer", # gemma3n ), diff --git a/include/llama.h b/include/llama.h index 27e480674282..1fef35a215cc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -340,6 +340,8 @@ extern "C" { uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] uint32_t n_outputs_max; // max outputs in a ubatch (0 = n_batch) + int32_t diffusion_self_cond_top_k; // sparse self-conditioning width for diffusion models, 0 = model default [EXPERIMENTAL] + uint32_t diffusion_input_gpu_groups; // bitmask of diffusion decoder inputs assigned to GPU backend [EXPERIMENTAL] int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -382,6 +384,9 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + bool diffusion_fused_self_cond_embd; // use fused diffusion self-conditioning embedding input [EXPERIMENTAL] + bool diffusion_fuse_final_logit_softcap; // move diffusion final softcap into sampling [EXPERIMENTAL] + bool diffusion_separate_encoder_decoder; // build separate diffusion encoder/decoder graph variants [EXPERIMENTAL] // [EXPERIMENTAL] // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) @@ -978,6 +983,97 @@ extern "C" { // If set to true, the model will only attend to the past tokens LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn); + // Set diffusion self-conditioning: the previous denoising step's per-token probability + // distribution (softmax of the processed logits), as a [n_vocab * n_tokens] row-major + // float array. The decoder forms soft-embeddings (probs @ token_embd * embed_scale) and + // adds them to the input embeddings. Pass probs=NULL or n_tokens=0 to clear (-> zeros, + // i.e. the first denoising step). Used by diffusion_gemma. + LLAMA_API void llama_set_diffusion_self_cond( + struct llama_context * ctx, + const float * probs, + int64_t n_vocab, + int64_t n_tokens); + + // Diffusion prompt conditioning: set the length of the causal prompt prefix in the + // [prompt ; canvas] sequence. The prompt tokens attend causally among themselves; the + // canvas attends to everything. Pass 0 for unconditioned generation. Used by diffusion_gemma. + LLAMA_API void llama_set_diffusion_prompt_len(struct llama_context * ctx, int64_t n_prompt); + + // Diffusion KV-cache reuse phase selector (block-diffusion, e.g. diffusion_gemma): + // false = encoder phase: plain token embeddings, no self-conditioning; the decoded + // tokens' KV is committed to the cache (prompt prefill / finalized-canvas + // commit). Use with llama_set_causal_attn(ctx, true). + // true = decoder phase: self-conditioned canvas input that reads the cached prefix; + // the caller rolls back the canvas KV (llama_memory_seq_rm) after reading the + // logits. Use with llama_set_causal_attn(ctx, false). + LLAMA_API void llama_set_diffusion_decoder_phase(struct llama_context * ctx, bool decoder_phase); + + // Sparse (top-k) self-conditioning for block diffusion: instead of the dense per-token probs + // (llama_set_diffusion_self_cond), feed only the top-k token ids + their probabilities per + // position. The graph then gathers just those k token embeddings and blends them, avoiding the + // full-vocab soft-embedding matmul and the [n_vocab x n_tokens] input. ids and probs are each + // [k * n_tokens] (token-major: position outer, k inner). Pass k=0 to clear. + LLAMA_API void llama_set_diffusion_self_cond_topk( + struct llama_context * ctx, + const int32_t * ids, + const float * probs, + int64_t k, + int64_t n_tokens); + + struct llama_diffusion_sample_params { + int32_t n_tokens; + int32_t top_k; + int32_t self_cond_top_k; + float temperature; + uint32_t seed; + uint32_t step; + bool top_k_tail_correction; + bool cuda_fast_top_k; + bool cuda_direct_self_cond; + bool cuda_final_tokens_on_stop; + bool cuda_fused_top_k_sample; + bool cuda_tight_top_k; + bool cuda_parallel_full_softmax; + bool cuda_fused_full_softmax; + int32_t cuda_top_k_local_k; + }; + + struct llama_diffusion_sample_result { + llama_token * sampled; + llama_token * argmax; + float * entropy; + // Optional host outputs for self-conditioning. If both are null, the CUDA fast path writes + // sparse self-conditioning directly into the reused decoder graph input tensors. + int32_t * self_cond_ids; + float * self_cond_probs; + // Optional device-resident denoise update. When update_canvas_on_device is true, the CUDA + // fast path performs entropy accept/renoise and writes the next canvas into the reused + // decoder token input tensor. final_tokens is copied only when requested, typically on + // the final denoising step. + llama_token * final_tokens; + // Optional host stop flag for device-loop early stopping. When requested, CUDA computes + // the stable+confident condition on device and copies only this flag plus final_tokens. + int32_t * stop; + float entropy_bound; + float confidence_threshold; + int32_t stability_threshold; + bool update_canvas_on_device; + bool update_stop_state_on_device; + bool check_stop_on_device; + bool reset_stop_state; + }; + + // CUDA-only fast path for block-diffusion sampling. When enabled, llama_decode() + // keeps the dense diffusion logits on the backend and does not copy them to the + // host output buffer. Call llama_diffusion_sample_topk() after llama_decode() to + // sample the most recent logits tensor and retrieve compact row results. + LLAMA_API bool llama_diffusion_sample_topk_supported(struct llama_context * ctx); + LLAMA_API void llama_set_diffusion_gpu_sampling(struct llama_context * ctx, bool enabled); + LLAMA_API bool llama_diffusion_sample_topk( + struct llama_context * ctx, + const struct llama_diffusion_sample_params * params, + struct llama_diffusion_sample_result * result); + // Set whether the model is in warmup mode or not // If true, all model tensors are activated during llama_decode() to load and cache their weights. // diff --git a/scripts/collect_diffusion_ncu_admin.ps1 b/scripts/collect_diffusion_ncu_admin.ps1 new file mode 100644 index 000000000000..f42a6b9e1864 --- /dev/null +++ b/scripts/collect_diffusion_ncu_admin.ps1 @@ -0,0 +1,186 @@ +param( + [string] $Model = "D:\Day0-1\northbloom\diffusion-gemma-26b-v10-q4_k_m.gguf", + [string] $Exe = "D:\Day0-1\northbloom\llama.cpp-codex-sampling\pr-changes\llama.cpp\build-review\bin\Release\llama-diffusion-gemma-cli.exe", + [string] $OutDir = "D:\Day0-1\northbloom\llama.cpp-codex-sampling\pr-changes\llama.cpp\profiles\diffusion-kernel-efficiency\ncu-admin", + [string] $Prompt = "Write one short sentence about CUDA graphs.", + [int] $TopK = 64 +) + +$ErrorActionPreference = "Stop" + +function Resolve-Ncu { + $cmd = Get-Command ncu.exe -ErrorAction SilentlyContinue + if ($cmd) { + return $cmd.Source + } + + $candidates = @( + "C:\Program Files\NVIDIA Corporation\Nsight Compute 2026.2.0\target\windows-desktop-win7-x64\ncu.exe", + "C:\Program Files\NVIDIA Corporation\Nsight Compute 2026.2.0\host\target-windows-x64\ncu.exe", + "C:\Program Files\NVIDIA Corporation\Nsight Compute 2026.1.0\target\windows-desktop-win7-x64\ncu.exe", + "C:\Program Files\NVIDIA Corporation\Nsight Compute 2026.1.0\host\target-windows-x64\ncu.exe", + "C:\Program Files\NVIDIA Corporation\Nsight Compute 2025.3.1\target\windows-desktop-win7-x64\ncu.exe", + "C:\Program Files\NVIDIA Corporation\Nsight Compute 2025.3.1\host\target-windows-x64\ncu.exe" + ) + foreach ($candidate in $candidates) { + if (Test-Path $candidate) { + return $candidate + } + } + + $found = Get-ChildItem "C:\Program Files\NVIDIA Corporation" -Recurse -Filter ncu.exe -ErrorAction SilentlyContinue | + Select-Object -First 1 + if ($found) { + return $found.FullName + } + + throw "Could not find ncu.exe. Add Nsight Compute to PATH or edit Resolve-Ncu in this script." +} + +if (-not (Test-Path $Exe)) { + throw "Executable not found: $Exe" +} +if (-not (Test-Path $Model)) { + throw "Model not found: $Model" +} + +New-Item -ItemType Directory -Force -Path $OutDir | Out-Null + +$ncu = Resolve-Ncu +Write-Host "Using NCU: $ncu" +Write-Host "Output : $OutDir" + +$runArgs = @( + $Exe, + "-m", $Model, + "-p", $Prompt, + "-n", "256", + "--diffusion-steps", "48", + "--top-k", "$TopK", + "-ngl", "all", + "-sm", "none", + "-mg", "0", + "-fa", "on", + "--ctx-size", "4096", + "--seed", "42", + "--no-mmproj", + "--diffusion-cuda-direct-self-cond", + "--diffusion-cuda-final-tokens-on-stop", + "--diffusion-cuda-tight-top-k", + "--diffusion-cuda-top-k-local-k", "4" +) + +function Set-EnvMap($map) { + foreach ($kv in $map.GetEnumerator()) { + [Environment]::SetEnvironmentVariable($kv.Key, $kv.Value, "Process") + } +} + +function Clear-ExperimentEnv { + $names = @( + "GGML_CUDA_MMQ_STREAM_K", + "GGML_CUDA_MMQ_STREAM_K_DIVISOR", + "GGML_CUDA_MMQ_STREAM_K_DIVISOR_MIN_PCT", + "GGML_CUDA_MMQ_AVOID_FIXUP", + "GGML_CUDA_MMQ_AVOID_FIXUP_MIN_EFF", + "GGML_CUDA_DISABLE_GRAPHS" + ) + foreach ($name in $names) { + [Environment]::SetEnvironmentVariable($name, $null, "Process") + } +} + +function Invoke-NcuProfile { + param( + [string] $Name, + [string] $KernelRegex, + [string] $MetricSet = "basic", + [int] $LaunchSkip = 0, + [int] $LaunchCount = 4, + [string[]] $ExtraArgs = @(), + [hashtable] $ExtraEnv = @{} + ) + + Clear-ExperimentEnv + Set-EnvMap $ExtraEnv + + $export = Join-Path $OutDir $Name + $args = @( + "--force-overwrite", + "--target-processes", "all", + "--graph-profiling", "node", + "--replay-mode", "application", + "--set", $MetricSet, + "--kernel-name", "regex:$KernelRegex", + "--launch-skip", "$LaunchSkip", + "--launch-count", "$LaunchCount", + "--export", $export, + "--" + ) + $runArgs + $ExtraArgs + + Write-Host "" + Write-Host "== $Name ==" + Write-Host "kernel regex: $KernelRegex" + Write-Host "set=$MetricSet skip=$LaunchSkip count=$LaunchCount" + & $ncu @args + if ($LASTEXITCODE -ne 0) { + throw "NCU failed for $Name with exit code $LASTEXITCODE" + } +} + +Invoke-NcuProfile ` + -Name "01_topk_fused_local4_detailed" ` + -KernelRegex "diffusion_sample_topk_fused_local_kernel" ` + -MetricSet "detailed" ` + -LaunchSkip 0 ` + -LaunchCount 4 ` + -ExtraArgs @("--diffusion-cuda-fused-top-k-sample") + +Invoke-NcuProfile ` + -Name "02_topk_baseline_select_detailed" ` + -KernelRegex "diffusion_select_topk_local_kernel" ` + -MetricSet "detailed" ` + -LaunchSkip 0 ` + -LaunchCount 4 ` + -ExtraArgs @("--diffusion-cuda-top-k-local-k", "8") + +Invoke-NcuProfile ` + -Name "03_mmq_denoise_basic" ` + -KernelRegex "mul_mat_q|mul_mat_q_stream_k_fixup|quantize_mmq_q8_1" ` + -MetricSet "basic" ` + -LaunchSkip 900 ` + -LaunchCount 16 ` + -ExtraArgs @("--diffusion-cuda-fused-top-k-sample") + +Invoke-NcuProfile ` + -Name "04_mmq_no_fixup_aggressive_basic" ` + -KernelRegex "mul_mat_q|mul_mat_q_stream_k_fixup|quantize_mmq_q8_1" ` + -MetricSet "basic" ` + -LaunchSkip 900 ` + -LaunchCount 16 ` + -ExtraArgs @("--diffusion-cuda-fused-top-k-sample") ` + -ExtraEnv @{ + GGML_CUDA_MMQ_STREAM_K_DIVISOR = "1" + GGML_CUDA_MMQ_AVOID_FIXUP = "1" + GGML_CUDA_MMQ_AVOID_FIXUP_MIN_EFF = "0" + } + +Invoke-NcuProfile ` + -Name "05_attention_basic" ` + -KernelRegex "flash_attn_ext_f16|flash_attn_stream_k_fixup" ` + -MetricSet "basic" ` + -LaunchSkip 80 ` + -LaunchCount 16 ` + -ExtraArgs @("--diffusion-cuda-fused-top-k-sample") + +Invoke-NcuProfile ` + -Name "06_misc_hotspots_basic" ` + -KernelRegex "cpy_scalar_transpose|reduce_rows_f32|k_get_rows_float|softcap_f32" ` + -MetricSet "basic" ` + -LaunchSkip 0 ` + -LaunchCount 16 ` + -ExtraArgs @("--diffusion-cuda-fused-top-k-sample") + +Write-Host "" +Write-Host "NCU collection complete. Reports are in:" +Write-Host $OutDir diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 680b5fc64df3..a671a9611ce3 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -58,6 +58,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA4, "gemma4" }, { LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" }, + { LLM_ARCH_DIFFUSION_GEMMA, "diffusion-gemma" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, @@ -392,6 +393,10 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_LAYER_OUT_SCALE, "blk.%d.layer_output_scale" }, + { LLM_TENSOR_SELF_COND_NORM, "self_cond_norm" }, + { LLM_TENSOR_SELF_COND_GATE, "self_cond_gate" }, + { LLM_TENSOR_SELF_COND_UP, "self_cond_up" }, + { LLM_TENSOR_SELF_COND_DOWN, "self_cond_down" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_POS_EMBD, "position_embd" }, { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, @@ -584,6 +589,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_SELF_COND_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, // diffusion-gemma self-conditioning + {LLM_TENSOR_SELF_COND_GATE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SELF_COND_UP, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SELF_COND_DOWN, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, @@ -890,6 +899,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: case LLM_ARCH_RND1: + case LLM_ARCH_DIFFUSION_GEMMA: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index b65fce72e646..14c3cae02846 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -62,6 +62,7 @@ enum llm_arch { LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA4, LLM_ARCH_GEMMA4_ASSISTANT, + LLM_ARCH_DIFFUSION_GEMMA, LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, @@ -412,6 +413,10 @@ enum llm_tensor { LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, LLM_TENSOR_LAYER_OUT_SCALE, + LLM_TENSOR_SELF_COND_NORM, // diffusion-gemma + LLM_TENSOR_SELF_COND_GATE, // diffusion-gemma + LLM_TENSOR_SELF_COND_UP, // diffusion-gemma + LLM_TENSOR_SELF_COND_DOWN, // diffusion-gemma LLM_TENSOR_POST_ATTN_NORM, LLM_TENSOR_POST_MLP_NORM, LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 9a40c4366af1..45f966ca1935 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1,6 +1,7 @@ #include "llama-context.h" #include "ggml.h" +#include "ggml-cuda.h" #include "llama-arch.h" #include "llama-graph.h" #include "llama-impl.h" @@ -196,6 +197,12 @@ llama_context::llama_context( cparams.n_outputs_max = params.n_outputs_max == 0 ? cparams.n_batch : params.n_outputs_max; + diffusion_cond.self_cond_top_k = params.diffusion_self_cond_top_k > 0 ? params.diffusion_self_cond_top_k : 256; + diffusion_cond.input_gpu_groups = params.diffusion_input_gpu_groups; + diffusion_cond.fused_self_cond_embd = params.diffusion_fused_self_cond_embd; + diffusion_cond.fuse_final_logit_softcap = params.diffusion_fuse_final_logit_softcap; + diffusion_cond.separate_encoder_decoder = params.diffusion_separate_encoder_decoder; + cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; @@ -1125,6 +1132,267 @@ void llama_context::set_embeddings_nextn(bool value, bool masked) { cparams.embeddings_nextn_masked = masked; } +void llama_context::set_diffusion_self_cond(const float * probs, int64_t n_vocab, int64_t n_tokens) { + diffusion_cond.sc_topk_device_ready = false; + diffusion_cond.sc_topk_device_ids_data = nullptr; + diffusion_cond.sc_topk_device_probs_data = nullptr; + diffusion_cond.sc_topk_device_ids_bytes = 0; + diffusion_cond.sc_topk_device_probs_bytes = 0; + diffusion_cond.sc_embd_device_ready = false; + diffusion_cond.sc_embd_device_data = nullptr; + diffusion_cond.sc_embd_device_bytes = 0; + diffusion_cond.canvas_tokens_device_ready = false; + diffusion_cond.canvas_tokens_device_data = nullptr; + diffusion_cond.canvas_tokens_device_bytes = 0; + + if (probs == nullptr || n_vocab <= 0 || n_tokens <= 0) { + diffusion_cond.probs.clear(); + diffusion_cond.n_vocab = 0; + diffusion_cond.n_tokens = 0; + return; + } + + const size_t n = (size_t) n_vocab * (size_t) n_tokens; + diffusion_cond.probs.assign(probs, probs + n); + diffusion_cond.n_vocab = n_vocab; + diffusion_cond.n_tokens = n_tokens; +} + +void llama_context::set_diffusion_prompt_len(int64_t n_prompt) { + diffusion_cond.n_prompt = n_prompt; +} + +void llama_context::set_diffusion_decoder_phase(bool decoder_phase) { + diffusion_cond.decoder_phase = decoder_phase; +} + +void llama_context::set_diffusion_self_cond_topk(const int32_t * ids, const float * probs, int64_t k, int64_t n_tokens) { + diffusion_cond.sc_topk_device_ready = false; + diffusion_cond.sc_topk_device_ids_data = nullptr; + diffusion_cond.sc_topk_device_probs_data = nullptr; + diffusion_cond.sc_topk_device_ids_bytes = 0; + diffusion_cond.sc_topk_device_probs_bytes = 0; + diffusion_cond.sc_embd_device_ready = false; + diffusion_cond.sc_embd_device_data = nullptr; + diffusion_cond.sc_embd_device_bytes = 0; + diffusion_cond.canvas_tokens_device_ready = false; + diffusion_cond.canvas_tokens_device_data = nullptr; + diffusion_cond.canvas_tokens_device_bytes = 0; + + if (ids == nullptr || probs == nullptr || k <= 0 || n_tokens <= 0) { + diffusion_cond.sc_topk = 0; + diffusion_cond.sc_topk_ids.clear(); + diffusion_cond.sc_topk_probs.clear(); + return; + } + const size_t n = (size_t) k * (size_t) n_tokens; + diffusion_cond.sc_topk = k; + diffusion_cond.sc_topk_ids.assign(ids, ids + n); + diffusion_cond.sc_topk_probs.assign(probs, probs + n); +} + +static ggml_backend_cuda_diffusion_sample_topk_t get_cuda_diffusion_sample_topk_proc(ggml_backend_t backend) { + if (!backend) { + return nullptr; + } + + ggml_backend_dev_t dev = ggml_backend_get_device(backend); + if (!dev) { + return nullptr; + } + + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + if (!reg) { + return nullptr; + } + + return (ggml_backend_cuda_diffusion_sample_topk_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_cuda_diffusion_sample_topk"); +} + +static bool diffusion_decoder_inputs_device_ready( + const llama_diffusion_cond & diffusion_cond, + const llm_graph_result * res) { + if (!diffusion_cond.decoder_phase || + !diffusion_cond.canvas_tokens_device_ready || + !res) { + return false; + } + + ggml_tensor * t_tokens = res->get_inp_tokens(); + if (!t_tokens || + diffusion_cond.canvas_tokens_device_data != t_tokens->data || + diffusion_cond.canvas_tokens_device_bytes != ggml_nbytes(t_tokens)) { + return false; + } + + auto * inp_sc_embd = res->get_inp_diffusion_self_cond_embd(); + if (inp_sc_embd) { + return inp_sc_embd->embd && + diffusion_cond.sc_embd_device_ready && + diffusion_cond.sc_embd_device_data == inp_sc_embd->embd->data && + diffusion_cond.sc_embd_device_bytes == ggml_nbytes(inp_sc_embd->embd); + } + + auto * inp_sc = res->get_inp_diffusion_self_cond_topk(); + if (!inp_sc || !inp_sc->ids || !inp_sc->probs || + !diffusion_cond.sc_topk_device_ready || + diffusion_cond.sc_topk_device_ids_data != inp_sc->ids->data || + diffusion_cond.sc_topk_device_probs_data != inp_sc->probs->data || + diffusion_cond.sc_topk_device_ids_bytes != ggml_nbytes(inp_sc->ids) || + diffusion_cond.sc_topk_device_probs_bytes != ggml_nbytes(inp_sc->probs)) { + return false; + } + + return true; +} + +bool llama_context::diffusion_sample_topk_supported() const { + for (ggml_backend_t backend : backend_ptrs) { + if (get_cuda_diffusion_sample_topk_proc(backend)) { + return true; + } + } + return false; +} + +void llama_context::set_diffusion_gpu_sampling(bool enabled) { + diffusion_gpu_sampling = enabled; +} + +bool llama_context::diffusion_sample_topk( + const llama_diffusion_sample_params * params, + llama_diffusion_sample_result * result) { + if (!params || !result || !gf_res_prev) { + return false; + } + + static_assert(sizeof(llama_token) == sizeof(int32_t), "llama_token must be int32_t for CUDA diffusion sampling"); + + ggml_tensor * t_logits = gf_res_prev->get_logits(); + if (!t_logits) { + return false; + } + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + auto sample_proc = get_cuda_diffusion_sample_topk_proc(backend); + if (!sample_proc) { + return false; + } + + ggml_tensor * t_self_cond_ids = nullptr; + ggml_tensor * t_self_cond_probs = nullptr; + ggml_tensor * t_self_cond_embd = nullptr; + const ggml_tensor * t_token_embd = nullptr; + ggml_tensor * t_canvas_tokens = nullptr; + const bool device_self_cond = result->self_cond_ids == nullptr && result->self_cond_probs == nullptr; + if (device_self_cond) { + auto * inp_embd = gf_res_prev->get_inp_diffusion_self_cond_embd(); + if (inp_embd && diffusion_cond.fused_self_cond_embd) { + t_self_cond_embd = inp_embd->embd; + t_token_embd = gf_res_prev->get_diffusion_token_embd(); + if (!t_self_cond_embd || !t_token_embd) { + return false; + } + } else { + auto * inp = gf_res_prev->get_inp_diffusion_self_cond_topk(); + if (!inp || !inp->ids || !inp->probs) { + return false; + } + t_self_cond_ids = inp->ids; + t_self_cond_probs = inp->probs; + } + } else if (result->self_cond_ids == nullptr || result->self_cond_probs == nullptr) { + return false; + } + + if (result->update_canvas_on_device) { + t_canvas_tokens = gf_res_prev->get_inp_tokens(); + if (!t_canvas_tokens) { + return false; + } + } + + const float fused_logit_softcap = diffusion_cond.fuse_final_logit_softcap + ? model.hparams.f_final_logit_softcapping + : 0.0f; + + ggml_cuda_diffusion_sample_params cuda_params = { + /* .n_vocab = */ (int32_t) model.vocab.n_tokens(), + /* .n_tokens = */ params->n_tokens, + /* .top_k = */ params->top_k, + /* .self_cond_top_k = */ params->self_cond_top_k, + /* .temperature = */ params->temperature, + /* .seed = */ params->seed, + /* .step = */ params->step, + /* .top_k_tail_correction = */ params->top_k_tail_correction, + /* .logit_softcap = */ fused_logit_softcap, + /* .fast_top_k = */ params->cuda_fast_top_k, + /* .direct_self_cond = */ params->cuda_direct_self_cond, + /* .final_tokens_on_stop = */ params->cuda_final_tokens_on_stop, + /* .fused_top_k_sample = */ params->cuda_fused_top_k_sample, + /* .tight_top_k = */ params->cuda_tight_top_k, + /* .parallel_full_softmax = */ params->cuda_parallel_full_softmax, + /* .fused_full_softmax = */ params->cuda_fused_full_softmax, + /* .top_k_local_k = */ params->cuda_top_k_local_k, + }; + + ggml_cuda_diffusion_sample_result cuda_result = { + /* .sampled = */ (int32_t *) result->sampled, + /* .argmax = */ (int32_t *) result->argmax, + /* .entropy = */ result->entropy, + /* .self_cond_ids = */ result->self_cond_ids, + /* .self_cond_probs = */ result->self_cond_probs, + /* .self_cond_ids_tensor = */ t_self_cond_ids, + /* .self_cond_probs_tensor = */ t_self_cond_probs, + /* .self_cond_embd_tensor = */ t_self_cond_embd, + /* .token_embd_tensor = */ t_token_embd, + /* .canvas_tokens_tensor = */ t_canvas_tokens, + /* .final_tokens = */ (int32_t *) result->final_tokens, + /* .stop = */ result->stop, + /* .entropy_bound = */ result->entropy_bound, + /* .confidence_threshold = */ result->confidence_threshold, + /* .stability_threshold = */ result->stability_threshold, + /* .update_canvas_on_device = */ result->update_canvas_on_device, + /* .update_stop_state_on_device = */ result->update_stop_state_on_device, + /* .check_stop_on_device = */ result->check_stop_on_device, + /* .reset_stop_state = */ result->reset_stop_state, + }; + + const bool ok = sample_proc(backend, t_logits, &cuda_params, &cuda_result); + if (ok && device_self_cond) { + diffusion_cond.sc_topk = params->self_cond_top_k; + diffusion_cond.sc_topk_ids.clear(); + diffusion_cond.sc_topk_probs.clear(); + diffusion_cond.sc_topk_device_ready = false; + diffusion_cond.sc_topk_device_ids_data = nullptr; + diffusion_cond.sc_topk_device_probs_data = nullptr; + diffusion_cond.sc_topk_device_ids_bytes = 0; + diffusion_cond.sc_topk_device_probs_bytes = 0; + diffusion_cond.sc_embd_device_ready = false; + diffusion_cond.sc_embd_device_data = nullptr; + diffusion_cond.sc_embd_device_bytes = 0; + if (t_self_cond_embd) { + diffusion_cond.sc_embd_device_ready = true; + diffusion_cond.sc_embd_device_data = t_self_cond_embd->data; + diffusion_cond.sc_embd_device_bytes = ggml_nbytes(t_self_cond_embd); + } else { + diffusion_cond.sc_topk_device_ready = true; + diffusion_cond.sc_topk_device_ids_data = t_self_cond_ids->data; + diffusion_cond.sc_topk_device_probs_data = t_self_cond_probs->data; + diffusion_cond.sc_topk_device_ids_bytes = ggml_nbytes(t_self_cond_ids); + diffusion_cond.sc_topk_device_probs_bytes = ggml_nbytes(t_self_cond_probs); + } + } + if (ok && result->update_canvas_on_device) { + diffusion_cond.canvas_tokens_device_ready = true; + diffusion_cond.canvas_tokens_device_data = t_canvas_tokens->data; + diffusion_cond.canvas_tokens_device_bytes = ggml_nbytes(t_canvas_tokens); + } + + return ok; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1321,7 +1589,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll //const auto t_start_us = ggml_time_us(); // FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated - res->set_inputs(&ubatch); + if (!diffusion_decoder_inputs_device_ready(diffusion_cond, res)) { + res->set_inputs(&ubatch); + } //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); } @@ -1624,7 +1894,11 @@ static void copy_tensor_async_candidates( } } -static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map & samplers) { +static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map & samplers, bool skip_raw_logits) { + if (skip_raw_logits) { + return false; + } + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { if (!ubatch.output[i]) { continue; @@ -1851,7 +2125,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { + if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers, diffusion_gpu_sampling)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits.data != nullptr); @@ -2316,6 +2590,7 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ loras.get(), /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.diffusion =*/ &diffusion_cond, /*.samplers =*/ sampling.samplers, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -3357,6 +3632,8 @@ llama_context_params llama_context_default_params() { /*.n_seq_max =*/ 1, /*.n_rs_seq =*/ 0, /*.n_outputs_max =*/ 0, + /*.diffusion_self_cond_top_k =*/ 256, + /*.diffusion_input_gpu_groups =*/ 63, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, @@ -3384,6 +3661,9 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.diffusion_fused_self_cond_embd =*/ false, + /*.diffusion_fuse_final_logit_softcap =*/ false, + /*.diffusion_separate_encoder_decoder =*/ false, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, /*.ctx_other =*/ nullptr, @@ -3554,6 +3834,37 @@ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) { ctx->set_causal_attn(causal_attn); } +void llama_set_diffusion_self_cond(llama_context * ctx, const float * probs, int64_t n_vocab, int64_t n_tokens) { + ctx->set_diffusion_self_cond(probs, n_vocab, n_tokens); +} + +void llama_set_diffusion_prompt_len(llama_context * ctx, int64_t n_prompt) { + ctx->set_diffusion_prompt_len(n_prompt); +} + +void llama_set_diffusion_decoder_phase(llama_context * ctx, bool decoder_phase) { + ctx->set_diffusion_decoder_phase(decoder_phase); +} + +void llama_set_diffusion_self_cond_topk(llama_context * ctx, const int32_t * ids, const float * probs, int64_t k, int64_t n_tokens) { + ctx->set_diffusion_self_cond_topk(ids, probs, k, n_tokens); +} + +bool llama_diffusion_sample_topk_supported(llama_context * ctx) { + return ctx->diffusion_sample_topk_supported(); +} + +void llama_set_diffusion_gpu_sampling(llama_context * ctx, bool enabled) { + ctx->set_diffusion_gpu_sampling(enabled); +} + +bool llama_diffusion_sample_topk( + llama_context * ctx, + const llama_diffusion_sample_params * params, + llama_diffusion_sample_result * result) { + return ctx->diffusion_sample_topk(params, result); +} + void llama_set_warmup(llama_context * ctx, bool warmup) { ctx->set_warmup(warmup); } diff --git a/src/llama-context.h b/src/llama-context.h index 6f8f59a22a3e..48928162acbd 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -115,6 +115,25 @@ struct llama_context { void set_causal_attn(bool value); void set_warmup(bool value); + // diffusion self-conditioning: set the previous denoising step's per-token probs + // ([n_vocab * n_tokens], row-major). Pass data=nullptr / n_tokens=0 to clear (-> zeros). + void set_diffusion_self_cond(const float * probs, int64_t n_vocab, int64_t n_tokens); + + // diffusion: length of the causal prompt prefix in the [prompt; canvas] sequence (0 = unconditioned) + void set_diffusion_prompt_len(int64_t n_prompt); + + // diffusion: select the KV-cache reuse phase (false = encoder/commit, true = decoder/denoise) + void set_diffusion_decoder_phase(bool decoder_phase); + + // diffusion: sparse (top-k) self-conditioning (top-k token ids + probs per position) + void set_diffusion_self_cond_topk(const int32_t * ids, const float * probs, int64_t k, int64_t n_tokens); + + bool diffusion_sample_topk_supported() const; + void set_diffusion_gpu_sampling(bool enabled); + bool diffusion_sample_topk( + const llama_diffusion_sample_params * params, + llama_diffusion_sample_result * result); + void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); @@ -274,6 +293,9 @@ struct llama_context { llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably + llama_diffusion_cond diffusion_cond; // diffusion self-conditioning (set per-decode by the sampler) + bool diffusion_gpu_sampling = false; // skip dense logits D2H; sampled via CUDA backend proc + llama_memory_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 4cc4a4a16a1c..5883c5a9ea9b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -12,8 +12,11 @@ #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" +#include "../ggml/src/ggml-impl.h" + #include #include +#include #include #include #include @@ -84,8 +87,16 @@ static ggml_tensor * ggml_mul_mat_aux( void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; + const size_t n_bytes = n_tokens*ggml_element_size(tokens); - ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens)); + const bool have_device = diffusion + && diffusion->canvas_tokens_device_ready + && diffusion->canvas_tokens_device_data == tokens->data + && diffusion->canvas_tokens_device_bytes == n_bytes; + + if (!have_device) { + ggml_backend_tensor_set(tokens, ubatch->token, 0, n_bytes); + } } if (ubatch->embd) { @@ -216,26 +227,25 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) { const int64_t n_tokens = ubatch->n_tokens; - GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer)); - int32_t * data = (int32_t *) out_ids->data; + std::vector data(n_outputs); if (n_outputs == n_tokens) { for (int i = 0; i < n_tokens; ++i) { data[i] = i; } + } else { + GGML_ASSERT(ubatch->output); - return; - } - - GGML_ASSERT(ubatch->output); - - int n_outputs = 0; + int n_outputs = 0; - for (int i = 0; i < n_tokens; ++i) { - if (ubatch->output[i]) { - data[n_outputs++] = i; + for (int i = 0; i < n_tokens; ++i) { + if (ubatch->output[i]) { + data[n_outputs++] = i; + } } } + + ggml_backend_tensor_set(out_ids, data.data(), 0, data.size()*ggml_element_size(out_ids)); } bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) { @@ -385,6 +395,88 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } +void llm_graph_input_diffusion_self_cond::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (!probs) { + return; + } + assert(probs->type == GGML_TYPE_F32); + + const size_t n_bytes = ggml_nbytes(probs); + if (diffusion && diffusion->probs.size() * sizeof(float) == n_bytes) { + ggml_backend_tensor_set(probs, diffusion->probs.data(), 0, n_bytes); + } else { + // no self-conditioning this step (e.g. first denoising step) -> zeros + std::vector zeros(n_bytes, 0); + ggml_backend_tensor_set(probs, zeros.data(), 0, n_bytes); + } +} + +void llm_graph_input_diffusion_self_cond_topk::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (!ids || !probs) { + return; + } + assert(ids->type == GGML_TYPE_I32); + assert(probs->type == GGML_TYPE_F32); + + const size_t n_id_bytes = ggml_nbytes(ids); + const size_t n_pr_bytes = ggml_nbytes(probs); + + const bool have_device = diffusion + && diffusion->sc_topk > 0 + && diffusion->sc_topk_device_ready + && diffusion->sc_topk_device_ids_data == ids->data + && diffusion->sc_topk_device_probs_data == probs->data + && diffusion->sc_topk_device_ids_bytes == n_id_bytes + && diffusion->sc_topk_device_probs_bytes == n_pr_bytes; + + if (have_device) { + return; + } + + const bool have = diffusion + && diffusion->sc_topk > 0 + && diffusion->sc_topk_ids.size() * sizeof(int32_t) == n_id_bytes + && diffusion->sc_topk_probs.size() * sizeof(float) == n_pr_bytes; + + if (have) { + ggml_backend_tensor_set(ids, diffusion->sc_topk_ids.data(), 0, n_id_bytes); + ggml_backend_tensor_set(probs, diffusion->sc_topk_probs.data(), 0, n_pr_bytes); + } else { + // no self-conditioning this step (e.g. first denoising step): + // ids -> 0 (any valid row), probs -> 0 so the gathered embeddings contribute nothing + std::vector zeros_id(n_id_bytes, 0); + std::vector zeros_pr(n_pr_bytes, 0); + ggml_backend_tensor_set(ids, zeros_id.data(), 0, n_id_bytes); + ggml_backend_tensor_set(probs, zeros_pr.data(), 0, n_pr_bytes); + } +} + +void llm_graph_input_diffusion_self_cond_embd::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (!embd) { + return; + } + assert(embd->type == GGML_TYPE_F32); + + const size_t n_bytes = ggml_nbytes(embd); + const bool have_device = diffusion + && diffusion->sc_embd_device_ready + && diffusion->sc_embd_device_data == embd->data + && diffusion->sc_embd_device_bytes == n_bytes; + + if (have_device) { + return; + } + + std::vector zeros(n_bytes, 0); + ggml_backend_tensor_set(embd, zeros.data(), 0, n_bytes); +} + template static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); @@ -464,24 +556,73 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { }; GGML_ASSERT(self_kq_mask); - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); if (self_kq_mask->type == GGML_TYPE_F16) { - fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); + std::vector data(ggml_nelements(self_kq_mask)); + fill_mask(data.data(), ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); + ggml_backend_tensor_set(self_kq_mask, data.data(), 0, ggml_nbytes(self_kq_mask)); } else { - fill_mask((float *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); + std::vector data(ggml_nelements(self_kq_mask)); + fill_mask(data.data(), ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); + ggml_backend_tensor_set(self_kq_mask, data.data(), 0, ggml_nbytes(self_kq_mask)); } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(self_kq_mask_swa); - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); if (self_kq_mask_swa->type == GGML_TYPE_F16) { - fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); + std::vector data(ggml_nelements(self_kq_mask_swa)); + fill_mask(data.data(), ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); + ggml_backend_tensor_set(self_kq_mask_swa, data.data(), 0, ggml_nbytes(self_kq_mask_swa)); } else { - fill_mask((float *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); + std::vector data(ggml_nelements(self_kq_mask_swa)); + fill_mask(data.data(), ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); + ggml_backend_tensor_set(self_kq_mask_swa, data.data(), 0, ggml_nbytes(self_kq_mask_swa)); } } } +void llm_graph_input_attn_no_cache_prefix::set_input(const llama_ubatch * ubatch) { + const int64_t n_kv = ubatch->n_tokens; + const int64_t n_tokens = ubatch->n_tokens; + const int64_t P = n_prompt; // causal prompt prefix length + + const auto fill_mask = [&](auto * data, int64_t ne) { + using T = std::remove_reference_t; + std::fill(data, data + ne, llama_cast(-INFINITY)); + + for (int64_t i1 = 0; i1 < n_tokens; ++i1) { // query + const llama_seq_id s1 = ubatch->seq_id[i1][0]; + const uint64_t idst = i1*n_kv; + for (int64_t i0 = 0; i0 < n_tokens; ++i0) { // key + if (ubatch->seq_id[i0][0] != s1) { + continue; + } + bool allow; + if (i1 < P) { + // prompt query: causal, prompt keys only (no canvas) + allow = (i0 < P) && (i0 <= i1); + } else { + // canvas query: attend to everything (bidirectional + cross to prompt) + allow = true; + } + if (allow) { + data[idst + i0] = llama_cast(0.0f); + } + } + } + }; + + GGML_ASSERT(self_kq_mask); + if (self_kq_mask->type == GGML_TYPE_F16) { + std::vector data(ggml_nelements(self_kq_mask)); + fill_mask(data.data(), ggml_nelements(self_kq_mask)); + ggml_backend_tensor_set(self_kq_mask, data.data(), 0, ggml_nbytes(self_kq_mask)); + } else { + std::vector data(ggml_nelements(self_kq_mask)); + fill_mask(data.data(), ggml_nelements(self_kq_mask)); + ggml_backend_tensor_set(self_kq_mask, data.data(), 0, ggml_nbytes(self_kq_mask)); + } +} + void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { mctx->set_input_k_idxs(self_k_idxs, ubatch); mctx->set_input_v_idxs(self_v_idxs, ubatch); @@ -904,6 +1045,9 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + t_h_nextn = nullptr; + t_h_pre_norm = nullptr; + t_diffusion_token_embd = nullptr; t_sampled.clear(); t_sampled_probs.clear(); t_sampled_logits.clear(); @@ -932,6 +1076,26 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { } } +llm_graph_input_diffusion_self_cond_topk * llm_graph_result::get_inp_diffusion_self_cond_topk() const { + for (auto & input : inputs) { + if (auto * topk = dynamic_cast(input.get())) { + return topk; + } + } + + return nullptr; +} + +llm_graph_input_diffusion_self_cond_embd * llm_graph_result::get_inp_diffusion_self_cond_embd() const { + for (auto & input : inputs) { + if (auto * embd = dynamic_cast(input.get())) { + return embd; + } + } + + return nullptr; +} + void llm_graph_result::set_outputs() { if (t_logits != nullptr) { ggml_set_output(t_logits); @@ -945,6 +1109,9 @@ void llm_graph_result::set_outputs() { if (t_h_nextn != nullptr) { ggml_set_output(t_h_nextn); } + if (t_h_pre_norm != nullptr) { + ggml_set_output(t_h_pre_norm); + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); @@ -1049,12 +1216,15 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : loras (params.loras), mctx (params.mctx), cross (params.cross), + diffusion (params.diffusion), samplers (params.samplers), cb_func (params.cb), res (params.res), ctx0 (res->get_ctx()), gf (res->get_gf()) { res->set_params(params); + ggml_graph_set_flag(gf, GGML_CGRAPH_FLAG_DIFFUSION_DECODER, + llm_arch_is_diffusion(arch) && diffusion && diffusion->decoder_phase); } void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { @@ -1063,6 +1233,32 @@ void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { } } +void llm_graph_context::set_diffusion_input_backend(ggml_tensor * tensor, uint32_t group) const { + if (!diffusion || !diffusion->decoder_phase || !sched || !tensor) { + return; + } + + // Keep the default to inputs used by the fixed diffusion decoder graph: + // canvas/self-cond, positions, attention scale, KV indices, masks and + // rotary helpers. Mark them as outputs too, matching ggml-backend's + // existing copy-tensor convention, so the allocator will not overwrite + // them between denoising replays. + const uint32_t enabled_groups = diffusion->input_gpu_groups; + if ((enabled_groups & group) == 0) { + return; + } + + for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); + if (backend && backend != backend_cpu && + ggml_backend_dev_type(ggml_backend_get_device(backend)) == GGML_BACKEND_DEVICE_TYPE_GPU) { + ggml_set_output(tensor); + ggml_backend_sched_set_tensor_backend(sched, tensor, backend); + break; + } + } +} + ggml_tensor * llm_graph_context::build_cvec( ggml_tensor * cur, int il) const { @@ -1810,12 +2006,13 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { assert(n_embd_inp >= n_embd); - auto inp = std::make_unique(n_embd_inp); + auto inp = std::make_unique(n_embd_inp, diffusion); inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); cb(inp->tokens, "inp_tokens", -1); ggml_set_input(inp->tokens); res->t_inp_tokens = inp->tokens; + set_diffusion_input_backend(inp->tokens); inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens); cb(inp->embd, "inp_embd", -1); @@ -1900,6 +2097,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const { cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd()); ggml_set_input(cur); + set_diffusion_input_backend(cur, 2); res->add_input(std::move(inp)); @@ -1914,6 +2112,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { // this need to be 1x1xN for broadcasting cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens); ggml_set_input(cur); + set_diffusion_input_backend(cur, 4); ggml_set_name(cur, "attn_scale"); res->add_input(std::move(inp)); @@ -1922,6 +2121,10 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { } ggml_tensor * llm_graph_context::build_inp_out_ids() const { + if (diffusion && diffusion->decoder_phase && n_outputs == n_tokens) { + return nullptr; + } + // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls, // but this would make the graph topology depend on the number of output tokens, which can interfere with // features that require constant topology such as pipeline parallelism @@ -1936,6 +2139,7 @@ ggml_tensor * llm_graph_context::build_inp_out_ids() const { cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs); ggml_set_input(cur); + set_diffusion_input_backend(cur, 64); res->add_input(std::move(inp)); @@ -1992,6 +2196,53 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { return cur; } +ggml_tensor * llm_graph_context::build_inp_diffusion_self_cond(int64_t n_vocab) const { + auto inp = std::make_unique(diffusion); + + auto & cur = inp->probs; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens); + ggml_set_input(cur); + set_diffusion_input_backend(cur); + + res->add_input(std::move(inp)); + + return cur; +} + +llm_graph_input_diffusion_self_cond_topk * llm_graph_context::build_inp_diffusion_self_cond_topk(int64_t k) const { + auto inp = std::make_unique(diffusion); + + // ids are flat [k*n_tokens] (ggml_get_rows treats higher dims of the index tensor as batch + // dims that must match the data tensor; a flat index list gathers into [n_embd, k*n_tokens]). + inp->ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, k * n_tokens); + ggml_set_input(inp->ids); + set_diffusion_input_backend(inp->ids); + + inp->probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, k, n_tokens); + ggml_set_input(inp->probs); + set_diffusion_input_backend(inp->probs); + + auto * ptr = inp.get(); + res->add_input(std::move(inp)); + + return ptr; +} + +ggml_tensor * llm_graph_context::build_inp_diffusion_self_cond_embd(int64_t n_embd) const { + auto inp = std::make_unique(diffusion); + + auto & cur = inp->embd; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_input(cur); + set_diffusion_input_backend(cur); + + res->add_input(std::move(inp)); + + return cur; +} + ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { auto inp = std::make_unique(hparams); @@ -2181,12 +2432,14 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask); + set_diffusion_input_backend(inp->self_kq_mask, 128); inp->self_kq_mask_cnv = inp->self_kq_mask; if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask_swa); + set_diffusion_input_backend(inp->self_kq_mask_swa, 128); inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa; } else { @@ -2197,6 +2450,23 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp)); } +llm_graph_input_attn_no_cache_prefix * llm_graph_context::build_attn_inp_no_cache_prefix(int64_t n_prompt) const { + auto inp = std::make_unique(hparams, cparams, n_prompt); + + const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); + ggml_set_input(inp->self_kq_mask); + set_diffusion_input_backend(inp->self_kq_mask, 128); + inp->self_kq_mask_cnv = inp->self_kq_mask; + + // sliding-window layers reuse the same prefix mask (valid while n_tokens <= sliding_window) + inp->self_kq_mask_swa = inp->self_kq_mask; + inp->self_kq_mask_swa_cnv = inp->self_kq_mask; + + return (llm_graph_input_attn_no_cache_prefix *) res->add_input(std::move(inp)); +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_no_cache * inp, ggml_tensor * wo, @@ -2278,6 +2548,11 @@ llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const { const auto * mctx_cur = static_cast(mctx); auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + set_diffusion_input_backend(inp->self_k_idxs, 8); + set_diffusion_input_backend(inp->self_v_idxs, 8); + set_diffusion_input_backend(inp->self_kq_mask, 16); + set_diffusion_input_backend(inp->self_k_rot, 32); + set_diffusion_input_backend(inp->self_v_rot, 32); return (llm_graph_input_attn_kv *) res->add_input(std::move(inp)); } @@ -2382,6 +2657,8 @@ llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const { const auto * mctx_cur = static_cast(mctx); auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + set_diffusion_input_backend(inp->self_k_idxs, 8); + set_diffusion_input_backend(inp->self_kq_mask, 16); return (llm_graph_input_attn_k *) res->add_input(std::move(inp)); } @@ -2691,6 +2968,12 @@ llm_graph_input_attn_k_dsa * llm_graph_context::build_attn_inp_k_dsa() const { inp->self_k_rot_lid = mctx_cur->get_lid()->build_input_k_rot(ctx0); } + set_diffusion_input_backend(inp->self_k_idxs_mla, 8); + set_diffusion_input_backend(inp->self_kq_mask_mla, 16); + set_diffusion_input_backend(inp->self_k_idxs_lid, 8); + set_diffusion_input_backend(inp->self_kq_mask_lid, 16); + set_diffusion_input_backend(inp->self_k_rot_lid, 32); + return (llm_graph_input_attn_k_dsa *) res->add_input(std::move(inp)); } @@ -2726,6 +3009,17 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0); inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0); + set_diffusion_input_backend(inp->self_k_idxs, 8); + set_diffusion_input_backend(inp->self_v_idxs, 8); + set_diffusion_input_backend(inp->self_kq_mask, 16); + set_diffusion_input_backend(inp->self_k_idxs_swa, 8); + set_diffusion_input_backend(inp->self_v_idxs_swa, 8); + set_diffusion_input_backend(inp->self_kq_mask_swa, 16); + set_diffusion_input_backend(inp->self_k_rot, 32); + set_diffusion_input_backend(inp->self_v_rot, 32); + set_diffusion_input_backend(inp->self_k_rot_swa, 32); + set_diffusion_input_backend(inp->self_v_rot_swa, 32); + return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); } diff --git a/src/llama-graph.h b/src/llama-graph.h index 6793846e3ea6..ce98d241836d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -76,6 +76,57 @@ struct llama_cross { std::vector> seq_ids_enc; }; +// diffusion self-conditioning: the previous denoising step's per-token probability +// distribution (softmax of the processed logits). The decoder turns this into +// soft-embeddings (probs @ token_embd * embed_scale) that are added to the input +// embeddings. Set per-decode via llama_set_diffusion_self_cond(); empty -> zeros. +struct llama_diffusion_cond { + int64_t n_vocab = 0; + int64_t n_tokens = 0; + int64_t n_prompt = 0; // length of the (causal) prompt prefix; 0 = unconditioned + std::vector probs; // [n_vocab * n_tokens], row-major per token (dense self-cond path) + + // Sparse (top-k) self-conditioning: the previous step's top-k token ids + probabilities per + // position. When sc_topk > 0 the graph gathers only these k embedding rows and blends them, + // instead of the full-vocab `probs @ token_embd` matmul. ids/probs are [sc_topk * n_tokens] + // (token-major: positions outer, k inner). Set via llama_set_diffusion_self_cond_topk(). + int64_t sc_topk = 0; + std::vector sc_topk_ids; + std::vector sc_topk_probs; + bool sc_topk_device_ready = false; + void * sc_topk_device_ids_data = nullptr; + void * sc_topk_device_probs_data = nullptr; + size_t sc_topk_device_ids_bytes = 0; + size_t sc_topk_device_probs_bytes = 0; + + bool sc_embd_device_ready = false; + void * sc_embd_device_data = nullptr; + size_t sc_embd_device_bytes = 0; + + bool canvas_tokens_device_ready = false; + void * canvas_tokens_device_data = nullptr; + size_t canvas_tokens_device_bytes = 0; + + // KV-cache reuse phase selector (block-diffusion): + // false (encoder phase) = plain token embeddings, no self-conditioning; KV is + // committed to the cache (prompt prefill / finalized-canvas commit). + // true (decoder phase) = self-conditioned canvas input that reads the cached + // prefix read-only; its own KV is written then rolled back by the caller. + // Causality is controlled separately via llama_set_causal_attn (encoder: causal, + // decoder: bidirectional). + // + // Defaults to true (decoder) so the init-time graph reserve builds the worst-case + // superset graph (the decoder adds the self-conditioning input + block on top of the + // encoder graph). The caller sets the actual phase before every decode. + bool decoder_phase = true; + + int64_t self_cond_top_k = 256; + uint32_t input_gpu_groups = 63; + bool fused_self_cond_embd = false; + bool fuse_final_logit_softcap = false; + bool separate_encoder_decoder = false; +}; + struct llm_graph_params; // @@ -110,7 +161,8 @@ using llm_graph_input_ptr = std::unique_ptr; class llm_graph_input_embd : public llm_graph_input_i { public: - llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {} + llm_graph_input_embd(int64_t n_embd, const llama_diffusion_cond * diffusion = nullptr) : + diffusion(diffusion), n_embd(n_embd) {} virtual ~llm_graph_input_embd() = default; void set_input(const llama_ubatch * ubatch) override; @@ -120,6 +172,7 @@ class llm_graph_input_embd : public llm_graph_input_i { ggml_tensor * tokens = nullptr; // I32 [n_batch] ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + const llama_diffusion_cond * diffusion = nullptr; const int64_t n_embd = 0; }; @@ -279,6 +332,49 @@ class llm_graph_input_cross_embd : public llm_graph_input_i { const llama_cross * cross; }; +// diffusion self-conditioning probabilities input (see struct llama_diffusion_cond) +class llm_graph_input_diffusion_self_cond : public llm_graph_input_i { +public: + llm_graph_input_diffusion_self_cond(const llama_diffusion_cond * diffusion) : diffusion(diffusion) {} + virtual ~llm_graph_input_diffusion_self_cond() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * probs; // F32 [n_vocab, n_tokens] + + const llama_diffusion_cond * diffusion; +}; + +// sparse (top-k) diffusion self-conditioning input: top-k token ids + probabilities per position. +// Feeds [k, n_tokens] ids (I32) and probs (F32) so the graph gathers only k embedding rows per +// position instead of the dense full-vocab probs (see struct llama_diffusion_cond, sc_topk). +class llm_graph_input_diffusion_self_cond_topk : public llm_graph_input_i { +public: + llm_graph_input_diffusion_self_cond_topk(const llama_diffusion_cond * diffusion) : diffusion(diffusion) {} + virtual ~llm_graph_input_diffusion_self_cond_topk() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * ids; // I32 [k*n_tokens] (flat; gathered into [n_embd, k*n_tokens]) + ggml_tensor * probs; // F32 [k, n_tokens] + + const llama_diffusion_cond * diffusion; +}; + +// Dense diffusion self-conditioning embedding input. This is filled by the CUDA sampler from +// top-k ids/probs when the fused diffusion self-conditioning embedding path is enabled. +class llm_graph_input_diffusion_self_cond_embd : public llm_graph_input_i { +public: + llm_graph_input_diffusion_self_cond_embd(const llama_diffusion_cond * diffusion) : diffusion(diffusion) {} + virtual ~llm_graph_input_diffusion_self_cond_embd() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * embd = nullptr; // F32 [n_embd, n_tokens] + + const llama_diffusion_cond * diffusion; +}; + class llm_graph_input_attn_no_cache : public llm_graph_input_i { public: llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) : @@ -302,6 +398,24 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i { const llama_cparams cparams; }; +// prefix attention mask (no KV cache), used by block-diffusion models over a +// [prompt(0..n_prompt-1) ; canvas(n_prompt..n_tokens-1)] sequence: +// - prompt queries attend causally to the prompt only (no canvas) +// - canvas queries attend to all positions (bidirectional + cross to prompt) +// Valid while n_tokens <= sliding_window (sliding == full); the swa mask reuses the +// same prefix mask. n_prompt = 0 reduces to a fully-bidirectional mask. +class llm_graph_input_attn_no_cache_prefix : public llm_graph_input_attn_no_cache { +public: + llm_graph_input_attn_no_cache_prefix(const llama_hparams & hparams, const llama_cparams & cparams, int64_t n_prompt) : + llm_graph_input_attn_no_cache(hparams, cparams), n_prompt(n_prompt) { + } + ~llm_graph_input_attn_no_cache_prefix() = default; + + void set_input(const llama_ubatch * ubatch) override; + + const int64_t n_prompt; +}; + class llm_graph_input_attn_kv : public llm_graph_input_i { public: llm_graph_input_attn_kv( @@ -602,6 +716,7 @@ struct llm_graph_params { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + const llama_diffusion_cond * diffusion; std::map samplers; @@ -704,6 +819,10 @@ class llm_graph_result { ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } ggml_tensor * get_h_nextn() const { return t_h_nextn; } + ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } + ggml_tensor * get_diffusion_token_embd() const { return t_diffusion_token_embd; } + llm_graph_input_diffusion_self_cond_topk * get_inp_diffusion_self_cond_topk() const; + llm_graph_input_diffusion_self_cond_embd * get_inp_diffusion_self_cond_embd() const; ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -733,6 +852,8 @@ class llm_graph_result { ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; ggml_tensor * t_h_nextn = nullptr; // [n_embd, n_outputs] hidden state before final output norm + ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm + ggml_tensor * t_diffusion_token_embd = nullptr; // on-device F16 token embedding used by diffusion CUDA helpers std::map t_sampled_logits; std::map t_candidates; @@ -820,6 +941,7 @@ struct llm_graph_context { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + const llama_diffusion_cond * diffusion; std::map samplers; @@ -834,6 +956,7 @@ struct llm_graph_context { virtual ~llm_graph_context() = default; void cb(ggml_tensor * cur, const char * name, int il) const; + void set_diffusion_input_backend(ggml_tensor * tensor, uint32_t group = 1) const; // // common @@ -940,6 +1063,10 @@ struct llm_graph_context { // ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; + ggml_tensor * build_inp_diffusion_self_cond(int64_t n_vocab) const; // F32 [n_vocab, n_tokens] + // sparse self-cond: returns the input object exposing ->ids (I32 [k,n_tokens]) and ->probs (F32 [k,n_tokens]) + llm_graph_input_diffusion_self_cond_topk * build_inp_diffusion_self_cond_topk(int64_t k) const; + ggml_tensor * build_inp_diffusion_self_cond_embd(int64_t n_embd) const; // F32 [n_embd, n_tokens] ggml_tensor * build_inp_pos() const; ggml_tensor * build_inp_attn_scale() const; ggml_tensor * build_inp_out_ids() const; @@ -967,6 +1094,7 @@ struct llm_graph_context { int il) const; llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const; + llm_graph_input_attn_no_cache_prefix * build_attn_inp_no_cache_prefix(int64_t n_prompt) const; ggml_tensor * build_attn( llm_graph_input_attn_no_cache * inp, diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 2802103bdd82..0e7c92f91b0a 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1450,8 +1450,15 @@ void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ub const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - int64_t * data = (int64_t *) dst->data; + std::vector data_staging; + int64_t * data = nullptr; + const bool is_host = ggml_backend_buffer_is_host(dst->buffer); + if (is_host) { + data = (int64_t *) dst->data; + } else { + data_staging.resize(ggml_nelements(dst)); + data = data_staging.data(); + } for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { const int64_t offs = sinfo.strm[s]*get_size(); @@ -1460,14 +1467,25 @@ void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ub data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i]; } } + + if (!is_host) { + ggml_backend_tensor_set(dst, data_staging.data(), 0, ggml_nbytes(dst)); + } } void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - int64_t * data = (int64_t *) dst->data; + std::vector data_staging; + int64_t * data = nullptr; + const bool is_host = ggml_backend_buffer_is_host(dst->buffer); + if (is_host) { + data = (int64_t *) dst->data; + } else { + data_staging.resize(ggml_nelements(dst)); + data = data_staging.data(); + } if (!v_trans) { for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { @@ -1493,6 +1511,10 @@ void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ub } } } + + if (!is_host) { + ggml_backend_tensor_set(dst, data_staging.data(), 0, ggml_nbytes(dst)); + } } void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const { @@ -1715,8 +1737,6 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const uint32_t n_tokens = ubatch->n_tokens; - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - const int64_t n_kv = dst->ne[0]; const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch @@ -1740,9 +1760,21 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u }; if (dst->type == GGML_TYPE_F16) { - set_input_kq_mask_impl(args, (ggml_fp16_t *) dst->data, causal_attn); + if (ggml_backend_buffer_is_host(dst->buffer)) { + set_input_kq_mask_impl(args, (ggml_fp16_t *) dst->data, causal_attn); + } else { + std::vector data(ggml_nelements(dst)); + set_input_kq_mask_impl(args, data.data(), causal_attn); + ggml_backend_tensor_set(dst, data.data(), 0, ggml_nbytes(dst)); + } } else { - set_input_kq_mask_impl(args, (float *) dst->data, causal_attn); + if (ggml_backend_buffer_is_host(dst->buffer)) { + set_input_kq_mask_impl(args, (float *) dst->data, causal_attn); + } else { + std::vector data(ggml_nelements(dst)); + set_input_kq_mask_impl(args, data.data(), causal_attn); + ggml_backend_tensor_set(dst, data.data(), 0, ggml_nbytes(dst)); + } } //const int64_t t_end = ggml_time_us(); @@ -1776,21 +1808,27 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch } void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const { - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - const auto n_rot = dst->ne[0]; GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); - memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); + const auto & data = attn_rot_hadamard.at(n_rot); + if (ggml_backend_buffer_is_host(dst->buffer)) { + memcpy(dst->data, data.data(), ggml_nbytes(dst)); + } else { + ggml_backend_tensor_set(dst, data.data(), 0, ggml_nbytes(dst)); + } } void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const { - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - const auto n_rot = dst->ne[0]; GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); - memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); + const auto & data = attn_rot_hadamard.at(n_rot); + if (ggml_backend_buffer_is_host(dst->buffer)) { + memcpy(dst->data, data.data(), ggml_nbytes(dst)); + } else { + ggml_backend_tensor_set(dst, data.data(), 0, ggml_nbytes(dst)); + } } size_t llama_kv_cache::total_size() const { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4f12e0949acb..abd0efaabad0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -141,6 +141,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_gemma4(params); case LLM_ARCH_GEMMA4_ASSISTANT: return new llama_model_gemma4_assistant(params); + case LLM_ARCH_DIFFUSION_GEMMA: + return new llama_model_diffusion_gemma(params); case LLM_ARCH_GEMMA_EMBEDDING: return new llama_model_gemma_embedding(params); case LLM_ARCH_STARCODER2: @@ -1604,6 +1606,9 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } } + // per-arch precompute of derived tensors (data is now available) + load_arch_post(ml); + if (use_mmap_buffer) { for (auto & mapping : ml.mappings) { pimpl->mappings.emplace_back(std::move(mapping)); @@ -2121,7 +2126,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_memory_i::layer_reuse_cb reuse = nullptr; llama_kv_cache::layer_share_cb share = nullptr; - if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { + if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_DIFFUSION_GEMMA) { reuse = [&](uint32_t il) { GGML_ASSERT(hparams.n_layer_kv_from_start >= 2); @@ -2447,6 +2452,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA3N: case LLM_ARCH_GEMMA4: case LLM_ARCH_GEMMA4_ASSISTANT: + case LLM_ARCH_DIFFUSION_GEMMA: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: diff --git a/src/llama-model.h b/src/llama-model.h index 992c8d9c8fd9..4b94b044098e 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -647,6 +647,10 @@ struct llama_model { virtual void load_arch_tensors(llama_model_loader & ml) = 0; virtual std::unique_ptr build_arch_graph(const llm_graph_params & params) const = 0; + // optional per-arch hook run after all tensor data has been loaded; use to precompute + // derived tensors that are not stored in the GGUF (e.g. a transposed embedding). + virtual void load_arch_post(llama_model_loader & ml) { GGML_UNUSED(ml); } + protected: llama_model_params params; diff --git a/src/models/diffusion-gemma.cpp b/src/models/diffusion-gemma.cpp new file mode 100644 index 000000000000..e7a2d6e01727 --- /dev/null +++ b/src/models/diffusion-gemma.cpp @@ -0,0 +1,442 @@ +#include "models.h" + +#include "ggml-backend.h" +#include "ggml-alloc.h" + +#include + +// diffusion_gemma reuses the gemma4 decoder block (tensor layout + per-layer math) but runs +// as a bidirectional (non-causal) block-diffusion denoiser over a canvas, with KV-cache reuse: +// the prompt / previously-finalized canvases form a causal, read-only prefix in the unified +// sliding-window KV cache, and each denoising step decodes only the current canvas against it +// (self-conditioned, bidirectional), rolling back its own K/V afterwards. +// +// Two graph variants are provided (see build_arch_graph): a single phase-branching graph, and +// a separate encoder/decoder pair (--diffusion-separate-encoder-decoder). Both reuse the gemma4 transformer +// body and differ only in input-embedding handling (encoder: plain; decoder: self-conditioned). + +void llama_model_diffusion_gemma::load_arch_hparams(llama_model_loader & ml) { + // reuse the gemma4 hparam loading (sliding window pattern, dual head dims, MoE, rope, + // softcapping, layer types, ...) + llama_model_gemma4::load_arch_hparams(ml); + + // the diffusion decoder attends bidirectionally + hparams.causal_attn = false; +} + +void llama_model_diffusion_gemma::load_arch_tensors(llama_model_loader & ml) { + // load the shared gemma4 tensors (token embd, attention, dual dense+MoE FFN, norms, + // per-layer layer_scalar, output) + llama_model_gemma4::load_arch_tensors(ml); + + LLAMA_LOAD_LOCALS; + + // self_conditioning is a gated MLP at hidden_size -> intermediate_size -> hidden_size + const int64_t n_ff_sc = n_ff; + + self_cond_norm = create_tensor(tn(LLM_TENSOR_SELF_COND_NORM, "weight"), {n_embd}, 0); + self_cond_gate = create_tensor(tn(LLM_TENSOR_SELF_COND_GATE, "weight"), {n_embd, n_ff_sc}, 0); + self_cond_up = create_tensor(tn(LLM_TENSOR_SELF_COND_UP, "weight"), {n_embd, n_ff_sc}, 0); + self_cond_down = create_tensor(tn(LLM_TENSOR_SELF_COND_DOWN, "weight"), {n_ff_sc, n_embd}, 0); +} + +llama_model_diffusion_gemma::~llama_model_diffusion_gemma() { + if (tok_embd_gpu_buf) { + ggml_backend_buffer_free(tok_embd_gpu_buf); + } + if (tok_embd_gpu_ctx) { + ggml_free(tok_embd_gpu_ctx); + } + if (tok_embd_t_buf) { + ggml_backend_buffer_free(tok_embd_t_buf); + } + if (tok_embd_t_ctx) { + ggml_free(tok_embd_t_ctx); + } +} + +// Place the token embedding the diffusion decoder needs on an offloaded (GPU) backend. +// +// Primary path (sparse gather): store tok_embd {n_embd, n_vocab} as F16 in tok_embd_gpu on-device. +// The decoder graph gathers the canvas token rows and the self-conditioning top-k rows from this +// tensor. This keeps token-id driven embedding lookup on device: ~1.47 GiB resident, no per-decode +// token-id D2H for CPU row selection, no per-decode embedding H2D. F16 (not the native Q4_K) +// because CUDA get_rows has no Q4_K/Q6_K kernel -- a quantized gather would fall back to CPU every +// step (a large regression). F16 halves the VRAM vs the F32 dense path. +// +// Fallback path (dense matmul): if the F16 copy can't be allocated, precompute the transposed F32 +// embedding {n_vocab, n_embd} so the dense `probs @ token_embd` matmul (build_input) still runs +// on-device. This costs ~2.75 GiB and is only used when the gather copy is unavailable. +void llama_model_diffusion_gemma::load_arch_post(llama_model_loader & ml) { + GGML_UNUSED(ml); + + if (!tok_embd || !tok_embd->buffer) { + return; + } + + const int64_t n_embd_t = tok_embd->ne[0]; + const int64_t n_vocab_t = tok_embd->ne[1]; + + // Choose an offloaded (non-host) backend taken from a layer weight; fall back to token_embd's + // own buffer (CPU-only runs). Leaving the self-cond embedding host-resident would force the + // scheduler to stream it across PCIe every forward, which dominated the per-step time. + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(tok_embd->buffer); + for (const auto & layer : layers) { + ggml_tensor * t = layer.wq ? layer.wq : (layer.wk ? layer.wk : layer.ffn_down); + if (t && t->buffer) { + ggml_backend_buffer_type_t b = ggml_backend_buffer_get_type(t->buffer); + if (!ggml_backend_buft_is_host(b)) { buft = b; break; } + } + } + + // Dequantize the embedding to F32 once (host). It's needed by both paths below, and crucially + // CUDA ggml_get_rows supports F16/F32/BF16 but NOT Q4_K/Q6_K (see ggml-cuda supports_op) -- a + // gather from the native quantized type silently falls back to CPU every step. So the gather + // copy must be F16/F32; we use F16 to halve its VRAM (~1.47 GiB vs 2.75 GiB F32). + const int64_t n_elem = n_embd_t * n_vocab_t; + const auto * tt = ggml_get_type_traits(tok_embd->type); + if (!tt || !tt->to_float) { + LLAMA_LOG_WARN("%s: cannot dequantize token embedding type %s; self-conditioning will use " + "the per-decode transpose fallback\n", __func__, ggml_type_name(tok_embd->type)); + return; + } + + std::vector raw(ggml_nbytes(tok_embd)); + ggml_backend_tensor_get(tok_embd, raw.data(), 0, raw.size()); + + std::vector src((size_t) n_elem); + tt->to_float(raw.data(), src.data(), n_elem); + + // --- primary: F16 on-device copy for the sparse gather (get_rows runs on the GPU) --- + { + std::vector half((size_t) n_elem); + ggml_fp32_to_fp16_row(src.data(), half.data(), n_elem); + + ggml_init_params ip = { /*.mem_size =*/ ggml_tensor_overhead(), /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true }; + tok_embd_gpu_ctx = ggml_init(ip); + tok_embd_gpu = ggml_new_tensor_2d(tok_embd_gpu_ctx, GGML_TYPE_F16, n_embd_t, n_vocab_t); + ggml_set_name(tok_embd_gpu, "token_embd.gpu.f16"); + + tok_embd_gpu_buf = ggml_backend_alloc_ctx_tensors_from_buft(tok_embd_gpu_ctx, buft); + if (tok_embd_gpu_buf) { + ggml_backend_tensor_set(tok_embd_gpu, half.data(), 0, ggml_nbytes(tok_embd_gpu)); + + LLAMA_LOG_INFO("%s: placed diffusion token embedding {%lld, %lld} F16 (%.2f GiB) on %s (gather, k=%lld)\n", + __func__, (long long) n_embd_t, (long long) n_vocab_t, + ggml_nbytes(tok_embd_gpu) / (1024.0 * 1024.0 * 1024.0), + ggml_backend_buffer_name(tok_embd_gpu->buffer), + (long long) llama_model_diffusion_gemma::N_SC_TOPK); + return; + } + LLAMA_LOG_WARN("%s: failed to allocate on-device gather embedding; falling back to dense matmul\n", __func__); + ggml_free(tok_embd_gpu_ctx); + tok_embd_gpu_ctx = nullptr; + tok_embd_gpu = nullptr; + } + + // --- fallback: transposed F32 embedding for the dense matmul path --- + std::vector dst((size_t) n_elem); + for (int64_t e = 0; e < n_embd_t; ++e) { + for (int64_t v = 0; v < n_vocab_t; ++v) { + dst[(size_t) e * n_vocab_t + v] = src[(size_t) v * n_embd_t + e]; + } + } + + ggml_init_params ip = { /*.mem_size =*/ ggml_tensor_overhead(), /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true }; + tok_embd_t_ctx = ggml_init(ip); + tok_embd_t = ggml_new_tensor_2d(tok_embd_t_ctx, GGML_TYPE_F32, n_vocab_t, n_embd_t); + ggml_set_name(tok_embd_t, "token_embd_t.f32"); + + tok_embd_t_buf = ggml_backend_alloc_ctx_tensors_from_buft(tok_embd_t_ctx, buft); + if (!tok_embd_t_buf) { + LLAMA_LOG_WARN("%s: failed to allocate transposed embedding buffer; falling back\n", __func__); + ggml_free(tok_embd_t_ctx); + tok_embd_t_ctx = nullptr; + tok_embd_t = nullptr; + return; + } + ggml_backend_tensor_set(tok_embd_t, dst.data(), 0, ggml_nbytes(tok_embd_t)); + + LLAMA_LOG_INFO("%s: precomputed transposed F32 token embedding {%lld, %lld} (%.2f GiB) on %s (dense fallback)\n", + __func__, (long long) n_vocab_t, (long long) n_embd_t, + ggml_nbytes(tok_embd_t) / (1024.0 * 1024.0 * 1024.0), + ggml_backend_buffer_name(tok_embd_t->buffer)); +} + +std::unique_ptr llama_model_diffusion_gemma::build_arch_graph(const llm_graph_params & params) const { + const bool is_decoder = params.diffusion && params.diffusion->decoder_phase; + + // Variant B ("separate encoder and decoder block", shared weights): opt-in via + // --diffusion-separate-encoder-decoder. Two distinct graphs are built per phase. Functionally identical + // to Variant A here (the checkpoint shares encoder/decoder weights); the split mirrors + // the HF two-stack structure and generalizes to a checkpoint with distinct weights. + if (params.diffusion && params.diffusion->separate_encoder_decoder) { + if (is_decoder) { + return std::make_unique(*this, params); + } + return std::make_unique(*this, params); + } + + // Variant A ("single encoder/decoder block"): one graph that branches on the phase. + return std::make_unique(*this, params); +} + +// Scaled input embeddings. In the decoder phase, apply the self-conditioning transform: +// inpL = post_norm(scaled_embed + sc_mlp(pre_norm(soft))), +// soft = (probs @ token_embd) * sqrt(n_embd) [probs = previous step's softmax, 0 on step 1] +ggml_tensor * llama_model_diffusion_gemma::graph_base::build_input(bool is_decoder) { + const auto & dmodel = static_cast(model); + + ggml_tensor * inpL = build_inp_embd(dmodel.tok_embd_gpu ? dmodel.tok_embd_gpu : model.tok_embd); + if (dmodel.tok_embd_gpu) { + res->t_diffusion_token_embd = dmodel.tok_embd_gpu; + } + + // scaled word embeddings (sqrt(hidden_size)); raw embeddings input is not scaled + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); + + if (is_decoder) { + ggml_tensor * soft; // soft-embedding {n_embd, n_tokens}: blend of the previous step's + // predicted token embeddings, scaled by sqrt(n_embd) + if (dmodel.tok_embd_gpu && diffusion && diffusion->fused_self_cond_embd) { + soft = build_inp_diffusion_self_cond_embd(n_embd); + } else if (dmodel.tok_embd_gpu) { + // Sparse gather path (Option-2): the previous step's top-k token ids+probs are fed per + // position; gather just those k embedding rows and blend them, instead of the dense + // full-vocab `probs @ token_embd` matmul. Gather width is fixed (N_SC_TOPK) so the + // graph shape is constant; unused slots carry prob 0 (the CLI zero-pads). + const int64_t k = diffusion ? std::min(std::max(diffusion->self_cond_top_k, 1), llama_model_diffusion_gemma::N_SC_TOPK) : llama_model_diffusion_gemma::N_SC_TOPK; + auto * inp = build_inp_diffusion_self_cond_topk(k); + ggml_tensor * ids = inp->ids; // I32 {k*n_tokens} + ggml_tensor * probs = inp->probs; // F32 {k, n_tokens} + + // gather: {n_embd, n_vocab} x {k*n_tokens} ids -> {n_embd, k*n_tokens} -> {n_embd, k, n_tokens} + ggml_tensor * emb = ggml_get_rows(ctx0, dmodel.tok_embd_gpu, ids); // {n_embd, k*n_tokens} + emb = ggml_reshape_3d(ctx0, emb, n_embd, k, n_tokens); // {n_embd, k, n_tokens} + // weight each gathered row by its prob (broadcast over n_embd): {n_embd, k, n_tokens} + ggml_tensor * w = ggml_mul(ctx0, emb, ggml_reshape_3d(ctx0, probs, 1, k, n_tokens)); + // sum over k: bring k to ne[0] then sum_rows -> {1, n_embd, n_tokens} -> {n_embd, n_tokens} + w = ggml_cont(ctx0, ggml_permute(ctx0, w, 1, 0, 2, 3)); // {k, n_embd, n_tokens} + w = ggml_sum_rows(ctx0, w); // {1, n_embd, n_tokens} + soft = ggml_reshape_2d(ctx0, w, n_embd, n_tokens); // {n_embd, n_tokens} + } else { + // Dense fallback: soft = (probs @ token_embd). mul_mat contracts ne[0], so token_embd + // needs vocab as ne[0]; prefer the transposed F32 embedding from load_arch_post, else + // dequantize+transpose every decode (a quantized tensor can't be transposed directly). + ggml_tensor * probs = build_inp_diffusion_self_cond(model.tok_embd->ne[1]); // {n_vocab, n_tokens} + ggml_tensor * embed_t = dmodel.tok_embd_t; // {n_vocab, n_embd} + if (!embed_t) { + ggml_tensor * embed_f = ggml_cast(ctx0, model.tok_embd, GGML_TYPE_F32); // {n_embd, n_vocab} + embed_t = ggml_cont(ctx0, ggml_transpose(ctx0, embed_f)); // {n_vocab, n_embd} + } + soft = ggml_mul_mat(ctx0, embed_t, probs); // {n_embd, n_tokens} + } + soft = ggml_scale(ctx0, soft, sqrtf((float) n_embd)); + cb(soft, "self_cond_soft_embd", -1); + ggml_tensor * scn = build_norm(soft, dmodel.self_cond_norm, nullptr, LLM_NORM_RMS, -1); + ggml_tensor * sc = build_ffn(scn, + dmodel.self_cond_up, nullptr, nullptr, + dmodel.self_cond_gate, nullptr, nullptr, + dmodel.self_cond_down, nullptr, nullptr, + nullptr, LLM_FFN_GELU, LLM_FFN_PAR, -1); + inpL = ggml_rms_norm(ctx0, ggml_add(ctx0, inpL, sc), hparams.f_norm_rms_eps); // scale-less post_norm + cb(inpL, "self_cond_input", -1); + } + + return inpL; +} + +// Variant A: single graph, phase chosen at runtime from the diffusion cond. +llama_model_diffusion_gemma::graph::graph(const llama_model & model, const llm_graph_params & params) : + graph_base(model, params) { + build_transformer(build_input(diffusion && diffusion->decoder_phase)); +} + +// Variant B: separate encoder / decoder graphs (shared weight tensors). +llama_model_diffusion_gemma::graph_encoder::graph_encoder(const llama_model & model, const llm_graph_params & params) : + graph_base(model, params) { + build_transformer(build_input(/*is_decoder=*/false)); +} + +llama_model_diffusion_gemma::graph_decoder::graph_decoder(const llama_model & model, const llm_graph_params & params) : + graph_base(model, params) { + build_transformer(build_input(/*is_decoder=*/true)); +} + +// Run the reused gemma4 decoder block over the input embeddings and emit logits. +void llama_model_diffusion_gemma::graph_base::build_transformer(ggml_tensor * inpL) { + ggml_tensor * cur; + + ggml_tensor * inp_pos = build_inp_pos(); + + // Reuse the unified sliding-window KV cache: the canvas (decoder phase) reads the + // cached prompt/previous-canvas prefix; encoder-phase tokens write (commit) their KV. + // The prompt/canvas attention pattern is selected by cparams.causal_attn, toggled by + // the caller (encoder: causal; decoder: bidirectional + full cross to the prefix). + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_embd_head = hparams.n_embd_head_k(il); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v(il)); + + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + const int n_rot_l = hparams.n_rot(il); + + // full_attention layers use rope_freqs for proportional rope + ggml_tensor * freq_factors = hparams.is_swa(il) ? nullptr : model.layers[il].rope_freqs; + + // attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention (QK norm + scale-less V norm, mirrors Gemma4Attention) + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); + cb(Qcur, "Qcur", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, + freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + + // KV-sharing layers (n_kv_shared_layers) do not own a cache slot: they reuse an + // earlier layer's cached K/V (wk/wv/k_norm are absent). Mirror gemma4: compute and + // store K/V only for has_kv layers, otherwise pass nullptr (no store). + if (hparams.has_kv(il)) { + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + cb(Kcur, "Kcur", il); + // global (full-attention) layers have no v_proj: V = K (before norms) + ggml_tensor * Vcur = model.layers[il].wv + ? build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s) + : Kcur; + cb(Vcur, "Vcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); // scale-less v_norm + cb(Kcur, "Kcur_normed", il); + cb(Vcur, "Vcur_normed", il); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, + freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Kcur, "Kcur_pos", il); + + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + } else { + // reuse the cached K/V of an earlier layer + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, model.layers[il].wo_s, + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); + cb(attn_out, "attn_out", il); + + // feed-forward: dense MLP (shared expert) + routed MoE, summed (mirrors gemma4) + const bool is_moe_layer = model.layers[il].ffn_gate_inp != nullptr; + if (is_moe_layer) { + ggml_tensor * cur_mlp = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur_mlp, "ffn_norm_1", il); + cur_mlp = build_ffn(cur_mlp, + model.layers[il].ffn_up, nullptr, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, nullptr, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, LLM_FFN_GELU, LLM_FFN_PAR, il); + cur_mlp = build_norm(cur_mlp, model.layers[il].ffn_post_norm_1, nullptr, LLM_NORM_RMS, il); + cb(cur_mlp, "ffn_mlp", il); + + ggml_tensor * cur_moe = build_norm(attn_out, model.layers[il].ffn_pre_norm_2, nullptr, LLM_NORM_RMS, il); + cb(cur_moe, "ffn_norm_2", il); + + // router operates on attn_out (scale-less norm * 1/sqrt(n_embd) * router scale) + ggml_tensor * tmp = ggml_rms_norm(ctx0, attn_out, hparams.f_norm_rms_eps); + tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) n_embd)); + tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_gate_inp_s); + ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, tmp); + cb(logits, "ffn_moe_logits", il); + + cur_moe = build_moe_ffn(cur_moe, + nullptr, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_GELU, true, + 1.0f, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, logits, + model.layers[il].ffn_gate_up_exps, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cur_moe = build_norm(cur_moe, model.layers[il].ffn_post_norm_2, nullptr, LLM_NORM_RMS, il); + cb(cur_moe, "ffn_moe", il); + + cur = ggml_add(ctx0, cur_mlp, cur_moe); + cb(cur, "ffn_moe_combined", il); + } else { + cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, nullptr, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, attn_out); + + // layer_scalar + if (model.layers[il].out_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + } + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = build_norm(inpL, model.output_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur, model.output_s); + + const bool fuse_final_softcap = + diffusion && diffusion->decoder_phase && diffusion->fuse_final_logit_softcap; + + if (hparams.f_final_logit_softcapping && !fuse_final_softcap) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index c137e32e8fd1..e595a27a7816 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -821,6 +821,78 @@ struct llama_model_gemma4 : public llama_model_base { std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; +// Block-diffusion variant of Gemma 4. Reuses the gemma4 decoder block (tensor layout +// and per-layer math) but runs bidirectionally (non-causal, no KV cache) and applies +// the self-conditioning transform to the input embeddings. +// +// NOTE: this implements a single bidirectional denoising pass over the canvas with no +// prompt context. With soft-conditioning = 0 (the first denoising step) the +// self-conditioning module reduces to a scale-less RMS norm of the scaled embeddings. +// The soft-conditioning input path (later steps) and the encoder-KV cross-attention +// (prompted generation) are layered on top in a later step. +struct llama_model_diffusion_gemma : public llama_model_gemma4 { + llama_model_diffusion_gemma(const struct llama_model_params & params) : llama_model_gemma4(params) {} + ~llama_model_diffusion_gemma() override; + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + void load_arch_post(llama_model_loader & ml) override; // place the self-cond embedding on-device + + // width of the sparse (top-k) self-conditioning gather. The graph gathers exactly this many + // token embeddings per position each decoder step (the CLI feeds the top-N ids+probs, zero- + // padding unused slots), independent of the sampling k. 256 captures effectively all the + // softmax mass of a converged diffusion step, so the soft-embedding blend is near-exact. + static constexpr int64_t N_SC_TOPK = 256; + + // Shared transformer body for both KV-cache-reuse variants. The reused gemma4 decoder + // block (layers + final norm + tied lm_head + softcapping) is identical across phases; + // only the input embedding handling differs (encoder: plain; decoder: self-conditioned). + struct graph_base : public llm_graph_context { + const llama_model & model; + graph_base(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params), model(model) {} + // scaled input embeddings; if is_decoder, apply the self-conditioning transform + ggml_tensor * build_input(bool is_decoder); + // run the per-layer block over inpL (cached iswa attention) and emit logits + void build_transformer(ggml_tensor * inpL); + }; + + // Variant A ("single encoder/decoder block"): one graph that branches on the phase. + struct graph : public graph_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + // Variant B ("separate encoder and decoder block", shared weights): two graphs selected + // by phase. Enabled with the DG4_SEPARATE_ENC_DEC environment variable. + struct graph_encoder : public graph_base { + graph_encoder(const llama_model & model, const llm_graph_params & params); + }; + struct graph_decoder : public graph_base { + graph_decoder(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; + + // self_conditioning (top-level, not per-layer); post-norm is scale-less (no weight) + ggml_tensor * self_cond_norm = nullptr; + ggml_tensor * self_cond_gate = nullptr; + ggml_tensor * self_cond_up = nullptr; + ggml_tensor * self_cond_down = nullptr; + + // On-device F16 copy of the token embedding {n_embd, n_vocab}, used by the sparse self-cond + // gather (ggml_get_rows of the top-k ids). F16 (not the native Q4_K) because CUDA get_rows has + // no Q4_K/Q6_K kernel and would fall back to CPU. Allocated once in load_arch_post on an + // offloaded backend so the gather runs on-device. ~1.47 GiB vs the 2.75 GiB F32 dense transpose. + ggml_context * tok_embd_gpu_ctx = nullptr; + ggml_backend_buffer_t tok_embd_gpu_buf = nullptr; + ggml_tensor * tok_embd_gpu = nullptr; + + // Fallback only (used if the on-device gather copy can't be allocated): transposed F32 token + // embedding {n_vocab, n_embd} for the dense soft-embedding matmul (probs @ token_embd). + ggml_context * tok_embd_t_ctx = nullptr; + ggml_backend_buffer_t tok_embd_t_buf = nullptr; + ggml_tensor * tok_embd_t = nullptr; +}; + struct llama_model_gemma4_assistant : public llama_model_base { llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {}