Files
2026-05-19 10:12:57 +00:00

207 lines
7.6 KiB
Python

"""
Whisper microservice: транскрибация аудио через faster-whisper.
Модель загружается по требованию (lazy load) и выгружается
через UNLOAD_AFTER секунд простоя. Явные endpoints /load и /unload
позволяют управлять памятью вручную (например, из n8n).
"""
from __future__ import annotations
import asyncio
import gc
import logging
import os
import tempfile
import time
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, File, Form, UploadFile
from faster_whisper import BatchedInferencePipeline, WhisperModel
logger = logging.getLogger("whisper-service")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s")
# ────────────────────────── config ──────────────────────────
MODEL_SIZE = os.getenv("WHISPER_MODEL", "large-v3")
DEVICE = os.getenv("WHISPER_DEVICE", "cuda")
COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
BATCH_SIZE = int(os.getenv("WHISPER_BATCH_SIZE", "16"))
UNLOAD_AFTER = int(os.getenv("WHISPER_UNLOAD_AFTER", "300")) # секунд простоя, 0 = никогда
# ────────────────────────── state ──────────────────────────
model: BatchedInferencePipeline | None = None
_model_lock = asyncio.Lock()
_last_used_at: float = 0.0
_unload_task: asyncio.Task | None = None
# ────────────────────────── model helpers ──────────────────────────
def _load_model_sync() -> BatchedInferencePipeline:
logger.info(f"Loading '{MODEL_SIZE}' on {DEVICE} ({COMPUTE_TYPE}), batch={BATCH_SIZE}")
t0 = time.perf_counter()
base = WhisperModel(MODEL_SIZE, device=DEVICE, compute_type=COMPUTE_TYPE)
pipeline = BatchedInferencePipeline(model=base)
logger.info(f"Model loaded in {time.perf_counter() - t0:.1f}s")
return pipeline
def _unload_model_sync() -> None:
global model
if model is None:
return
del model
model = None
gc.collect()
# CTranslate2 освобождает VRAM при удалении объекта,
# дополнительный вызов не нужен
logger.info("Model unloaded, VRAM released")
async def _ensure_loaded() -> BatchedInferencePipeline:
global model, _last_used_at, _unload_task
async with _model_lock:
if model is None:
loop = asyncio.get_running_loop()
model = await loop.run_in_executor(None, _load_model_sync)
_last_used_at = time.monotonic()
if UNLOAD_AFTER > 0:
if _unload_task and not _unload_task.done():
_unload_task.cancel()
_unload_task = asyncio.create_task(_auto_unload_after(UNLOAD_AFTER))
return model
async def _auto_unload_after(seconds: int) -> None:
await asyncio.sleep(seconds)
async with _model_lock:
if model is not None and (time.monotonic() - _last_used_at) >= seconds:
logger.info(f"Auto-unloading after {seconds}s of inactivity…")
_unload_model_sync()
# ────────────────────────── lifespan ──────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info(f"Whisper service ready (lazy load, unload_after={UNLOAD_AFTER}s)")
yield
async with _model_lock:
_unload_model_sync()
app = FastAPI(title="Whisper Service", lifespan=lifespan)
# ────────────────────────── control endpoints ──────────────────────────
@app.get("/health")
async def health():
return {
"status": "ok",
"model": MODEL_SIZE,
"device": DEVICE,
"loaded": model is not None,
"idle_seconds": round(time.monotonic() - _last_used_at, 1) if _last_used_at else None,
}
@app.post("/load")
async def load_model():
"""Явная загрузка модели в VRAM (перед серией запросов)."""
await _ensure_loaded()
return {"status": "loaded", "model": MODEL_SIZE}
@app.post("/unload")
async def unload_model():
"""Явная выгрузка модели из VRAM (освобождаем память для Ollama)."""
global _unload_task
async with _model_lock:
if _unload_task and not _unload_task.done():
_unload_task.cancel()
_unload_model_sync()
return {"status": "unloaded"}
# ────────────────────────── transcribe ──────────────────────────
@app.post("/transcribe")
async def transcribe(
file: UploadFile = File(...),
language: str | None = Form(None),
initial_prompt: str | None = Form(None),
beam_size: int = Form(5),
batch_size: int = Form(BATCH_SIZE),
word_timestamps: bool = Form(True),
vad_filter: bool = Form(True),
):
"""Транскрибация загруженного аудиофайла."""
language = language if language and language not in ("string", "") else None
initial_prompt = initial_prompt if initial_prompt and initial_prompt not in ("string", "") else None
suffix = Path(file.filename).suffix if file.filename else ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
pipeline = await _ensure_loaded()
t0 = time.perf_counter()
logger.info(
f"Received: {file.filename} ({len(content)} bytes), "
f"language={language}, model={MODEL_SIZE}, batch_size={batch_size}"
)
segments_iter, info = pipeline.transcribe(
tmp_path,
language=language,
initial_prompt=initial_prompt,
beam_size=beam_size,
batch_size=batch_size,
word_timestamps=word_timestamps,
vad_filter=vad_filter,
)
t1 = time.perf_counter()
logger.info(
f"Transcribe call returned in {t1 - t0:.1f}s. "
f"Language: {info.language} ({info.language_probability:.2f}), "
f"duration: {info.duration:.1f}s"
)
segments = []
for seg in segments_iter:
seg_data = {
"start": round(seg.start, 3),
"end": round(seg.end, 3),
"text": seg.text.strip(),
}
if word_timestamps and seg.words:
seg_data["words"] = [
{"start": round(w.start, 3), "end": round(w.end, 3), "word": w.word}
for w in seg.words
]
segments.append(seg_data)
elapsed = time.perf_counter() - t0
logger.info(
f"Done: {file.filename}{info.duration:.1f}s audio, "
f"{len(segments)} segments, {elapsed:.1f}s total "
f"(RTF={elapsed / info.duration:.3f})"
)
return {
"language": info.language,
"language_probability": round(info.language_probability, 3),
"duration": round(info.duration, 3),
"processing_time": round(elapsed, 3),
"segments": segments,
}
finally:
Path(tmp_path).unlink(missing_ok=True)