Files
2026-05-19 10:12:57 +00:00

353 lines
12 KiB
Python

"""
API Gateway: маршрутизация запросов и reconciliation.
POST /transcribe → whisper-service (только транскрибация)
POST /diarize → pyannote-service (только диаризация)
POST /process → оба параллельно + сшивка по словам
Управление VRAM:
POST /whisper/load → загрузить whisper
POST /whisper/unload → выгрузить whisper
POST /pyannote/load → загрузить pyannote
POST /pyannote/unload → выгрузить pyannote
POST /ollama/unload → выгрузить текущую модель ollama (keep_alive=0)
"""
from __future__ import annotations
import asyncio
import logging
import os
import time
import httpx
from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import JSONResponse
logger = logging.getLogger("gateway")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s")
WHISPER_URL = os.getenv("WHISPER_URL", "http://whisper-service:8001")
PYANNOTE_URL = os.getenv("PYANNOTE_URL", "http://pyannote-service:8002")
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://ollama:11434")
TIMEOUT = httpx.Timeout(timeout=1800.0)
app = FastAPI(
title="Speech Processing API",
description="Транскрибация + диаризация аудио. Whisper (GPU) + pyannote (CPU), параллельно.",
version="4.0.0",
)
# ────────────────────── VRAM management proxies ──────────────────────
@app.post("/whisper/load")
async def whisper_load():
async with httpx.AsyncClient() as client:
r = await client.post(f"{WHISPER_URL}/load", timeout=120.0)
return r.json()
@app.post("/whisper/unload")
async def whisper_unload():
async with httpx.AsyncClient() as client:
r = await client.post(f"{WHISPER_URL}/unload", timeout=30.0)
return r.json()
@app.post("/pyannote/load")
async def pyannote_load():
async with httpx.AsyncClient() as client:
r = await client.post(f"{PYANNOTE_URL}/load", timeout=120.0)
return r.json()
@app.post("/pyannote/unload")
async def pyannote_unload():
async with httpx.AsyncClient() as client:
r = await client.post(f"{PYANNOTE_URL}/unload", timeout=30.0)
return r.json()
@app.post("/ollama/unload")
async def ollama_unload(model: str = "llama3.2"):
"""Выгрузить модель ollama (keep_alive=0)."""
async with httpx.AsyncClient() as client:
r = await client.post(
f"{OLLAMA_URL}/api/generate",
json={"model": model, "keep_alive": 0},
timeout=30.0,
)
return {"status": "unloaded", "model": model}
# ────────────────────── helpers ──────────────────────
def _format_timestamp_srt(seconds: float) -> str:
if seconds < 0:
seconds = 0
total_ms = round(seconds * 1000)
hours, total_ms = divmod(total_ms, 3_600_000)
minutes, total_ms = divmod(total_ms, 60_000)
secs, ms = divmod(total_ms, 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d},{ms:03d}"
def _format_timestamp_vtt(seconds: float) -> str:
if seconds < 0:
seconds = 0
total_ms = round(seconds * 1000)
hours, total_ms = divmod(total_ms, 3_600_000)
minutes, total_ms = divmod(total_ms, 60_000)
secs, ms = divmod(total_ms, 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d}.{ms:03d}"
async def _call_whisper(
client: httpx.AsyncClient,
audio_bytes: bytes,
filename: str,
language: str | None,
initial_prompt: str | None,
beam_size: int,
batch_size: int,
) -> dict:
language = language if language and language not in ("string", "") else None
initial_prompt = initial_prompt if initial_prompt and initial_prompt not in ("string", "") else None
data = {"beam_size": str(beam_size), "batch_size": str(batch_size), "word_timestamps": "true"}
if language:
data["language"] = language
if initial_prompt:
data["initial_prompt"] = initial_prompt
logger.info(f"[whisper] Sending {filename} ({len(audio_bytes)} bytes)")
t0 = time.perf_counter()
resp = await client.post(
f"{WHISPER_URL}/transcribe",
files={"file": (filename, audio_bytes)},
data=data,
timeout=TIMEOUT,
)
logger.info(f"[whisper] {resp.status_code} in {time.perf_counter() - t0:.1f}s")
resp.raise_for_status()
return resp.json()
async def _call_pyannote(
client: httpx.AsyncClient,
audio_bytes: bytes,
filename: str,
num_speakers: int | None,
min_speakers: int | None,
max_speakers: int | None,
min_duration: float = 0.5,
merge_gap: float = 0.3,
) -> dict:
data = {"min_duration": str(min_duration), "merge_gap": str(merge_gap)}
if num_speakers is not None:
data["num_speakers"] = str(num_speakers)
if min_speakers is not None:
data["min_speakers"] = str(min_speakers)
if max_speakers is not None:
data["max_speakers"] = str(max_speakers)
logger.info(f"[pyannote] Sending {filename} ({len(audio_bytes)} bytes)")
t0 = time.perf_counter()
resp = await client.post(
f"{PYANNOTE_URL}/diarize",
files={"file": (filename, audio_bytes)},
data=data,
timeout=TIMEOUT,
)
logger.info(f"[pyannote] {resp.status_code} in {time.perf_counter() - t0:.1f}s")
resp.raise_for_status()
return resp.json()
def _reconcile(whisper_result: dict, pyannote_result: dict) -> list[dict]:
turns = pyannote_result.get("turns", [])
if not turns:
return [
{"speaker": "SPEAKER_00", "start": seg["start"], "end": seg["end"], "text": seg["text"]}
for seg in whisper_result.get("segments", [])
]
all_words = []
for seg in whisper_result.get("segments", []):
words = seg.get("words")
if words:
all_words.extend(words)
else:
all_words.append({"start": seg["start"], "end": seg["end"], "word": seg["text"]})
if not all_words:
return []
def find_speaker(midpoint: float) -> str:
for turn in turns:
if turn["start"] <= midpoint <= turn["end"]:
return turn["speaker"]
min_dist = float("inf")
closest = turns[0]["speaker"]
for turn in turns:
dist = min(abs(midpoint - turn["start"]), abs(midpoint - turn["end"]))
if dist < min_dist:
min_dist = dist
closest = turn["speaker"]
return closest
for w in all_words:
w["speaker"] = find_speaker((w["start"] + w["end"]) / 2)
utterances: list[dict] = []
current_speaker = all_words[0]["speaker"]
current_words = [all_words[0]]
for w in all_words[1:]:
if w["speaker"] == current_speaker:
current_words.append(w)
else:
utterances.append({
"speaker": current_speaker,
"start": round(current_words[0]["start"], 3),
"end": round(current_words[-1]["end"], 3),
"text": "".join(w["word"] for w in current_words).strip(),
})
current_speaker = w["speaker"]
current_words = [w]
if current_words:
utterances.append({
"speaker": current_speaker,
"start": round(current_words[0]["start"], 3),
"end": round(current_words[-1]["end"], 3),
"text": "".join(w["word"] for w in current_words).strip(),
})
return utterances
def _to_srt(utterances: list[dict]) -> str:
lines = []
for i, u in enumerate(utterances, 1):
lines += [str(i), f"{_format_timestamp_srt(u['start'])} --> {_format_timestamp_srt(u['end'])}", f"[{u['speaker']}] {u['text']}", ""]
return "\n".join(lines)
def _to_vtt(utterances: list[dict]) -> str:
lines = ["WEBVTT", ""]
for u in utterances:
lines += [f"{_format_timestamp_vtt(u['start'])} --> {_format_timestamp_vtt(u['end'])}", f"[{u['speaker']}] {u['text']}", ""]
return "\n".join(lines)
def _to_txt(utterances: list[dict]) -> str:
return "\n".join(f"{u['speaker']}: {u['text']}" for u in utterances)
# ────────────────────── endpoints ──────────────────────
@app.get("/health")
async def health():
async with httpx.AsyncClient() as client:
results = {}
for name, url in [("whisper", WHISPER_URL), ("pyannote", PYANNOTE_URL), ("ollama", OLLAMA_URL)]:
try:
r = await client.get(f"{url}/health" if name != "ollama" else f"{url}/", timeout=3.0)
results[name] = "ok" if r.status_code < 400 else "error"
except Exception:
results[name] = "unreachable"
return {"status": "ok", "services": results}
@app.post("/transcribe")
async def transcribe(
file: UploadFile = File(...),
language: str | None = Form(None),
initial_prompt: str | None = Form(None),
beam_size: int = Form(5),
batch_size: int = Form(16),
response_format: str = Form("json"),
):
audio_bytes = await file.read()
async with httpx.AsyncClient() as client:
result = await _call_whisper(client, audio_bytes, file.filename or "audio.wav",
language, initial_prompt, beam_size, batch_size)
if response_format == "srt":
return JSONResponse({"format": "srt", "content": _to_srt([
{"speaker": "", "start": s["start"], "end": s["end"], "text": s["text"]}
for s in result["segments"]
])})
if response_format == "vtt":
return JSONResponse({"format": "vtt", "content": _to_vtt([
{"speaker": "", "start": s["start"], "end": s["end"], "text": s["text"]}
for s in result["segments"]
])})
if response_format == "txt":
return JSONResponse({"format": "txt", "content": " ".join(s["text"] for s in result["segments"])})
return result
@app.post("/diarize")
async def diarize(
file: UploadFile = File(...),
num_speakers: int | None = Form(None),
min_speakers: int | None = Form(None),
max_speakers: int | None = Form(None),
min_duration: float = Form(0.5),
merge_gap: float = Form(0.3),
):
audio_bytes = await file.read()
async with httpx.AsyncClient() as client:
return await _call_pyannote(client, audio_bytes, file.filename or "audio.wav",
num_speakers, min_speakers, max_speakers, min_duration, merge_gap)
@app.post("/process")
async def process(
file: UploadFile = File(...),
language: str | None = Form(None),
initial_prompt: str | None = Form(None),
beam_size: int = Form(5),
batch_size: int = Form(16),
num_speakers: int | None = Form(None),
min_speakers: int | None = Form(None),
max_speakers: int | None = Form(None),
min_duration: float = Form(0.5),
merge_gap: float = Form(0.3),
response_format: str = Form("json"),
):
t0 = time.perf_counter()
audio_bytes = await file.read()
filename = file.filename or "audio.wav"
async with httpx.AsyncClient() as client:
whisper_result, pyannote_result = await asyncio.gather(
_call_whisper(client, audio_bytes, filename, language, initial_prompt, beam_size, batch_size),
_call_pyannote(client, audio_bytes, filename, num_speakers, min_speakers, max_speakers, min_duration, merge_gap),
)
utterances = _reconcile(whisper_result, pyannote_result)
elapsed = time.perf_counter() - t0
if response_format == "srt":
return JSONResponse({"format": "srt", "content": _to_srt(utterances), "processing_time": round(elapsed, 3)})
if response_format == "vtt":
return JSONResponse({"format": "vtt", "content": _to_vtt(utterances), "processing_time": round(elapsed, 3)})
if response_format == "txt":
return JSONResponse({"format": "txt", "content": _to_txt(utterances), "processing_time": round(elapsed, 3)})
return {
"language": whisper_result.get("language"),
"language_probability": whisper_result.get("language_probability"),
"duration": whisper_result.get("duration"),
"processing_time": round(elapsed, 3),
"whisper_time": whisper_result.get("processing_time"),
"pyannote_time": pyannote_result.get("processing_time"),
"speakers": pyannote_result.get("speakers", []),
"num_speakers": pyannote_result.get("num_speakers", 0),
"utterances": utterances,
}