207 lines
7.6 KiB
Python
207 lines
7.6 KiB
Python
"""
|
|
Whisper microservice: транскрибация аудио через faster-whisper.
|
|
|
|
Модель загружается по требованию (lazy load) и выгружается
|
|
через UNLOAD_AFTER секунд простоя. Явные endpoints /load и /unload
|
|
позволяют управлять памятью вручную (например, из n8n).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import gc
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, File, Form, UploadFile
|
|
from faster_whisper import BatchedInferencePipeline, WhisperModel
|
|
|
|
logger = logging.getLogger("whisper-service")
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s")
|
|
|
|
# ────────────────────────── config ──────────────────────────
|
|
|
|
MODEL_SIZE = os.getenv("WHISPER_MODEL", "large-v3")
|
|
DEVICE = os.getenv("WHISPER_DEVICE", "cuda")
|
|
COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
|
|
BATCH_SIZE = int(os.getenv("WHISPER_BATCH_SIZE", "16"))
|
|
UNLOAD_AFTER = int(os.getenv("WHISPER_UNLOAD_AFTER", "300")) # секунд простоя, 0 = никогда
|
|
|
|
# ────────────────────────── state ──────────────────────────
|
|
|
|
model: BatchedInferencePipeline | None = None
|
|
_model_lock = asyncio.Lock()
|
|
_last_used_at: float = 0.0
|
|
_unload_task: asyncio.Task | None = None
|
|
|
|
# ────────────────────────── model helpers ──────────────────────────
|
|
|
|
def _load_model_sync() -> BatchedInferencePipeline:
|
|
logger.info(f"Loading '{MODEL_SIZE}' on {DEVICE} ({COMPUTE_TYPE}), batch={BATCH_SIZE}…")
|
|
t0 = time.perf_counter()
|
|
base = WhisperModel(MODEL_SIZE, device=DEVICE, compute_type=COMPUTE_TYPE)
|
|
pipeline = BatchedInferencePipeline(model=base)
|
|
logger.info(f"Model loaded in {time.perf_counter() - t0:.1f}s")
|
|
return pipeline
|
|
|
|
|
|
def _unload_model_sync() -> None:
|
|
global model
|
|
if model is None:
|
|
return
|
|
del model
|
|
model = None
|
|
gc.collect()
|
|
# CTranslate2 освобождает VRAM при удалении объекта,
|
|
# дополнительный вызов не нужен
|
|
logger.info("Model unloaded, VRAM released")
|
|
|
|
|
|
async def _ensure_loaded() -> BatchedInferencePipeline:
|
|
global model, _last_used_at, _unload_task
|
|
async with _model_lock:
|
|
if model is None:
|
|
loop = asyncio.get_running_loop()
|
|
model = await loop.run_in_executor(None, _load_model_sync)
|
|
|
|
_last_used_at = time.monotonic()
|
|
|
|
if UNLOAD_AFTER > 0:
|
|
if _unload_task and not _unload_task.done():
|
|
_unload_task.cancel()
|
|
_unload_task = asyncio.create_task(_auto_unload_after(UNLOAD_AFTER))
|
|
|
|
return model
|
|
|
|
|
|
async def _auto_unload_after(seconds: int) -> None:
|
|
await asyncio.sleep(seconds)
|
|
async with _model_lock:
|
|
if model is not None and (time.monotonic() - _last_used_at) >= seconds:
|
|
logger.info(f"Auto-unloading after {seconds}s of inactivity…")
|
|
_unload_model_sync()
|
|
|
|
# ────────────────────────── lifespan ──────────────────────────
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger.info(f"Whisper service ready (lazy load, unload_after={UNLOAD_AFTER}s)")
|
|
yield
|
|
async with _model_lock:
|
|
_unload_model_sync()
|
|
|
|
|
|
app = FastAPI(title="Whisper Service", lifespan=lifespan)
|
|
|
|
# ────────────────────────── control endpoints ──────────────────────────
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {
|
|
"status": "ok",
|
|
"model": MODEL_SIZE,
|
|
"device": DEVICE,
|
|
"loaded": model is not None,
|
|
"idle_seconds": round(time.monotonic() - _last_used_at, 1) if _last_used_at else None,
|
|
}
|
|
|
|
|
|
@app.post("/load")
|
|
async def load_model():
|
|
"""Явная загрузка модели в VRAM (перед серией запросов)."""
|
|
await _ensure_loaded()
|
|
return {"status": "loaded", "model": MODEL_SIZE}
|
|
|
|
|
|
@app.post("/unload")
|
|
async def unload_model():
|
|
"""Явная выгрузка модели из VRAM (освобождаем память для Ollama)."""
|
|
global _unload_task
|
|
async with _model_lock:
|
|
if _unload_task and not _unload_task.done():
|
|
_unload_task.cancel()
|
|
_unload_model_sync()
|
|
return {"status": "unloaded"}
|
|
|
|
# ────────────────────────── transcribe ──────────────────────────
|
|
|
|
@app.post("/transcribe")
|
|
async def transcribe(
|
|
file: UploadFile = File(...),
|
|
language: str | None = Form(None),
|
|
initial_prompt: str | None = Form(None),
|
|
beam_size: int = Form(5),
|
|
batch_size: int = Form(BATCH_SIZE),
|
|
word_timestamps: bool = Form(True),
|
|
vad_filter: bool = Form(True),
|
|
):
|
|
"""Транскрибация загруженного аудиофайла."""
|
|
language = language if language and language not in ("string", "") else None
|
|
initial_prompt = initial_prompt if initial_prompt and initial_prompt not in ("string", "") else None
|
|
|
|
suffix = Path(file.filename).suffix if file.filename else ".wav"
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
|
content = await file.read()
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
pipeline = await _ensure_loaded()
|
|
|
|
t0 = time.perf_counter()
|
|
logger.info(
|
|
f"Received: {file.filename} ({len(content)} bytes), "
|
|
f"language={language}, model={MODEL_SIZE}, batch_size={batch_size}"
|
|
)
|
|
|
|
segments_iter, info = pipeline.transcribe(
|
|
tmp_path,
|
|
language=language,
|
|
initial_prompt=initial_prompt,
|
|
beam_size=beam_size,
|
|
batch_size=batch_size,
|
|
word_timestamps=word_timestamps,
|
|
vad_filter=vad_filter,
|
|
)
|
|
t1 = time.perf_counter()
|
|
logger.info(
|
|
f"Transcribe call returned in {t1 - t0:.1f}s. "
|
|
f"Language: {info.language} ({info.language_probability:.2f}), "
|
|
f"duration: {info.duration:.1f}s"
|
|
)
|
|
|
|
segments = []
|
|
for seg in segments_iter:
|
|
seg_data = {
|
|
"start": round(seg.start, 3),
|
|
"end": round(seg.end, 3),
|
|
"text": seg.text.strip(),
|
|
}
|
|
if word_timestamps and seg.words:
|
|
seg_data["words"] = [
|
|
{"start": round(w.start, 3), "end": round(w.end, 3), "word": w.word}
|
|
for w in seg.words
|
|
]
|
|
segments.append(seg_data)
|
|
|
|
elapsed = time.perf_counter() - t0
|
|
logger.info(
|
|
f"Done: {file.filename} — {info.duration:.1f}s audio, "
|
|
f"{len(segments)} segments, {elapsed:.1f}s total "
|
|
f"(RTF={elapsed / info.duration:.3f})"
|
|
)
|
|
|
|
return {
|
|
"language": info.language,
|
|
"language_probability": round(info.language_probability, 3),
|
|
"duration": round(info.duration, 3),
|
|
"processing_time": round(elapsed, 3),
|
|
"segments": segments,
|
|
}
|
|
|
|
finally:
|
|
Path(tmp_path).unlink(missing_ok=True) |