353 lines
12 KiB
Python
353 lines
12 KiB
Python
"""
|
|
API Gateway: маршрутизация запросов и reconciliation.
|
|
|
|
POST /transcribe → whisper-service (только транскрибация)
|
|
POST /diarize → pyannote-service (только диаризация)
|
|
POST /process → оба параллельно + сшивка по словам
|
|
|
|
Управление VRAM:
|
|
POST /whisper/load → загрузить whisper
|
|
POST /whisper/unload → выгрузить whisper
|
|
POST /pyannote/load → загрузить pyannote
|
|
POST /pyannote/unload → выгрузить pyannote
|
|
POST /ollama/unload → выгрузить текущую модель ollama (keep_alive=0)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import time
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, File, Form, UploadFile
|
|
from fastapi.responses import JSONResponse
|
|
|
|
logger = logging.getLogger("gateway")
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s")
|
|
|
|
WHISPER_URL = os.getenv("WHISPER_URL", "http://whisper-service:8001")
|
|
PYANNOTE_URL = os.getenv("PYANNOTE_URL", "http://pyannote-service:8002")
|
|
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://ollama:11434")
|
|
|
|
TIMEOUT = httpx.Timeout(timeout=1800.0)
|
|
|
|
app = FastAPI(
|
|
title="Speech Processing API",
|
|
description="Транскрибация + диаризация аудио. Whisper (GPU) + pyannote (CPU), параллельно.",
|
|
version="4.0.0",
|
|
)
|
|
|
|
|
|
# ────────────────────── VRAM management proxies ──────────────────────
|
|
|
|
@app.post("/whisper/load")
|
|
async def whisper_load():
|
|
async with httpx.AsyncClient() as client:
|
|
r = await client.post(f"{WHISPER_URL}/load", timeout=120.0)
|
|
return r.json()
|
|
|
|
|
|
@app.post("/whisper/unload")
|
|
async def whisper_unload():
|
|
async with httpx.AsyncClient() as client:
|
|
r = await client.post(f"{WHISPER_URL}/unload", timeout=30.0)
|
|
return r.json()
|
|
|
|
|
|
@app.post("/pyannote/load")
|
|
async def pyannote_load():
|
|
async with httpx.AsyncClient() as client:
|
|
r = await client.post(f"{PYANNOTE_URL}/load", timeout=120.0)
|
|
return r.json()
|
|
|
|
|
|
@app.post("/pyannote/unload")
|
|
async def pyannote_unload():
|
|
async with httpx.AsyncClient() as client:
|
|
r = await client.post(f"{PYANNOTE_URL}/unload", timeout=30.0)
|
|
return r.json()
|
|
|
|
|
|
@app.post("/ollama/unload")
|
|
async def ollama_unload(model: str = "llama3.2"):
|
|
"""Выгрузить модель ollama (keep_alive=0)."""
|
|
async with httpx.AsyncClient() as client:
|
|
r = await client.post(
|
|
f"{OLLAMA_URL}/api/generate",
|
|
json={"model": model, "keep_alive": 0},
|
|
timeout=30.0,
|
|
)
|
|
return {"status": "unloaded", "model": model}
|
|
|
|
|
|
# ────────────────────── helpers ──────────────────────
|
|
|
|
def _format_timestamp_srt(seconds: float) -> str:
|
|
if seconds < 0:
|
|
seconds = 0
|
|
total_ms = round(seconds * 1000)
|
|
hours, total_ms = divmod(total_ms, 3_600_000)
|
|
minutes, total_ms = divmod(total_ms, 60_000)
|
|
secs, ms = divmod(total_ms, 1000)
|
|
return f"{hours:02d}:{minutes:02d}:{secs:02d},{ms:03d}"
|
|
|
|
|
|
def _format_timestamp_vtt(seconds: float) -> str:
|
|
if seconds < 0:
|
|
seconds = 0
|
|
total_ms = round(seconds * 1000)
|
|
hours, total_ms = divmod(total_ms, 3_600_000)
|
|
minutes, total_ms = divmod(total_ms, 60_000)
|
|
secs, ms = divmod(total_ms, 1000)
|
|
return f"{hours:02d}:{minutes:02d}:{secs:02d}.{ms:03d}"
|
|
|
|
|
|
async def _call_whisper(
|
|
client: httpx.AsyncClient,
|
|
audio_bytes: bytes,
|
|
filename: str,
|
|
language: str | None,
|
|
initial_prompt: str | None,
|
|
beam_size: int,
|
|
batch_size: int,
|
|
) -> dict:
|
|
language = language if language and language not in ("string", "") else None
|
|
initial_prompt = initial_prompt if initial_prompt and initial_prompt not in ("string", "") else None
|
|
|
|
data = {"beam_size": str(beam_size), "batch_size": str(batch_size), "word_timestamps": "true"}
|
|
if language:
|
|
data["language"] = language
|
|
if initial_prompt:
|
|
data["initial_prompt"] = initial_prompt
|
|
|
|
logger.info(f"[whisper] Sending {filename} ({len(audio_bytes)} bytes)")
|
|
t0 = time.perf_counter()
|
|
resp = await client.post(
|
|
f"{WHISPER_URL}/transcribe",
|
|
files={"file": (filename, audio_bytes)},
|
|
data=data,
|
|
timeout=TIMEOUT,
|
|
)
|
|
logger.info(f"[whisper] {resp.status_code} in {time.perf_counter() - t0:.1f}s")
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
async def _call_pyannote(
|
|
client: httpx.AsyncClient,
|
|
audio_bytes: bytes,
|
|
filename: str,
|
|
num_speakers: int | None,
|
|
min_speakers: int | None,
|
|
max_speakers: int | None,
|
|
min_duration: float = 0.5,
|
|
merge_gap: float = 0.3,
|
|
) -> dict:
|
|
data = {"min_duration": str(min_duration), "merge_gap": str(merge_gap)}
|
|
if num_speakers is not None:
|
|
data["num_speakers"] = str(num_speakers)
|
|
if min_speakers is not None:
|
|
data["min_speakers"] = str(min_speakers)
|
|
if max_speakers is not None:
|
|
data["max_speakers"] = str(max_speakers)
|
|
|
|
logger.info(f"[pyannote] Sending {filename} ({len(audio_bytes)} bytes)")
|
|
t0 = time.perf_counter()
|
|
resp = await client.post(
|
|
f"{PYANNOTE_URL}/diarize",
|
|
files={"file": (filename, audio_bytes)},
|
|
data=data,
|
|
timeout=TIMEOUT,
|
|
)
|
|
logger.info(f"[pyannote] {resp.status_code} in {time.perf_counter() - t0:.1f}s")
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
def _reconcile(whisper_result: dict, pyannote_result: dict) -> list[dict]:
|
|
turns = pyannote_result.get("turns", [])
|
|
if not turns:
|
|
return [
|
|
{"speaker": "SPEAKER_00", "start": seg["start"], "end": seg["end"], "text": seg["text"]}
|
|
for seg in whisper_result.get("segments", [])
|
|
]
|
|
|
|
all_words = []
|
|
for seg in whisper_result.get("segments", []):
|
|
words = seg.get("words")
|
|
if words:
|
|
all_words.extend(words)
|
|
else:
|
|
all_words.append({"start": seg["start"], "end": seg["end"], "word": seg["text"]})
|
|
|
|
if not all_words:
|
|
return []
|
|
|
|
def find_speaker(midpoint: float) -> str:
|
|
for turn in turns:
|
|
if turn["start"] <= midpoint <= turn["end"]:
|
|
return turn["speaker"]
|
|
min_dist = float("inf")
|
|
closest = turns[0]["speaker"]
|
|
for turn in turns:
|
|
dist = min(abs(midpoint - turn["start"]), abs(midpoint - turn["end"]))
|
|
if dist < min_dist:
|
|
min_dist = dist
|
|
closest = turn["speaker"]
|
|
return closest
|
|
|
|
for w in all_words:
|
|
w["speaker"] = find_speaker((w["start"] + w["end"]) / 2)
|
|
|
|
utterances: list[dict] = []
|
|
current_speaker = all_words[0]["speaker"]
|
|
current_words = [all_words[0]]
|
|
|
|
for w in all_words[1:]:
|
|
if w["speaker"] == current_speaker:
|
|
current_words.append(w)
|
|
else:
|
|
utterances.append({
|
|
"speaker": current_speaker,
|
|
"start": round(current_words[0]["start"], 3),
|
|
"end": round(current_words[-1]["end"], 3),
|
|
"text": "".join(w["word"] for w in current_words).strip(),
|
|
})
|
|
current_speaker = w["speaker"]
|
|
current_words = [w]
|
|
|
|
if current_words:
|
|
utterances.append({
|
|
"speaker": current_speaker,
|
|
"start": round(current_words[0]["start"], 3),
|
|
"end": round(current_words[-1]["end"], 3),
|
|
"text": "".join(w["word"] for w in current_words).strip(),
|
|
})
|
|
|
|
return utterances
|
|
|
|
|
|
def _to_srt(utterances: list[dict]) -> str:
|
|
lines = []
|
|
for i, u in enumerate(utterances, 1):
|
|
lines += [str(i), f"{_format_timestamp_srt(u['start'])} --> {_format_timestamp_srt(u['end'])}", f"[{u['speaker']}] {u['text']}", ""]
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _to_vtt(utterances: list[dict]) -> str:
|
|
lines = ["WEBVTT", ""]
|
|
for u in utterances:
|
|
lines += [f"{_format_timestamp_vtt(u['start'])} --> {_format_timestamp_vtt(u['end'])}", f"[{u['speaker']}] {u['text']}", ""]
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _to_txt(utterances: list[dict]) -> str:
|
|
return "\n".join(f"{u['speaker']}: {u['text']}" for u in utterances)
|
|
|
|
|
|
# ────────────────────── endpoints ──────────────────────
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
async with httpx.AsyncClient() as client:
|
|
results = {}
|
|
for name, url in [("whisper", WHISPER_URL), ("pyannote", PYANNOTE_URL), ("ollama", OLLAMA_URL)]:
|
|
try:
|
|
r = await client.get(f"{url}/health" if name != "ollama" else f"{url}/", timeout=3.0)
|
|
results[name] = "ok" if r.status_code < 400 else "error"
|
|
except Exception:
|
|
results[name] = "unreachable"
|
|
return {"status": "ok", "services": results}
|
|
|
|
|
|
@app.post("/transcribe")
|
|
async def transcribe(
|
|
file: UploadFile = File(...),
|
|
language: str | None = Form(None),
|
|
initial_prompt: str | None = Form(None),
|
|
beam_size: int = Form(5),
|
|
batch_size: int = Form(16),
|
|
response_format: str = Form("json"),
|
|
):
|
|
audio_bytes = await file.read()
|
|
async with httpx.AsyncClient() as client:
|
|
result = await _call_whisper(client, audio_bytes, file.filename or "audio.wav",
|
|
language, initial_prompt, beam_size, batch_size)
|
|
|
|
if response_format == "srt":
|
|
return JSONResponse({"format": "srt", "content": _to_srt([
|
|
{"speaker": "", "start": s["start"], "end": s["end"], "text": s["text"]}
|
|
for s in result["segments"]
|
|
])})
|
|
if response_format == "vtt":
|
|
return JSONResponse({"format": "vtt", "content": _to_vtt([
|
|
{"speaker": "", "start": s["start"], "end": s["end"], "text": s["text"]}
|
|
for s in result["segments"]
|
|
])})
|
|
if response_format == "txt":
|
|
return JSONResponse({"format": "txt", "content": " ".join(s["text"] for s in result["segments"])})
|
|
return result
|
|
|
|
|
|
@app.post("/diarize")
|
|
async def diarize(
|
|
file: UploadFile = File(...),
|
|
num_speakers: int | None = Form(None),
|
|
min_speakers: int | None = Form(None),
|
|
max_speakers: int | None = Form(None),
|
|
min_duration: float = Form(0.5),
|
|
merge_gap: float = Form(0.3),
|
|
):
|
|
audio_bytes = await file.read()
|
|
async with httpx.AsyncClient() as client:
|
|
return await _call_pyannote(client, audio_bytes, file.filename or "audio.wav",
|
|
num_speakers, min_speakers, max_speakers, min_duration, merge_gap)
|
|
|
|
|
|
@app.post("/process")
|
|
async def process(
|
|
file: UploadFile = File(...),
|
|
language: str | None = Form(None),
|
|
initial_prompt: str | None = Form(None),
|
|
beam_size: int = Form(5),
|
|
batch_size: int = Form(16),
|
|
num_speakers: int | None = Form(None),
|
|
min_speakers: int | None = Form(None),
|
|
max_speakers: int | None = Form(None),
|
|
min_duration: float = Form(0.5),
|
|
merge_gap: float = Form(0.3),
|
|
response_format: str = Form("json"),
|
|
):
|
|
t0 = time.perf_counter()
|
|
audio_bytes = await file.read()
|
|
filename = file.filename or "audio.wav"
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
whisper_result, pyannote_result = await asyncio.gather(
|
|
_call_whisper(client, audio_bytes, filename, language, initial_prompt, beam_size, batch_size),
|
|
_call_pyannote(client, audio_bytes, filename, num_speakers, min_speakers, max_speakers, min_duration, merge_gap),
|
|
)
|
|
|
|
utterances = _reconcile(whisper_result, pyannote_result)
|
|
elapsed = time.perf_counter() - t0
|
|
|
|
if response_format == "srt":
|
|
return JSONResponse({"format": "srt", "content": _to_srt(utterances), "processing_time": round(elapsed, 3)})
|
|
if response_format == "vtt":
|
|
return JSONResponse({"format": "vtt", "content": _to_vtt(utterances), "processing_time": round(elapsed, 3)})
|
|
if response_format == "txt":
|
|
return JSONResponse({"format": "txt", "content": _to_txt(utterances), "processing_time": round(elapsed, 3)})
|
|
|
|
return {
|
|
"language": whisper_result.get("language"),
|
|
"language_probability": whisper_result.get("language_probability"),
|
|
"duration": whisper_result.get("duration"),
|
|
"processing_time": round(elapsed, 3),
|
|
"whisper_time": whisper_result.get("processing_time"),
|
|
"pyannote_time": pyannote_result.get("processing_time"),
|
|
"speakers": pyannote_result.get("speakers", []),
|
|
"num_speakers": pyannote_result.get("num_speakers", 0),
|
|
"utterances": utterances,
|
|
} |