""" API Gateway: маршрутизация запросов и reconciliation. POST /transcribe → whisper-service (только транскрибация) POST /diarize → pyannote-service (только диаризация) POST /process → оба параллельно + сшивка по словам Управление VRAM: POST /whisper/load → загрузить whisper POST /whisper/unload → выгрузить whisper POST /pyannote/load → загрузить pyannote POST /pyannote/unload → выгрузить pyannote POST /ollama/unload → выгрузить текущую модель ollama (keep_alive=0) """ from __future__ import annotations import asyncio import logging import os import time import httpx from fastapi import FastAPI, File, Form, UploadFile from fastapi.responses import JSONResponse logger = logging.getLogger("gateway") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s") WHISPER_URL = os.getenv("WHISPER_URL", "http://whisper-service:8001") PYANNOTE_URL = os.getenv("PYANNOTE_URL", "http://pyannote-service:8002") OLLAMA_URL = os.getenv("OLLAMA_URL", "http://ollama:11434") TIMEOUT = httpx.Timeout(timeout=1800.0) app = FastAPI( title="Speech Processing API", description="Транскрибация + диаризация аудио. Whisper (GPU) + pyannote (CPU), параллельно.", version="4.0.0", ) # ────────────────────── VRAM management proxies ────────────────────── @app.post("/whisper/load") async def whisper_load(): async with httpx.AsyncClient() as client: r = await client.post(f"{WHISPER_URL}/load", timeout=120.0) return r.json() @app.post("/whisper/unload") async def whisper_unload(): async with httpx.AsyncClient() as client: r = await client.post(f"{WHISPER_URL}/unload", timeout=30.0) return r.json() @app.post("/pyannote/load") async def pyannote_load(): async with httpx.AsyncClient() as client: r = await client.post(f"{PYANNOTE_URL}/load", timeout=120.0) return r.json() @app.post("/pyannote/unload") async def pyannote_unload(): async with httpx.AsyncClient() as client: r = await client.post(f"{PYANNOTE_URL}/unload", timeout=30.0) return r.json() @app.post("/ollama/unload") async def ollama_unload(model: str = "llama3.2"): """Выгрузить модель ollama (keep_alive=0).""" async with httpx.AsyncClient() as client: r = await client.post( f"{OLLAMA_URL}/api/generate", json={"model": model, "keep_alive": 0}, timeout=30.0, ) return {"status": "unloaded", "model": model} # ────────────────────── helpers ────────────────────── def _format_timestamp_srt(seconds: float) -> str: if seconds < 0: seconds = 0 total_ms = round(seconds * 1000) hours, total_ms = divmod(total_ms, 3_600_000) minutes, total_ms = divmod(total_ms, 60_000) secs, ms = divmod(total_ms, 1000) return f"{hours:02d}:{minutes:02d}:{secs:02d},{ms:03d}" def _format_timestamp_vtt(seconds: float) -> str: if seconds < 0: seconds = 0 total_ms = round(seconds * 1000) hours, total_ms = divmod(total_ms, 3_600_000) minutes, total_ms = divmod(total_ms, 60_000) secs, ms = divmod(total_ms, 1000) return f"{hours:02d}:{minutes:02d}:{secs:02d}.{ms:03d}" async def _call_whisper( client: httpx.AsyncClient, audio_bytes: bytes, filename: str, language: str | None, initial_prompt: str | None, beam_size: int, batch_size: int, ) -> dict: language = language if language and language not in ("string", "") else None initial_prompt = initial_prompt if initial_prompt and initial_prompt not in ("string", "") else None data = {"beam_size": str(beam_size), "batch_size": str(batch_size), "word_timestamps": "true"} if language: data["language"] = language if initial_prompt: data["initial_prompt"] = initial_prompt logger.info(f"[whisper] Sending {filename} ({len(audio_bytes)} bytes)") t0 = time.perf_counter() resp = await client.post( f"{WHISPER_URL}/transcribe", files={"file": (filename, audio_bytes)}, data=data, timeout=TIMEOUT, ) logger.info(f"[whisper] {resp.status_code} in {time.perf_counter() - t0:.1f}s") resp.raise_for_status() return resp.json() async def _call_pyannote( client: httpx.AsyncClient, audio_bytes: bytes, filename: str, num_speakers: int | None, min_speakers: int | None, max_speakers: int | None, min_duration: float = 0.5, merge_gap: float = 0.3, ) -> dict: data = {"min_duration": str(min_duration), "merge_gap": str(merge_gap)} if num_speakers is not None: data["num_speakers"] = str(num_speakers) if min_speakers is not None: data["min_speakers"] = str(min_speakers) if max_speakers is not None: data["max_speakers"] = str(max_speakers) logger.info(f"[pyannote] Sending {filename} ({len(audio_bytes)} bytes)") t0 = time.perf_counter() resp = await client.post( f"{PYANNOTE_URL}/diarize", files={"file": (filename, audio_bytes)}, data=data, timeout=TIMEOUT, ) logger.info(f"[pyannote] {resp.status_code} in {time.perf_counter() - t0:.1f}s") resp.raise_for_status() return resp.json() def _reconcile(whisper_result: dict, pyannote_result: dict) -> list[dict]: turns = pyannote_result.get("turns", []) if not turns: return [ {"speaker": "SPEAKER_00", "start": seg["start"], "end": seg["end"], "text": seg["text"]} for seg in whisper_result.get("segments", []) ] all_words = [] for seg in whisper_result.get("segments", []): words = seg.get("words") if words: all_words.extend(words) else: all_words.append({"start": seg["start"], "end": seg["end"], "word": seg["text"]}) if not all_words: return [] def find_speaker(midpoint: float) -> str: for turn in turns: if turn["start"] <= midpoint <= turn["end"]: return turn["speaker"] min_dist = float("inf") closest = turns[0]["speaker"] for turn in turns: dist = min(abs(midpoint - turn["start"]), abs(midpoint - turn["end"])) if dist < min_dist: min_dist = dist closest = turn["speaker"] return closest for w in all_words: w["speaker"] = find_speaker((w["start"] + w["end"]) / 2) utterances: list[dict] = [] current_speaker = all_words[0]["speaker"] current_words = [all_words[0]] for w in all_words[1:]: if w["speaker"] == current_speaker: current_words.append(w) else: utterances.append({ "speaker": current_speaker, "start": round(current_words[0]["start"], 3), "end": round(current_words[-1]["end"], 3), "text": "".join(w["word"] for w in current_words).strip(), }) current_speaker = w["speaker"] current_words = [w] if current_words: utterances.append({ "speaker": current_speaker, "start": round(current_words[0]["start"], 3), "end": round(current_words[-1]["end"], 3), "text": "".join(w["word"] for w in current_words).strip(), }) return utterances def _to_srt(utterances: list[dict]) -> str: lines = [] for i, u in enumerate(utterances, 1): lines += [str(i), f"{_format_timestamp_srt(u['start'])} --> {_format_timestamp_srt(u['end'])}", f"[{u['speaker']}] {u['text']}", ""] return "\n".join(lines) def _to_vtt(utterances: list[dict]) -> str: lines = ["WEBVTT", ""] for u in utterances: lines += [f"{_format_timestamp_vtt(u['start'])} --> {_format_timestamp_vtt(u['end'])}", f"[{u['speaker']}] {u['text']}", ""] return "\n".join(lines) def _to_txt(utterances: list[dict]) -> str: return "\n".join(f"{u['speaker']}: {u['text']}" for u in utterances) # ────────────────────── endpoints ────────────────────── @app.get("/health") async def health(): async with httpx.AsyncClient() as client: results = {} for name, url in [("whisper", WHISPER_URL), ("pyannote", PYANNOTE_URL), ("ollama", OLLAMA_URL)]: try: r = await client.get(f"{url}/health" if name != "ollama" else f"{url}/", timeout=3.0) results[name] = "ok" if r.status_code < 400 else "error" except Exception: results[name] = "unreachable" return {"status": "ok", "services": results} @app.post("/transcribe") async def transcribe( file: UploadFile = File(...), language: str | None = Form(None), initial_prompt: str | None = Form(None), beam_size: int = Form(5), batch_size: int = Form(16), response_format: str = Form("json"), ): audio_bytes = await file.read() async with httpx.AsyncClient() as client: result = await _call_whisper(client, audio_bytes, file.filename or "audio.wav", language, initial_prompt, beam_size, batch_size) if response_format == "srt": return JSONResponse({"format": "srt", "content": _to_srt([ {"speaker": "", "start": s["start"], "end": s["end"], "text": s["text"]} for s in result["segments"] ])}) if response_format == "vtt": return JSONResponse({"format": "vtt", "content": _to_vtt([ {"speaker": "", "start": s["start"], "end": s["end"], "text": s["text"]} for s in result["segments"] ])}) if response_format == "txt": return JSONResponse({"format": "txt", "content": " ".join(s["text"] for s in result["segments"])}) return result @app.post("/diarize") async def diarize( file: UploadFile = File(...), num_speakers: int | None = Form(None), min_speakers: int | None = Form(None), max_speakers: int | None = Form(None), min_duration: float = Form(0.5), merge_gap: float = Form(0.3), ): audio_bytes = await file.read() async with httpx.AsyncClient() as client: return await _call_pyannote(client, audio_bytes, file.filename or "audio.wav", num_speakers, min_speakers, max_speakers, min_duration, merge_gap) @app.post("/process") async def process( file: UploadFile = File(...), language: str | None = Form(None), initial_prompt: str | None = Form(None), beam_size: int = Form(5), batch_size: int = Form(16), num_speakers: int | None = Form(None), min_speakers: int | None = Form(None), max_speakers: int | None = Form(None), min_duration: float = Form(0.5), merge_gap: float = Form(0.3), response_format: str = Form("json"), ): t0 = time.perf_counter() audio_bytes = await file.read() filename = file.filename or "audio.wav" async with httpx.AsyncClient() as client: whisper_result, pyannote_result = await asyncio.gather( _call_whisper(client, audio_bytes, filename, language, initial_prompt, beam_size, batch_size), _call_pyannote(client, audio_bytes, filename, num_speakers, min_speakers, max_speakers, min_duration, merge_gap), ) utterances = _reconcile(whisper_result, pyannote_result) elapsed = time.perf_counter() - t0 if response_format == "srt": return JSONResponse({"format": "srt", "content": _to_srt(utterances), "processing_time": round(elapsed, 3)}) if response_format == "vtt": return JSONResponse({"format": "vtt", "content": _to_vtt(utterances), "processing_time": round(elapsed, 3)}) if response_format == "txt": return JSONResponse({"format": "txt", "content": _to_txt(utterances), "processing_time": round(elapsed, 3)}) return { "language": whisper_result.get("language"), "language_probability": whisper_result.get("language_probability"), "duration": whisper_result.get("duration"), "processing_time": round(elapsed, 3), "whisper_time": whisper_result.get("processing_time"), "pyannote_time": pyannote_result.get("processing_time"), "speakers": pyannote_result.get("speakers", []), "num_speakers": pyannote_result.get("num_speakers", 0), "utterances": utterances, }