""" Pyannote microservice: диаризация (кто когда говорил). Модель загружается по требованию (lazy load) и выгружается через UNLOAD_AFTER секунд простоя. Явные endpoints /load и /unload позволяют управлять памятью вручную (например, из n8n). """ from __future__ import annotations import asyncio import gc import logging import os import tempfile import time from contextlib import asynccontextmanager from pathlib import Path import torch from fastapi import FastAPI, File, Form, HTTPException, UploadFile from pyannote.audio import Pipeline logger = logging.getLogger("pyannote-service") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s") # ────────────────────────── config ────────────────────────── HF_TOKEN = os.getenv("HF_TOKEN") DEVICE = os.getenv("PYANNOTE_DEVICE", "cuda") PIPELINE_NAME = "pyannote/speaker-diarization-3.1" UNLOAD_AFTER = int(os.getenv("PYANNOTE_UNLOAD_AFTER", "300")) # секунд простоя # ────────────────────────── state ────────────────────────── pipeline: Pipeline | None = None _model_lock = asyncio.Lock() _last_used_at: float = 0.0 _unload_task: asyncio.Task | None = None # ────────────────────────── постобработка (без изменений) ────────────────────────── def extract_segments(diarization) -> list[tuple[float, float, str]]: annotation = getattr(diarization, "speaker_diarization", diarization) segments: list[tuple[float, float, str]] = [] if hasattr(annotation, "itertracks"): for turn, _, speaker in annotation.itertracks(yield_label=True): segments.append((turn.start, turn.end, speaker)) else: for turn, speaker in annotation: segments.append((turn.start, turn.end, speaker)) return segments def filter_short(segments: list[tuple[float, float, str]], min_duration: float) -> list[tuple[float, float, str]]: if min_duration <= 0: return segments return [s for s in segments if (s[1] - s[0]) >= min_duration] def merge_adjacent(segments: list[tuple[float, float, str]], max_gap: float) -> list[tuple[float, float, str]]: if max_gap < 0 or not segments: return segments segments = sorted(segments, key=lambda s: s[0]) merged: list[tuple[float, float, str]] = [] for start, end, speaker in segments: if merged and merged[-1][2] == speaker and (start - merged[-1][1]) <= max_gap: prev_start, prev_end, prev_speaker = merged[-1] merged[-1] = (prev_start, max(prev_end, end), prev_speaker) else: merged.append((start, end, speaker)) return merged # ────────────────────────── model helpers ────────────────────────── def _load_pipeline_sync() -> Pipeline: if not HF_TOKEN: raise RuntimeError( "HF_TOKEN environment variable is required. " "Get it at https://huggingface.co/settings/tokens" ) logger.info(f"Loading pipeline '{PIPELINE_NAME}' on {DEVICE}…") t0 = time.perf_counter() p = Pipeline.from_pretrained(PIPELINE_NAME, token=HF_TOKEN) if DEVICE == "cuda" and torch.cuda.is_available(): p.to(torch.device("cuda")) logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}") else: logger.info("Using CPU") logger.info(f"Pipeline loaded in {time.perf_counter() - t0:.1f}s") return p def _unload_pipeline_sync() -> None: global pipeline if pipeline is None: return del pipeline pipeline = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() logger.info("Pipeline unloaded, VRAM released") async def _ensure_loaded() -> Pipeline: global pipeline, _last_used_at, _unload_task async with _model_lock: if pipeline is None: loop = asyncio.get_running_loop() pipeline = await loop.run_in_executor(None, _load_pipeline_sync) _last_used_at = time.monotonic() if UNLOAD_AFTER > 0: if _unload_task and not _unload_task.done(): _unload_task.cancel() _unload_task = asyncio.create_task(_auto_unload_after(UNLOAD_AFTER)) return pipeline async def _auto_unload_after(seconds: int) -> None: await asyncio.sleep(seconds) async with _model_lock: if pipeline is not None and (time.monotonic() - _last_used_at) >= seconds: logger.info(f"Auto-unloading after {seconds}s of inactivity…") _unload_pipeline_sync() # ────────────────────────── lifespan ────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): logger.info(f"Pyannote service ready (lazy load, unload_after={UNLOAD_AFTER}s)") yield async with _model_lock: _unload_pipeline_sync() app = FastAPI(title="Pyannote Service", lifespan=lifespan) # ────────────────────────── control endpoints ────────────────────────── @app.get("/health") async def health(): return { "status": "ok", "pipeline": PIPELINE_NAME, "device": DEVICE, "loaded": pipeline is not None, "idle_seconds": round(time.monotonic() - _last_used_at, 1) if _last_used_at else None, } @app.post("/load") async def load_model(): """Явная загрузка pipeline в VRAM.""" await _ensure_loaded() return {"status": "loaded", "pipeline": PIPELINE_NAME} @app.post("/unload") async def unload_model(): """Явная выгрузка pipeline из VRAM.""" global _unload_task async with _model_lock: if _unload_task and not _unload_task.done(): _unload_task.cancel() _unload_pipeline_sync() return {"status": "unloaded"} # ────────────────────────── diarize ────────────────────────── @app.post("/diarize") async def diarize( file: UploadFile = File(...), num_speakers: int | None = Form(None), min_speakers: int | None = Form(None), max_speakers: int | None = Form(None), min_duration: float = Form(0.5), merge_gap: float = Form(0.3), ): num_speakers = num_speakers if num_speakers and num_speakers > 0 else None min_speakers = min_speakers if min_speakers and min_speakers > 0 else None max_speakers = max_speakers if max_speakers and max_speakers > 0 else None suffix = Path(file.filename).suffix if file.filename else ".wav" with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: content = await file.read() tmp.write(content) tmp_path = tmp.name try: p = await _ensure_loaded() t0 = time.perf_counter() logger.info(f"Starting diarization: {file.filename} ({len(content)} bytes)") kwargs: dict = {} if num_speakers is not None: kwargs["num_speakers"] = num_speakers else: if min_speakers is not None: kwargs["min_speakers"] = min_speakers if max_speakers is not None: kwargs["max_speakers"] = max_speakers logger.info(f"Running pipeline (params: {kwargs})…") try: from pyannote.audio.pipelines.utils.hook import ProgressHook with ProgressHook() as hook: diarization = p(tmp_path, hook=hook, **kwargs) except ImportError: diarization = p(tmp_path, **kwargs) t1 = time.perf_counter() logger.info(f"Pipeline finished in {t1 - t0:.1f}s") raw_segments = extract_segments(diarization) raw_count = len(raw_segments) segments = filter_short(raw_segments, min_duration) after_filter = len(segments) segments = merge_adjacent(segments, merge_gap) after_merge = len(segments) turns = [ { "start": round(start, 3), "end": round(end, 3), "duration": round(end - start, 3), "speaker": speaker, } for start, end, speaker in segments ] speakers_stats: dict[str, float] = {} for start, end, speaker in segments: speakers_stats[speaker] = speakers_stats.get(speaker, 0.0) + (end - start) speakers_info = { speaker: round(duration, 3) for speaker, duration in sorted(speakers_stats.items(), key=lambda x: -x[1]) } elapsed = time.perf_counter() - t0 logger.info( f"Diarized {file.filename}: {raw_count} raw → {after_filter} filtered → " f"{after_merge} merged, {len(speakers_info)} speakers in {elapsed:.1f}s" ) return { "speakers": speakers_info, "num_speakers": len(speakers_info), "total_speech_duration": round(sum(speakers_stats.values()), 3), "processing_time": round(elapsed, 3), "segments_raw": raw_count, "segments_filtered": after_filter, "segments_final": after_merge, "turns": turns, } finally: Path(tmp_path).unlink(missing_ok=True)