build: first commit
This commit is contained in:
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Pyannote microservice: диаризация (кто когда говорил).
|
||||
|
||||
Модель загружается по требованию (lazy load) и выгружается
|
||||
через UNLOAD_AFTER секунд простоя. Явные endpoints /load и /unload
|
||||
позволяют управлять памятью вручную (например, из n8n).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
logger = logging.getLogger("pyannote-service")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s")
|
||||
|
||||
# ────────────────────────── config ──────────────────────────
|
||||
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
DEVICE = os.getenv("PYANNOTE_DEVICE", "cuda")
|
||||
PIPELINE_NAME = "pyannote/speaker-diarization-3.1"
|
||||
UNLOAD_AFTER = int(os.getenv("PYANNOTE_UNLOAD_AFTER", "300")) # секунд простоя
|
||||
|
||||
# ────────────────────────── state ──────────────────────────
|
||||
|
||||
pipeline: Pipeline | None = None
|
||||
_model_lock = asyncio.Lock()
|
||||
_last_used_at: float = 0.0
|
||||
_unload_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
# ────────────────────────── постобработка (без изменений) ──────────────────────────
|
||||
|
||||
def extract_segments(diarization) -> list[tuple[float, float, str]]:
|
||||
annotation = getattr(diarization, "speaker_diarization", diarization)
|
||||
segments: list[tuple[float, float, str]] = []
|
||||
if hasattr(annotation, "itertracks"):
|
||||
for turn, _, speaker in annotation.itertracks(yield_label=True):
|
||||
segments.append((turn.start, turn.end, speaker))
|
||||
else:
|
||||
for turn, speaker in annotation:
|
||||
segments.append((turn.start, turn.end, speaker))
|
||||
return segments
|
||||
|
||||
|
||||
def filter_short(segments: list[tuple[float, float, str]], min_duration: float) -> list[tuple[float, float, str]]:
|
||||
if min_duration <= 0:
|
||||
return segments
|
||||
return [s for s in segments if (s[1] - s[0]) >= min_duration]
|
||||
|
||||
|
||||
def merge_adjacent(segments: list[tuple[float, float, str]], max_gap: float) -> list[tuple[float, float, str]]:
|
||||
if max_gap < 0 or not segments:
|
||||
return segments
|
||||
segments = sorted(segments, key=lambda s: s[0])
|
||||
merged: list[tuple[float, float, str]] = []
|
||||
for start, end, speaker in segments:
|
||||
if merged and merged[-1][2] == speaker and (start - merged[-1][1]) <= max_gap:
|
||||
prev_start, prev_end, prev_speaker = merged[-1]
|
||||
merged[-1] = (prev_start, max(prev_end, end), prev_speaker)
|
||||
else:
|
||||
merged.append((start, end, speaker))
|
||||
return merged
|
||||
|
||||
|
||||
# ────────────────────────── model helpers ──────────────────────────
|
||||
|
||||
def _load_pipeline_sync() -> Pipeline:
|
||||
if not HF_TOKEN:
|
||||
raise RuntimeError(
|
||||
"HF_TOKEN environment variable is required. "
|
||||
"Get it at https://huggingface.co/settings/tokens"
|
||||
)
|
||||
logger.info(f"Loading pipeline '{PIPELINE_NAME}' on {DEVICE}…")
|
||||
t0 = time.perf_counter()
|
||||
p = Pipeline.from_pretrained(PIPELINE_NAME, token=HF_TOKEN)
|
||||
if DEVICE == "cuda" and torch.cuda.is_available():
|
||||
p.to(torch.device("cuda"))
|
||||
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
else:
|
||||
logger.info("Using CPU")
|
||||
logger.info(f"Pipeline loaded in {time.perf_counter() - t0:.1f}s")
|
||||
return p
|
||||
|
||||
|
||||
def _unload_pipeline_sync() -> None:
|
||||
global pipeline
|
||||
if pipeline is None:
|
||||
return
|
||||
del pipeline
|
||||
pipeline = None
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
logger.info("Pipeline unloaded, VRAM released")
|
||||
|
||||
|
||||
async def _ensure_loaded() -> Pipeline:
|
||||
global pipeline, _last_used_at, _unload_task
|
||||
async with _model_lock:
|
||||
if pipeline is None:
|
||||
loop = asyncio.get_running_loop()
|
||||
pipeline = await loop.run_in_executor(None, _load_pipeline_sync)
|
||||
|
||||
_last_used_at = time.monotonic()
|
||||
|
||||
if UNLOAD_AFTER > 0:
|
||||
if _unload_task and not _unload_task.done():
|
||||
_unload_task.cancel()
|
||||
_unload_task = asyncio.create_task(_auto_unload_after(UNLOAD_AFTER))
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
async def _auto_unload_after(seconds: int) -> None:
|
||||
await asyncio.sleep(seconds)
|
||||
async with _model_lock:
|
||||
if pipeline is not None and (time.monotonic() - _last_used_at) >= seconds:
|
||||
logger.info(f"Auto-unloading after {seconds}s of inactivity…")
|
||||
_unload_pipeline_sync()
|
||||
|
||||
|
||||
# ────────────────────────── lifespan ──────────────────────────
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info(f"Pyannote service ready (lazy load, unload_after={UNLOAD_AFTER}s)")
|
||||
yield
|
||||
async with _model_lock:
|
||||
_unload_pipeline_sync()
|
||||
|
||||
|
||||
app = FastAPI(title="Pyannote Service", lifespan=lifespan)
|
||||
|
||||
|
||||
# ────────────────────────── control endpoints ──────────────────────────
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"pipeline": PIPELINE_NAME,
|
||||
"device": DEVICE,
|
||||
"loaded": pipeline is not None,
|
||||
"idle_seconds": round(time.monotonic() - _last_used_at, 1) if _last_used_at else None,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/load")
|
||||
async def load_model():
|
||||
"""Явная загрузка pipeline в VRAM."""
|
||||
await _ensure_loaded()
|
||||
return {"status": "loaded", "pipeline": PIPELINE_NAME}
|
||||
|
||||
|
||||
@app.post("/unload")
|
||||
async def unload_model():
|
||||
"""Явная выгрузка pipeline из VRAM."""
|
||||
global _unload_task
|
||||
async with _model_lock:
|
||||
if _unload_task and not _unload_task.done():
|
||||
_unload_task.cancel()
|
||||
_unload_pipeline_sync()
|
||||
return {"status": "unloaded"}
|
||||
|
||||
|
||||
# ────────────────────────── diarize ──────────────────────────
|
||||
|
||||
@app.post("/diarize")
|
||||
async def diarize(
|
||||
file: UploadFile = File(...),
|
||||
num_speakers: int | None = Form(None),
|
||||
min_speakers: int | None = Form(None),
|
||||
max_speakers: int | None = Form(None),
|
||||
min_duration: float = Form(0.5),
|
||||
merge_gap: float = Form(0.3),
|
||||
):
|
||||
num_speakers = num_speakers if num_speakers and num_speakers > 0 else None
|
||||
min_speakers = min_speakers if min_speakers and min_speakers > 0 else None
|
||||
max_speakers = max_speakers if max_speakers and max_speakers > 0 else None
|
||||
|
||||
suffix = Path(file.filename).suffix if file.filename else ".wav"
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
p = await _ensure_loaded()
|
||||
|
||||
t0 = time.perf_counter()
|
||||
logger.info(f"Starting diarization: {file.filename} ({len(content)} bytes)")
|
||||
|
||||
kwargs: dict = {}
|
||||
if num_speakers is not None:
|
||||
kwargs["num_speakers"] = num_speakers
|
||||
else:
|
||||
if min_speakers is not None:
|
||||
kwargs["min_speakers"] = min_speakers
|
||||
if max_speakers is not None:
|
||||
kwargs["max_speakers"] = max_speakers
|
||||
|
||||
logger.info(f"Running pipeline (params: {kwargs})…")
|
||||
try:
|
||||
from pyannote.audio.pipelines.utils.hook import ProgressHook
|
||||
with ProgressHook() as hook:
|
||||
diarization = p(tmp_path, hook=hook, **kwargs)
|
||||
except ImportError:
|
||||
diarization = p(tmp_path, **kwargs)
|
||||
|
||||
t1 = time.perf_counter()
|
||||
logger.info(f"Pipeline finished in {t1 - t0:.1f}s")
|
||||
|
||||
raw_segments = extract_segments(diarization)
|
||||
raw_count = len(raw_segments)
|
||||
segments = filter_short(raw_segments, min_duration)
|
||||
after_filter = len(segments)
|
||||
segments = merge_adjacent(segments, merge_gap)
|
||||
after_merge = len(segments)
|
||||
|
||||
turns = [
|
||||
{
|
||||
"start": round(start, 3),
|
||||
"end": round(end, 3),
|
||||
"duration": round(end - start, 3),
|
||||
"speaker": speaker,
|
||||
}
|
||||
for start, end, speaker in segments
|
||||
]
|
||||
|
||||
speakers_stats: dict[str, float] = {}
|
||||
for start, end, speaker in segments:
|
||||
speakers_stats[speaker] = speakers_stats.get(speaker, 0.0) + (end - start)
|
||||
|
||||
speakers_info = {
|
||||
speaker: round(duration, 3)
|
||||
for speaker, duration in sorted(speakers_stats.items(), key=lambda x: -x[1])
|
||||
}
|
||||
|
||||
elapsed = time.perf_counter() - t0
|
||||
logger.info(
|
||||
f"Diarized {file.filename}: {raw_count} raw → {after_filter} filtered → "
|
||||
f"{after_merge} merged, {len(speakers_info)} speakers in {elapsed:.1f}s"
|
||||
)
|
||||
|
||||
return {
|
||||
"speakers": speakers_info,
|
||||
"num_speakers": len(speakers_info),
|
||||
"total_speech_duration": round(sum(speakers_stats.values()), 3),
|
||||
"processing_time": round(elapsed, 3),
|
||||
"segments_raw": raw_count,
|
||||
"segments_filtered": after_filter,
|
||||
"segments_final": after_merge,
|
||||
"turns": turns,
|
||||
}
|
||||
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
Reference in New Issue
Block a user