-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmain.py
More file actions
156 lines (121 loc) · 4.72 KB
/
main.py
File metadata and controls
156 lines (121 loc) · 4.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Based on https://github.com/morioka/tiny-openai-whisper-api
"""
import os
import shutil
from datetime import timedelta
from functools import lru_cache
from typing import Iterable, Optional, Tuple
import uvicorn
from faster_whisper import WhisperModel
from faster_whisper.transcribe import Segment, TranscriptionInfo
from fastapi import FastAPI, Form, UploadFile, File
from fastapi import HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
MODEL_SIZE = os.getenv("MODEL_SIZE", "base")
DEVICE = os.getenv("DEVICE", "auto")
COMPUTE_TYPE = os.getenv("COMPUTE_TYPE", "default")
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@lru_cache(maxsize=1)
def get_whisper_model(model_size: str):
"""Get a whisper model from the cache or download it if it doesn't exist"""
model = WhisperModel(model_size, device=DEVICE, compute_type=COMPUTE_TYPE)
return model
def transcribe(
audio_path: str, whisper_model: str, **whisper_args
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""Transcribe the audio file using whisper"""
# Get whisper model
# NOTE: If multiple models are selected, this may keep all of them in memory depending on the cache size
transcriber = get_whisper_model(whisper_model)
return transcriber.transcribe(
audio_path,
**whisper_args,
)
UPLOAD_DIR = "tmp"
@app.post("/v1/audio/transcriptions")
async def transcriptions(
model: str = Form(...),
file: UploadFile = File(...),
response_format: Optional[str] = Form(None),
temperature: Optional[float] = Form(None),
settings_override: Optional[dict] = Form(None),
):
assert model == "whisper-1"
if file is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Bad Request, bad file"
)
if response_format is None:
response_format = "json"
if response_format not in ["json", "text", "srt", "verbose_json", "vtt"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Bad Request, bad response_format",
)
if temperature is None:
temperature = 0.0
if temperature < 0.0 or temperature > 1.0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Bad Request, bad temperature",
)
filename = file.filename
fileobj = file.file
upload_name = os.path.join(UPLOAD_DIR, filename)
if not os.path.exists(UPLOAD_DIR):
os.makedirs(UPLOAD_DIR)
with open(upload_name, "wb+") as upload_file:
shutil.copyfileobj(fileobj, upload_file)
whisper_args = {
"whisper_model": MODEL_SIZE,
}
if settings_override is not None:
whisper_args.update(settings_override)
segments, _ = transcribe(audio_path=upload_name, **whisper_args)
if response_format in ["text", "json"]:
return {
"text": "\n".join(seg.text for seg in segments),
}
if response_format in ["srt"]:
ret = ""
for seg in segments:
td_s = timedelta(milliseconds=seg.start * 1000)
td_e = timedelta(milliseconds=seg.end * 1000)
t_s = f"{td_s.seconds // 3600:02}:{(td_s.seconds // 60) % 60:02}:{td_s.seconds % 60:02}.{td_s.microseconds // 1000:03}"
t_e = f"{td_e.seconds // 3600:02}:{(td_e.seconds // 60) % 60:02}:{td_e.seconds % 60:02}.{td_e.microseconds // 1000:03}"
ret += "{}\n{} --> {}\n{}\n\n".format(seg.id, t_s, t_e, seg.text)
ret += "\n"
return ret
if response_format in ["vtt"]:
ret = "WEBVTT\n\n"
for seg in segments:
td_s = timedelta(milliseconds=seg["start"] * 1000)
td_e = timedelta(milliseconds=seg["end"] * 1000)
t_s = f"{td_s.seconds // 3600:02}:{(td_s.seconds // 60) % 60:02}:{td_s.seconds % 60:02}.{td_s.microseconds // 1000:03}"
t_e = f"{td_e.seconds // 3600:02}:{(td_e.seconds // 60) % 60:02}:{td_e.seconds % 60:02}.{td_e.microseconds // 1000:03}"
ret += "{} --> {}\n{}\n\n".format(t_s, t_e, seg["text"])
return ret
if response_format in ["json"]:
raise NotImplementedError("Json format not implemented")
if response_format in ["verbose_json"]:
segments_list = list(segments)
return {
"text": "\n".join(seg.text for seg in segments_list),
"task": "transcribe",
"duration": segments_list[-1].end,
}
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Bad Request, bad response_format",
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)