Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ This starts a local web interface with a shareable link. To load a LoRA checkpoi
uv run python run_gradio.py --model medium --lora-ckpt-path path/to/lora.ckpt
```

For a minimal WebUI that maps closely to the `stable-audio` CLI flags, run:

```bash
uv run python run_basic_webui.py --model medium
```

## Usage

Stable Audio 3 supports several inference modes. For full details, see [Inference Methods](docs/workflows/inference.md).
Expand Down
29 changes: 29 additions & 0 deletions run_basic_webui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import argparse

from stable_audio_3.interface.basic_cli_webui import launch_basic_cli_webui


def main():
parser = argparse.ArgumentParser(description="Run basic Stable Audio CLI-parity WebUI")
parser.add_argument(
"--model",
default="medium",
choices=[
"medium",
"small-music",
"small-sfx",
"medium-base",
"small-music-base",
"small-sfx-base",
],
help="Model to load",
)
parser.add_argument("--device", default=None, help="cuda/mps/cpu (optional)")
parser.add_argument("--no-half", action="store_true", help="Disable fp16 on CUDA")

args = parser.parse_args()
launch_basic_cli_webui(model_name=args.model, device=args.device, no_half=args.no_half)


if __name__ == "__main__":
main()
213 changes: 213 additions & 0 deletions stable_audio_3/interface/basic_cli_webui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import os
import tempfile
from typing import Optional

import gradio as gr
import torch
import torchaudio

from stable_audio_3 import StableAudioModel


MODEL_CHOICES = [
"medium",
"small-music",
"small-sfx",
"medium-base",
"small-music-base",
"small-sfx-base",
]


def _to_audio_tuple(audio_path: Optional[str]):
if not audio_path:
return None
waveform, sr = torchaudio.load(audio_path)
return (sr, waveform)


def _normalize_optional_text(s: str):
s = (s or "").strip()
return s if s else None


def _parse_float_list(csv_text: str):
text = (csv_text or "").strip()
if not text:
return None
values = [v.strip() for v in text.split(",") if v.strip()]
if not values:
return None
parsed = [float(v) for v in values]
return parsed[0] if len(parsed) == 1 else parsed


def _build_ui(model_name: str, device: Optional[str], no_half: bool):
state = {"model": None, "loaded": None}

def ensure_model(current_model_name: str, current_device: str, current_no_half: bool):
target = (current_model_name, current_device or None, current_no_half)
if state["loaded"] != target:
state["model"] = StableAudioModel.from_pretrained(
current_model_name,
device=(current_device or None),
model_half=not current_no_half,
)
state["loaded"] = target
return state["model"]

def generate(
model_name,
device,
no_half,
prompt,
negative_prompt,
duration,
steps,
cfg_scale,
seed,
batch_size,
init_audio,
init_noise_level,
inpaint_audio,
inpaint_starts_csv,
inpaint_ends_csv,
chunked_decode_mode,
lora_ckpt_paths_csv,
lora_strength,
lora_index,
output_basename,
):
prompt = _normalize_optional_text(prompt)
if not prompt:
raise gr.Error("Prompt is required")

model = ensure_model(model_name, device, no_half)

lora_paths = [p.strip() for p in (lora_ckpt_paths_csv or "").split(",") if p.strip()]
if lora_paths:
model.load_lora(lora_paths)
if lora_strength is not None:
model.set_lora_strength(lora_strength, lora_index=int(lora_index) if lora_index >= 0 else None)

inpaint_starts = _parse_float_list(inpaint_starts_csv)
inpaint_ends = _parse_float_list(inpaint_ends_csv)
if (inpaint_starts is None) != (inpaint_ends is None):
raise gr.Error("inpaint-start and inpaint-end must both be set")

chunked_decode = None
if chunked_decode_mode == "on":
chunked_decode = True
elif chunked_decode_mode == "off":
chunked_decode = False

init_audio_tuple = _to_audio_tuple(init_audio)
inpaint_audio_tuple = _to_audio_tuple(inpaint_audio)

audio = model.generate(
prompt=prompt,
negative_prompt=_normalize_optional_text(negative_prompt),
duration=float(duration),
steps=int(steps),
cfg_scale=float(cfg_scale),
seed=int(seed),
batch_size=int(batch_size),
init_audio=init_audio_tuple,
init_noise_level=float(init_noise_level),
inpaint_audio=inpaint_audio_tuple,
inpaint_mask_start_seconds=inpaint_starts,
inpaint_mask_end_seconds=inpaint_ends,
chunked_decode=chunked_decode,
)

sr = model.model.sample_rate
first = audio[0].detach().cpu()

safe_name = (output_basename or "output").strip() or "output"
fd, out_path = tempfile.mkstemp(prefix=f"{safe_name}_", suffix=".wav")
os.close(fd)
torchaudio.save(out_path, first, sr)

return (sr, first.numpy().T), out_path

with gr.Blocks(title="Stable Audio 3 Basic CLI WebUI") as demo:
gr.Markdown("# Stable Audio 3 — Basic CLI WebUI\nMatches the main stable-audio CLI options in a minimal form.")

with gr.Row():
model_dd = gr.Dropdown(MODEL_CHOICES, value=model_name, label="--model")
device_tb = gr.Textbox(value=device or "", label="--device (optional: cuda/mps/cpu)")
no_half_cb = gr.Checkbox(value=no_half, label="--no-half")

prompt_tb = gr.Textbox(label="--prompt", lines=3, placeholder="Describe the audio...")
negative_prompt_tb = gr.Textbox(label="--negative-prompt", lines=2)

with gr.Row():
duration_num = gr.Number(value=30, label="--duration")
steps_num = gr.Number(value=8, label="--steps", precision=0)
cfg_scale_num = gr.Number(value=1.0, label="--cfg-scale")
seed_num = gr.Number(value=-1, label="--seed", precision=0)
batch_size_num = gr.Number(value=1, label="--batch-size", precision=0)

with gr.Accordion("Audio-to-audio", open=False):
init_audio_in = gr.Audio(type="filepath", label="--init-audio")
init_noise_num = gr.Number(value=0.9, label="--init-noise-level")

with gr.Accordion("Inpainting / continuation", open=False):
inpaint_audio_in = gr.Audio(type="filepath", label="--inpaint-audio")
inpaint_starts_tb = gr.Textbox(label="--inpaint-start (comma-separated)", placeholder="4,16")
inpaint_ends_tb = gr.Textbox(label="--inpaint-end (comma-separated)", placeholder="8,20")

with gr.Accordion("Decode + LoRA", open=False):
chunked_decode_mode = gr.Radio(
["auto", "on", "off"],
value="auto",
label="chunked decode (--chunked-decode / --no-chunked-decode)",
)
lora_paths_tb = gr.Textbox(
label="--lora-ckpt-path (comma-separated)",
placeholder="/path/a.safetensors,/path/b.safetensors",
)
with gr.Row():
lora_strength_num = gr.Number(value=None, label="--lora-strength")
lora_index_num = gr.Number(value=-1, label="--lora-index (-1 = all)", precision=0)

output_basename_tb = gr.Textbox(value="output", label="--output basename")
run_btn = gr.Button("Generate", variant="primary")

audio_out = gr.Audio(label="Generated audio")
file_out = gr.File(label="Download WAV")

run_btn.click(
generate,
inputs=[
model_dd,
device_tb,
no_half_cb,
prompt_tb,
negative_prompt_tb,
duration_num,
steps_num,
cfg_scale_num,
seed_num,
batch_size_num,
init_audio_in,
init_noise_num,
inpaint_audio_in,
inpaint_starts_tb,
inpaint_ends_tb,
chunked_decode_mode,
lora_paths_tb,
lora_strength_num,
lora_index_num,
output_basename_tb,
],
outputs=[audio_out, file_out],
)

return demo


def launch_basic_cli_webui(model_name: str = "medium", device: Optional[str] = None, no_half: bool = False):
demo = _build_ui(model_name=model_name, device=device, no_half=no_half)
demo.queue()
demo.launch(share=True)