Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ MAX_JOBS=8 \
- `TORCH_CUDA_ARCH_LIST` — set to your GPU's compute capability: `8.0` (A100), `8.6` (A10/RTX 3090), `8.9` (L4/RTX 4090), `9.0` (H100/H200)
- `MAX_JOBS` — number of parallel compile jobs; 4–8 is typical, reduce if you run out of RAM during compilation

**Note:** `flash-attn` is not declared in `pyproject.toml`, so a plain `uv sync` will remove it. Use `uv sync --inexact` to install/update dependencies without removing packages that aren't in the lockfile:

```bash
uv sync --inexact
```

## Quick Start

Launch the Gradio UI:
Expand Down Expand Up @@ -199,6 +205,37 @@ audio_out = ae.decode(latents)

See [Autoencoder Workflows](docs/workflows/autoencoder.md) for encoding batches, chunked processing, and pre-encoding datasets for LoRA training.

## CLI

A `stable-audio` cli is included for running generation without writing any Python.

**Text-to-audio:**
```bash
stable-audio --model small-music -p "lo-fi hip hop beat, 90 BPM" --duration 30 -o beat.wav
```

**Audio-to-audio** — restyle an existing recording:
```bash
stable-audio -p "bossa nova bassline" --init-audio input.wav --init-noise-level 0.8 -o out.wav
```

**Inpainting** — regenerate a region while keeping the rest:
```bash
stable-audio -p "punchy kick drum fill" --inpaint-audio input.wav --inpaint-start 4 --inpaint-end 8 -o out.wav
```

**Continuation** — extend a clip beyond its original length:
```bash
stable-audio -p "dreamy synth outro" --inpaint-audio input.wav --inpaint-start 10 --inpaint-end 30 --duration 30 -o out.wav
```

**With a LoRA:**
```bash
stable-audio -p "orchestral strings" --lora-ckpt-path my_lora.safetensors --lora-strength 0.8 -o out.wav
```

Run `stable-audio --help` for the full list of flags.

## Hardware Support
Stable Audio 3 scales from a laptop to a GPU server.

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ torchaudio = [
{ index = "pytorch-cu126", marker = "sys_platform == 'linux' and platform_machine == 'x86_64'" }
]

[project.scripts]
stable-audio = "stable_audio_3.cli:main"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand Down
292 changes: 292 additions & 0 deletions stable_audio_3/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
"""
stable-audio — command-line interface for Stable Audio 3.

Basic usage::

stable-audio --model small-music -p "lo-fi hip hop beat, 90 BPM" --duration 30 -o beat.wav

"""

import argparse
import os
import torch
import torchaudio

from stable_audio_3 import StableAudioModel


def _save_output(audio: torch.Tensor, sample_rate: int, output: str, batch_size: int):
"""Save generated audio tensor(s) to disk."""
base, ext = os.path.splitext(output)
if not ext:
ext = ".wav"
for i in range(batch_size):
path = f"{base}_{i}{ext}" if batch_size > 1 else f"{base}{ext}"
torchaudio.save(path, audio[i].cpu(), sample_rate)
print(f"Saved: {path}")


def main():
parser = argparse.ArgumentParser(
prog="stable-audio",
description="Stable Audio 3 — CLI for text-to-audio, audio-to-audio, and inpainting",
)

# Model
parser.add_argument(
"--model",
default="medium",
choices=[
"medium",
"small-music",
"small-sfx",
"medium-base",
"small-music-base",
"small-sfx-base",
],
help="Model to load (default: medium)",
)
parser.add_argument(
"--device",
default=None,
help="Device: cuda / mps / cpu (auto-detected if omitted)",
)
parser.add_argument(
"--no-half", action="store_true", help="Disable half-precision (fp16) on CUDA"
)

# Generation
parser.add_argument(
"-p",
"--prompt",
required=True,
nargs="+",
help="Text prompt(s). Pass multiple for per-batch prompts",
)
parser.add_argument(
"--negative-prompt", nargs="+", default=None, help="Negative prompt(s)"
)
parser.add_argument(
"--duration",
type=float,
nargs="+",
default=[120.0],
help="Duration in seconds (default: 120). Pass multiple for per-batch durations",
)
parser.add_argument(
"--steps", type=int, default=8, help="Diffusion steps (default: 8)"
)
parser.add_argument(
"--cfg-scale",
type=float,
default=1.0,
help="CFG scale (default: 1.0; try 7.0 for base models)",
)
parser.add_argument(
"--seed", type=int, default=-1, help="Random seed (-1 = random, default: -1)"
)
parser.add_argument(
"--batch-size",
type=int,
default=None,
help="Batch size (default: inferred from number of prompts, or 1)",
)
parser.add_argument(
"-o",
"--output",
default="output.wav",
help="Output file path (default: output.wav)",
)

# Audio-to-Audio
parser.add_argument(
"--init-audio",
default=None,
metavar="PATH",
help="Source audio file for audio-to-audio generation",
)
parser.add_argument(
"--init-noise-level",
type=float,
default=0.9,
help="Noise level for audio-to-audio (0.0–1.0, default: 0.9)",
)
Comment on lines +108 to +113

