build: first commit
This commit is contained in:
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Whisper microservice: транскрибация аудио через faster-whisper.
|
||||
|
||||
Модель загружается по требованию (lazy load) и выгружается
|
||||
через UNLOAD_AFTER секунд простоя. Явные endpoints /load и /unload
|
||||
позволяют управлять памятью вручную (например, из n8n).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, File, Form, UploadFile
|
||||
from faster_whisper import BatchedInferencePipeline, WhisperModel
|
||||
|
||||
logger = logging.getLogger("whisper-service")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s")
|
||||
|
||||
# ────────────────────────── config ──────────────────────────
|
||||
|
||||
MODEL_SIZE = os.getenv("WHISPER_MODEL", "large-v3")
|
||||
DEVICE = os.getenv("WHISPER_DEVICE", "cuda")
|
||||
COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
|
||||
BATCH_SIZE = int(os.getenv("WHISPER_BATCH_SIZE", "16"))
|
||||
UNLOAD_AFTER = int(os.getenv("WHISPER_UNLOAD_AFTER", "300")) # секунд простоя, 0 = никогда
|
||||
|
||||
# ────────────────────────── state ──────────────────────────
|
||||
|
||||
model: BatchedInferencePipeline | None = None
|
||||
_model_lock = asyncio.Lock()
|
||||
_last_used_at: float = 0.0
|
||||
_unload_task: asyncio.Task | None = None
|
||||
|
||||
# ────────────────────────── model helpers ──────────────────────────
|
||||
|
||||
def _load_model_sync() -> BatchedInferencePipeline:
|
||||
logger.info(f"Loading '{MODEL_SIZE}' on {DEVICE} ({COMPUTE_TYPE}), batch={BATCH_SIZE}…")
|
||||
t0 = time.perf_counter()
|
||||
base = WhisperModel(MODEL_SIZE, device=DEVICE, compute_type=COMPUTE_TYPE)
|
||||
pipeline = BatchedInferencePipeline(model=base)
|
||||
logger.info(f"Model loaded in {time.perf_counter() - t0:.1f}s")
|
||||
return pipeline
|
||||
|
||||
|
||||
def _unload_model_sync() -> None:
|
||||
global model
|
||||
if model is None:
|
||||
return
|
||||
del model
|
||||
model = None
|
||||
gc.collect()
|
||||
# CTranslate2 освобождает VRAM при удалении объекта,
|
||||
# дополнительный вызов не нужен
|
||||
logger.info("Model unloaded, VRAM released")
|
||||
|
||||
|
||||
async def _ensure_loaded() -> BatchedInferencePipeline:
|
||||
global model, _last_used_at, _unload_task
|
||||
async with _model_lock:
|
||||
if model is None:
|
||||
loop = asyncio.get_running_loop()
|
||||
model = await loop.run_in_executor(None, _load_model_sync)
|
||||
|
||||
_last_used_at = time.monotonic()
|
||||
|
||||
if UNLOAD_AFTER > 0:
|
||||
if _unload_task and not _unload_task.done():
|
||||
_unload_task.cancel()
|
||||
_unload_task = asyncio.create_task(_auto_unload_after(UNLOAD_AFTER))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
async def _auto_unload_after(seconds: int) -> None:
|
||||
await asyncio.sleep(seconds)
|
||||
async with _model_lock:
|
||||
if model is not None and (time.monotonic() - _last_used_at) >= seconds:
|
||||
logger.info(f"Auto-unloading after {seconds}s of inactivity…")
|
||||
_unload_model_sync()
|
||||
|
||||
# ────────────────────────── lifespan ──────────────────────────
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info(f"Whisper service ready (lazy load, unload_after={UNLOAD_AFTER}s)")
|
||||
yield
|
||||
async with _model_lock:
|
||||
_unload_model_sync()
|
||||
|
||||
|
||||
app = FastAPI(title="Whisper Service", lifespan=lifespan)
|
||||
|
||||
# ────────────────────────── control endpoints ──────────────────────────
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"model": MODEL_SIZE,
|
||||
"device": DEVICE,
|
||||
"loaded": model is not None,
|
||||
"idle_seconds": round(time.monotonic() - _last_used_at, 1) if _last_used_at else None,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/load")
|
||||
async def load_model():
|
||||
"""Явная загрузка модели в VRAM (перед серией запросов)."""
|
||||
await _ensure_loaded()
|
||||
return {"status": "loaded", "model": MODEL_SIZE}
|
||||
|
||||
|
||||
@app.post("/unload")
|
||||
async def unload_model():
|
||||
"""Явная выгрузка модели из VRAM (освобождаем память для Ollama)."""
|
||||
global _unload_task
|
||||
async with _model_lock:
|
||||
if _unload_task and not _unload_task.done():
|
||||
_unload_task.cancel()
|
||||
_unload_model_sync()
|
||||
return {"status": "unloaded"}
|
||||
|
||||
# ────────────────────────── transcribe ──────────────────────────
|
||||
|
||||
@app.post("/transcribe")
|
||||
async def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
language: str | None = Form(None),
|
||||
initial_prompt: str | None = Form(None),
|
||||
beam_size: int = Form(5),
|
||||
batch_size: int = Form(BATCH_SIZE),
|
||||
word_timestamps: bool = Form(True),
|
||||
vad_filter: bool = Form(True),
|
||||
):
|
||||
"""Транскрибация загруженного аудиофайла."""
|
||||
language = language if language and language not in ("string", "") else None
|
||||
initial_prompt = initial_prompt if initial_prompt and initial_prompt not in ("string", "") else None
|
||||
|
||||
suffix = Path(file.filename).suffix if file.filename else ".wav"
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
pipeline = await _ensure_loaded()
|
||||
|
||||
t0 = time.perf_counter()
|
||||
logger.info(
|
||||
f"Received: {file.filename} ({len(content)} bytes), "
|
||||
f"language={language}, model={MODEL_SIZE}, batch_size={batch_size}"
|
||||
)
|
||||
|
||||
segments_iter, info = pipeline.transcribe(
|
||||
tmp_path,
|
||||
language=language,
|
||||
initial_prompt=initial_prompt,
|
||||
beam_size=beam_size,
|
||||
batch_size=batch_size,
|
||||
word_timestamps=word_timestamps,
|
||||
vad_filter=vad_filter,
|
||||
)
|
||||
t1 = time.perf_counter()
|
||||
logger.info(
|
||||
f"Transcribe call returned in {t1 - t0:.1f}s. "
|
||||
f"Language: {info.language} ({info.language_probability:.2f}), "
|
||||
f"duration: {info.duration:.1f}s"
|
||||
)
|
||||
|
||||
segments = []
|
||||
for seg in segments_iter:
|
||||
seg_data = {
|
||||
"start": round(seg.start, 3),
|
||||
"end": round(seg.end, 3),
|
||||
"text": seg.text.strip(),
|
||||
}
|
||||
if word_timestamps and seg.words:
|
||||
seg_data["words"] = [
|
||||
{"start": round(w.start, 3), "end": round(w.end, 3), "word": w.word}
|
||||
for w in seg.words
|
||||
]
|
||||
segments.append(seg_data)
|
||||
|
||||
elapsed = time.perf_counter() - t0
|
||||
logger.info(
|
||||
f"Done: {file.filename} — {info.duration:.1f}s audio, "
|
||||
f"{len(segments)} segments, {elapsed:.1f}s total "
|
||||
f"(RTF={elapsed / info.duration:.3f})"
|
||||
)
|
||||
|
||||
return {
|
||||
"language": info.language,
|
||||
"language_probability": round(info.language_probability, 3),
|
||||
"duration": round(info.duration, 3),
|
||||
"processing_time": round(elapsed, 3),
|
||||
"segments": segments,
|
||||
}
|
||||
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
Reference in New Issue
Block a user