build: first commit
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends curl && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY app.py .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
+353
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
API Gateway: маршрутизация запросов и reconciliation.
|
||||
|
||||
POST /transcribe → whisper-service (только транскрибация)
|
||||
POST /diarize → pyannote-service (только диаризация)
|
||||
POST /process → оба параллельно + сшивка по словам
|
||||
|
||||
Управление VRAM:
|
||||
POST /whisper/load → загрузить whisper
|
||||
POST /whisper/unload → выгрузить whisper
|
||||
POST /pyannote/load → загрузить pyannote
|
||||
POST /pyannote/unload → выгрузить pyannote
|
||||
POST /ollama/unload → выгрузить текущую модель ollama (keep_alive=0)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, File, Form, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
logger = logging.getLogger("gateway")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s")
|
||||
|
||||
WHISPER_URL = os.getenv("WHISPER_URL", "http://whisper-service:8001")
|
||||
PYANNOTE_URL = os.getenv("PYANNOTE_URL", "http://pyannote-service:8002")
|
||||
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://ollama:11434")
|
||||
|
||||
TIMEOUT = httpx.Timeout(timeout=1800.0)
|
||||
|
||||
app = FastAPI(
|
||||
title="Speech Processing API",
|
||||
description="Транскрибация + диаризация аудио. Whisper (GPU) + pyannote (CPU), параллельно.",
|
||||
version="4.0.0",
|
||||
)
|
||||
|
||||
|
||||
# ────────────────────── VRAM management proxies ──────────────────────
|
||||
|
||||
@app.post("/whisper/load")
|
||||
async def whisper_load():
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.post(f"{WHISPER_URL}/load", timeout=120.0)
|
||||
return r.json()
|
||||
|
||||
|
||||
@app.post("/whisper/unload")
|
||||
async def whisper_unload():
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.post(f"{WHISPER_URL}/unload", timeout=30.0)
|
||||
return r.json()
|
||||
|
||||
|
||||
@app.post("/pyannote/load")
|
||||
async def pyannote_load():
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.post(f"{PYANNOTE_URL}/load", timeout=120.0)
|
||||
return r.json()
|
||||
|
||||
|
||||
@app.post("/pyannote/unload")
|
||||
async def pyannote_unload():
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.post(f"{PYANNOTE_URL}/unload", timeout=30.0)
|
||||
return r.json()
|
||||
|
||||
|
||||
@app.post("/ollama/unload")
|
||||
async def ollama_unload(model: str = "llama3.2"):
|
||||
"""Выгрузить модель ollama (keep_alive=0)."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.post(
|
||||
f"{OLLAMA_URL}/api/generate",
|
||||
json={"model": model, "keep_alive": 0},
|
||||
timeout=30.0,
|
||||
)
|
||||
return {"status": "unloaded", "model": model}
|
||||
|
||||
|
||||
# ────────────────────── helpers ──────────────────────
|
||||
|
||||
def _format_timestamp_srt(seconds: float) -> str:
|
||||
if seconds < 0:
|
||||
seconds = 0
|
||||
total_ms = round(seconds * 1000)
|
||||
hours, total_ms = divmod(total_ms, 3_600_000)
|
||||
minutes, total_ms = divmod(total_ms, 60_000)
|
||||
secs, ms = divmod(total_ms, 1000)
|
||||
return f"{hours:02d}:{minutes:02d}:{secs:02d},{ms:03d}"
|
||||
|
||||
|
||||
def _format_timestamp_vtt(seconds: float) -> str:
|
||||
if seconds < 0:
|
||||
seconds = 0
|
||||
total_ms = round(seconds * 1000)
|
||||
hours, total_ms = divmod(total_ms, 3_600_000)
|
||||
minutes, total_ms = divmod(total_ms, 60_000)
|
||||
secs, ms = divmod(total_ms, 1000)
|
||||
return f"{hours:02d}:{minutes:02d}:{secs:02d}.{ms:03d}"
|
||||
|
||||
|
||||
async def _call_whisper(
|
||||
client: httpx.AsyncClient,
|
||||
audio_bytes: bytes,
|
||||
filename: str,
|
||||
language: str | None,
|
||||
initial_prompt: str | None,
|
||||
beam_size: int,
|
||||
batch_size: int,
|
||||
) -> dict:
|
||||
language = language if language and language not in ("string", "") else None
|
||||
initial_prompt = initial_prompt if initial_prompt and initial_prompt not in ("string", "") else None
|
||||
|
||||
data = {"beam_size": str(beam_size), "batch_size": str(batch_size), "word_timestamps": "true"}
|
||||
if language:
|
||||
data["language"] = language
|
||||
if initial_prompt:
|
||||
data["initial_prompt"] = initial_prompt
|
||||
|
||||
logger.info(f"[whisper] Sending {filename} ({len(audio_bytes)} bytes)")
|
||||
t0 = time.perf_counter()
|
||||
resp = await client.post(
|
||||
f"{WHISPER_URL}/transcribe",
|
||||
files={"file": (filename, audio_bytes)},
|
||||
data=data,
|
||||
timeout=TIMEOUT,
|
||||
)
|
||||
logger.info(f"[whisper] {resp.status_code} in {time.perf_counter() - t0:.1f}s")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _call_pyannote(
|
||||
client: httpx.AsyncClient,
|
||||
audio_bytes: bytes,
|
||||
filename: str,
|
||||
num_speakers: int | None,
|
||||
min_speakers: int | None,
|
||||
max_speakers: int | None,
|
||||
min_duration: float = 0.5,
|
||||
merge_gap: float = 0.3,
|
||||
) -> dict:
|
||||
data = {"min_duration": str(min_duration), "merge_gap": str(merge_gap)}
|
||||
if num_speakers is not None:
|
||||
data["num_speakers"] = str(num_speakers)
|
||||
if min_speakers is not None:
|
||||
data["min_speakers"] = str(min_speakers)
|
||||
if max_speakers is not None:
|
||||
data["max_speakers"] = str(max_speakers)
|
||||
|
||||
logger.info(f"[pyannote] Sending {filename} ({len(audio_bytes)} bytes)")
|
||||
t0 = time.perf_counter()
|
||||
resp = await client.post(
|
||||
f"{PYANNOTE_URL}/diarize",
|
||||
files={"file": (filename, audio_bytes)},
|
||||
data=data,
|
||||
timeout=TIMEOUT,
|
||||
)
|
||||
logger.info(f"[pyannote] {resp.status_code} in {time.perf_counter() - t0:.1f}s")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _reconcile(whisper_result: dict, pyannote_result: dict) -> list[dict]:
|
||||
turns = pyannote_result.get("turns", [])
|
||||
if not turns:
|
||||
return [
|
||||
{"speaker": "SPEAKER_00", "start": seg["start"], "end": seg["end"], "text": seg["text"]}
|
||||
for seg in whisper_result.get("segments", [])
|
||||
]
|
||||
|
||||
all_words = []
|
||||
for seg in whisper_result.get("segments", []):
|
||||
words = seg.get("words")
|
||||
if words:
|
||||
all_words.extend(words)
|
||||
else:
|
||||
all_words.append({"start": seg["start"], "end": seg["end"], "word": seg["text"]})
|
||||
|
||||
if not all_words:
|
||||
return []
|
||||
|
||||
def find_speaker(midpoint: float) -> str:
|
||||
for turn in turns:
|
||||
if turn["start"] <= midpoint <= turn["end"]:
|
||||
return turn["speaker"]
|
||||
min_dist = float("inf")
|
||||
closest = turns[0]["speaker"]
|
||||
for turn in turns:
|
||||
dist = min(abs(midpoint - turn["start"]), abs(midpoint - turn["end"]))
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
closest = turn["speaker"]
|
||||
return closest
|
||||
|
||||
for w in all_words:
|
||||
w["speaker"] = find_speaker((w["start"] + w["end"]) / 2)
|
||||
|
||||
utterances: list[dict] = []
|
||||
current_speaker = all_words[0]["speaker"]
|
||||
current_words = [all_words[0]]
|
||||
|
||||
for w in all_words[1:]:
|
||||
if w["speaker"] == current_speaker:
|
||||
current_words.append(w)
|
||||
else:
|
||||
utterances.append({
|
||||
"speaker": current_speaker,
|
||||
"start": round(current_words[0]["start"], 3),
|
||||
"end": round(current_words[-1]["end"], 3),
|
||||
"text": "".join(w["word"] for w in current_words).strip(),
|
||||
})
|
||||
current_speaker = w["speaker"]
|
||||
current_words = [w]
|
||||
|
||||
if current_words:
|
||||
utterances.append({
|
||||
"speaker": current_speaker,
|
||||
"start": round(current_words[0]["start"], 3),
|
||||
"end": round(current_words[-1]["end"], 3),
|
||||
"text": "".join(w["word"] for w in current_words).strip(),
|
||||
})
|
||||
|
||||
return utterances
|
||||
|
||||
|
||||
def _to_srt(utterances: list[dict]) -> str:
|
||||
lines = []
|
||||
for i, u in enumerate(utterances, 1):
|
||||
lines += [str(i), f"{_format_timestamp_srt(u['start'])} --> {_format_timestamp_srt(u['end'])}", f"[{u['speaker']}] {u['text']}", ""]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _to_vtt(utterances: list[dict]) -> str:
|
||||
lines = ["WEBVTT", ""]
|
||||
for u in utterances:
|
||||
lines += [f"{_format_timestamp_vtt(u['start'])} --> {_format_timestamp_vtt(u['end'])}", f"[{u['speaker']}] {u['text']}", ""]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _to_txt(utterances: list[dict]) -> str:
|
||||
return "\n".join(f"{u['speaker']}: {u['text']}" for u in utterances)
|
||||
|
||||
|
||||
# ────────────────────── endpoints ──────────────────────
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
async with httpx.AsyncClient() as client:
|
||||
results = {}
|
||||
for name, url in [("whisper", WHISPER_URL), ("pyannote", PYANNOTE_URL), ("ollama", OLLAMA_URL)]:
|
||||
try:
|
||||
r = await client.get(f"{url}/health" if name != "ollama" else f"{url}/", timeout=3.0)
|
||||
results[name] = "ok" if r.status_code < 400 else "error"
|
||||
except Exception:
|
||||
results[name] = "unreachable"
|
||||
return {"status": "ok", "services": results}
|
||||
|
||||
|
||||
@app.post("/transcribe")
|
||||
async def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
language: str | None = Form(None),
|
||||
initial_prompt: str | None = Form(None),
|
||||
beam_size: int = Form(5),
|
||||
batch_size: int = Form(16),
|
||||
response_format: str = Form("json"),
|
||||
):
|
||||
audio_bytes = await file.read()
|
||||
async with httpx.AsyncClient() as client:
|
||||
result = await _call_whisper(client, audio_bytes, file.filename or "audio.wav",
|
||||
language, initial_prompt, beam_size, batch_size)
|
||||
|
||||
if response_format == "srt":
|
||||
return JSONResponse({"format": "srt", "content": _to_srt([
|
||||
{"speaker": "", "start": s["start"], "end": s["end"], "text": s["text"]}
|
||||
for s in result["segments"]
|
||||
])})
|
||||
if response_format == "vtt":
|
||||
return JSONResponse({"format": "vtt", "content": _to_vtt([
|
||||
{"speaker": "", "start": s["start"], "end": s["end"], "text": s["text"]}
|
||||
for s in result["segments"]
|
||||
])})
|
||||
if response_format == "txt":
|
||||
return JSONResponse({"format": "txt", "content": " ".join(s["text"] for s in result["segments"])})
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/diarize")
|
||||
async def diarize(
|
||||
file: UploadFile = File(...),
|
||||
num_speakers: int | None = Form(None),
|
||||
min_speakers: int | None = Form(None),
|
||||
max_speakers: int | None = Form(None),
|
||||
min_duration: float = Form(0.5),
|
||||
merge_gap: float = Form(0.3),
|
||||
):
|
||||
audio_bytes = await file.read()
|
||||
async with httpx.AsyncClient() as client:
|
||||
return await _call_pyannote(client, audio_bytes, file.filename or "audio.wav",
|
||||
num_speakers, min_speakers, max_speakers, min_duration, merge_gap)
|
||||
|
||||
|
||||
@app.post("/process")
|
||||
async def process(
|
||||
file: UploadFile = File(...),
|
||||
language: str | None = Form(None),
|
||||
initial_prompt: str | None = Form(None),
|
||||
beam_size: int = Form(5),
|
||||
batch_size: int = Form(16),
|
||||
num_speakers: int | None = Form(None),
|
||||
min_speakers: int | None = Form(None),
|
||||
max_speakers: int | None = Form(None),
|
||||
min_duration: float = Form(0.5),
|
||||
merge_gap: float = Form(0.3),
|
||||
response_format: str = Form("json"),
|
||||
):
|
||||
t0 = time.perf_counter()
|
||||
audio_bytes = await file.read()
|
||||
filename = file.filename or "audio.wav"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
whisper_result, pyannote_result = await asyncio.gather(
|
||||
_call_whisper(client, audio_bytes, filename, language, initial_prompt, beam_size, batch_size),
|
||||
_call_pyannote(client, audio_bytes, filename, num_speakers, min_speakers, max_speakers, min_duration, merge_gap),
|
||||
)
|
||||
|
||||
utterances = _reconcile(whisper_result, pyannote_result)
|
||||
elapsed = time.perf_counter() - t0
|
||||
|
||||
if response_format == "srt":
|
||||
return JSONResponse({"format": "srt", "content": _to_srt(utterances), "processing_time": round(elapsed, 3)})
|
||||
if response_format == "vtt":
|
||||
return JSONResponse({"format": "vtt", "content": _to_vtt(utterances), "processing_time": round(elapsed, 3)})
|
||||
if response_format == "txt":
|
||||
return JSONResponse({"format": "txt", "content": _to_txt(utterances), "processing_time": round(elapsed, 3)})
|
||||
|
||||
return {
|
||||
"language": whisper_result.get("language"),
|
||||
"language_probability": whisper_result.get("language_probability"),
|
||||
"duration": whisper_result.get("duration"),
|
||||
"processing_time": round(elapsed, 3),
|
||||
"whisper_time": whisper_result.get("processing_time"),
|
||||
"pyannote_time": pyannote_result.get("processing_time"),
|
||||
"speakers": pyannote_result.get("speakers", []),
|
||||
"num_speakers": pyannote_result.get("num_speakers", 0),
|
||||
"utterances": utterances,
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
fastapi==0.115.*
|
||||
uvicorn[standard]==0.34.*
|
||||
python-multipart==0.0.*
|
||||
httpx==0.28.*
|
||||
Reference in New Issue
Block a user