# Inpainting / Continuation
parser.add_argument(
"--inpaint-audio",
default=None,
metavar="PATH",
help="Source audio file for inpainting or continuation",
)
parser.add_argument(
"--inpaint-start",
type=float,
action="append",
dest="inpaint_starts",
metavar="SECONDS",
help="Start of inpaint region in seconds. Repeat for multiple regions.",
)
parser.add_argument(
"--inpaint-end",
type=float,
action="append",
dest="inpaint_ends",
metavar="SECONDS",
help="End of inpaint region in seconds. Repeat for multiple regions.",
)

# Chunked decode
decode_group = parser.add_mutually_exclusive_group()
decode_group.add_argument(
"--chunked-decode",
action="store_true",
default=None,
help="Force chunked decoding on",
)
decode_group.add_argument(
"--no-chunked-decode",
action="store_true",
default=None,
help="Force chunked decoding off",
)

# LoRA
parser.add_argument(
"--lora-ckpt-path",
action="append",
dest="loras",
metavar="PATH",
help="LoRA checkpoint path. Repeat to stack multiple LoRAs.",
)
parser.add_argument(
"--lora-strength",
type=float,
default=None,
help="LoRA strength (applied to all LoRAs)",
)
parser.add_argument(
"--lora-index",
type=int,
default=None,
help="Target a specific LoRA index when setting strength",
)

args = parser.parse_args()

# --- Validate inpaint args ---
if (args.inpaint_starts is None) != (args.inpaint_ends is None):
parser.error("--inpaint-start and --inpaint-end must both be provided together")
if args.inpaint_starts and len(args.inpaint_starts) != len(args.inpaint_ends):
parser.error(
"--inpaint-start and --inpaint-end must be specified the same number of times"
)
if args.inpaint_starts and not args.inpaint_audio:
parser.error("--inpaint-start/--inpaint-end require --inpaint-audio")
if args.inpaint_audio and not args.inpaint_starts:
parser.error("--inpaint-audio requires --inpaint-start and --inpaint-end")

# --- Resolve batch size ---
n_prompts = len(args.prompt)
if args.batch_size is None:
batch_size = n_prompts
elif n_prompts > 1 and args.batch_size != n_prompts:
parser.error(
f"--batch-size {args.batch_size} does not match the number of prompts "
f"({n_prompts}); omit --batch-size to have it inferred automatically"
)
else:
batch_size = args.batch_size

# --- Validate list-flag lengths against batch size ---
if (
args.negative_prompt
and len(args.negative_prompt) > 1
and len(args.negative_prompt) != batch_size
):
parser.error(
f"Got {len(args.negative_prompt)} --negative-prompt values but batch size is {batch_size}"
)
if len(args.duration) > 1 and len(args.duration) != batch_size:
parser.error(
f"Got {len(args.duration)} --duration values but batch size is {batch_size}"
)

# --- Build scalar / list args ---
prompt = args.prompt[0] if len(args.prompt) == 1 else args.prompt
negative_prompt = None
if args.negative_prompt:
negative_prompt = (
args.negative_prompt[0]
Comment on lines +216 to +220
if len(args.negative_prompt) == 1
else args.negative_prompt
)
duration = args.duration[0] if len(args.duration) == 1 else args.duration

# --- chunked_decode flag ---
chunked_decode = None
if args.chunked_decode:
chunked_decode = True
elif args.no_chunked_decode:
chunked_decode = False

# --- Load model ---
print(f"Loading model '{args.model}'…")
model = StableAudioModel.from_pretrained(
args.model, device=args.device, model_half=not args.no_half
)

# --- LoRA ---
if args.loras:
print(f"Loading LoRA(s): {args.loras}")
model.load_lora(args.loras)
if args.lora_strength is not None:
model.set_lora_strength(args.lora_strength, lora_index=args.lora_index)
Comment on lines +240 to +244
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine


# --- Load audio inputs ---
# torchaudio.load returns (waveform, sample_rate); model.generate expects (sample_rate, waveform)
init_audio = None
if args.init_audio:
waveform, sr = torchaudio.load(args.init_audio)
init_audio = (sr, waveform)

inpaint_audio = None
if args.inpaint_audio:
waveform, sr = torchaudio.load(args.inpaint_audio)
inpaint_audio = (sr, waveform)

inpaint_start = None
inpaint_end = None
if args.inpaint_starts:
inpaint_start = (
args.inpaint_starts[0]
if len(args.inpaint_starts) == 1
else args.inpaint_starts
)
inpaint_end = (
args.inpaint_ends[0] if len(args.inpaint_ends) == 1 else args.inpaint_ends
)

# --- Generate ---
print("Generating…")
audio = model.generate(
prompt=prompt,
negative_prompt=negative_prompt,
duration=duration,
steps=args.steps,
cfg_scale=args.cfg_scale,
seed=args.seed,
batch_size=batch_size,
init_audio=init_audio,
init_noise_level=args.init_noise_level,
inpaint_audio=inpaint_audio,
inpaint_mask_start_seconds=inpaint_start,
inpaint_mask_end_seconds=inpaint_end,
chunked_decode=chunked_decode,
)

_save_output(audio, model.model.sample_rate, args.output, batch_size)


if __name__ == "__main__":
main()
Loading
Loading