-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi_power.py
More file actions
135 lines (114 loc) · 4.66 KB
/
api_power.py
File metadata and controls
135 lines (114 loc) · 4.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import io
import os
import time
import shutil
import tempfile
import asyncio
import torch
import torchaudio
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import StreamingResponse
from speechbrain.inference.separation import SepformerSeparation
# --- 全局变量 ---
model = None
device = None
gpu_lock = asyncio.Lock() # 必须加锁,防止并发请求导致 GPU 显存冲突
# --- 1. 生命周期管理 (启动加载 & 预热) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, device
print("⏳ [Startup] 正在初始化环境...")
# 硬件配置
if torch.cuda.is_available():
device = "cuda"
# 关闭 benchmark 避免首次动态搜索算法耗时
torch.backends.cudnn.benchmark = False
# Ampere 架构开启 TF32
if torch.cuda.get_device_capability()[0] >= 8:
torch.set_float32_matmul_precision('high')
print(f"🚀 使用设备: GPU ({torch.cuda.get_device_name(0)})")
else:
device = "cpu"
print("⚠️ 使用设备: CPU")
# 加载模型
print("⏳ [Startup] 正在加载模型 (常驻内存)...")
run_opts = {"device": device}
model = SepformerSeparation.from_hparams(
source="speechbrain/sepformer-wsj03mix",
savedir="pretrained_models/sepformer-wsj03mix",
run_opts=run_opts
)
model.eval()
# 预热 GPU
if device == "cuda":
print("🔥 [Startup] 正在预热 GPU...")
dummy_input = torch.randn(1, 8000).to(device)
with torch.inference_mode():
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
_ = model.separate_batch(dummy_input)
torch.cuda.synchronize()
print("✅ [Startup] 预热完成,服务已就绪!")
yield
# 关闭时清理
print("🛑 [Shutdown] 服务关闭,清理资源...")
if device == "cuda":
torch.cuda.empty_cache()
# --- 2. 初始化 FastAPI ---
app = FastAPI(title="Audio Separation API", lifespan=lifespan)
# --- 3. 核心接口逻辑 ---
@app.post("/separate")
async def separate_audio_endpoint(file: UploadFile = File(...)):
"""
上传混合音频,返回分离后能量最大的音频流 (WAV格式)
"""
global model, device
# 步骤 A: 保存上传的文件到临时目录
# SpeechBrain 需要文件路径作为输入
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_input:
shutil.copyfileobj(file.file, temp_input)
temp_input_path = temp_input.name
try:
# 步骤 B: 获取 GPU 锁并执行推理
# 使用 async with gpu_lock 确保同一时间只有一个请求在使用 GPU
async with gpu_lock:
start_time = time.time()
# --- 极速推理核心 ---
with torch.inference_mode():
if device == "cuda":
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
est_sources = model.separate_file(path=temp_input_path)
else:
est_sources = model.separate_file(path=temp_input_path)
# --- 能量筛选 (GPU内完成) ---
# est_sources: [batch=1, time, sources]
# 计算平方和能量,找出最大值的索引
energies = est_sources.pow(2).sum(dim=1).squeeze()
best_idx = torch.argmax(energies).item()
best_source = est_sources[:, :, best_idx]
if device == "cuda":
torch.cuda.synchronize()
infer_time = time.time() - start_time
print(f"✅ [Request] 推理完成,耗时: {infer_time:.4f}s | 选中源索引: {best_idx}")
# 步骤 C: 将结果写入内存 Buffer (不写磁盘,速度更快)
# 必须转回 float32 否则 wav 编码会报错
source_cpu = best_source.detach().cpu().float()
buffer = io.BytesIO()
torchaudio.save(buffer, source_cpu, 8000, format="wav")
buffer.seek(0) # 指针回到开头
# 步骤 D: 返回流式响应
return StreamingResponse(
buffer,
media_type="audio/wav",
headers={"Content-Disposition": f"attachment; filename=best_source_{best_idx}.wav"}
)
except Exception as e:
return {"error": str(e)}
finally:
# 步骤 E: 清理临时输入文件
if os.path.exists(temp_input_path):
os.remove(temp_input_path)
if __name__ == "__main__":
import uvicorn
# 启动服务
uvicorn.run(app, host="0.0.0.0", port=8000)