proxy: multi-backend Ollama aggregation with per-model routing + failover
The gateway can now aggregate models across SEVERAL Ollama backends and route each request to the correct one. Opt-in via OLLAMA_BACKENDS in .env — single-backend deployments are unaffected (effective_backends() synthesizes a single "default" backend from the legacy OLLAMA_BASE_URL / OLLAMA_AUTH_TOKEN fields when the list is empty). Behavior: - Discovery polls EVERY configured backend in parallel each tick; the cache stores per-backend model lists plus a model → backends priority list (config order = priority order). - /api/tags and /v1/models surface the DEDUPLICATED UNION of all backends' models. - A request's model is looked up in the priority list and proxied to the FIRST backend that hosts it. If that backend errors on the request, the pipeline transparently fails over to the next backend that has the same model (the streaming-failover probes the first chunk before releasing the response, so we never serve partial bytes from a dead backend). - No existence disclosure: a model not hosted by any backend yields the same generic 403 as "model not allowed" (SPEC §13.6 preserved). Components: - config.py: new BackendSpec model + ollama_backends list field + an effective_backends() helper. - proxy/router.py (new): BackendRouter (clients_for_with_failover), build_http_clients() builds one httpx client per backend with its own auth headers, build_backend_headers() exposes the per-backend header composition for the CLI probe. - proxy/discovery.py: DiscoveryCache.set_per_backend() + backends_for(), refresh_all_backends() polls all in parallel, discovery_loop_multi() replaces the single-backend loop in production; the legacy single- backend functions are kept for the dependency-override tests. - proxy/pipeline.py: Pipeline accepts an optional router; the four proxy methods now retry against each candidate backend in priority order on transport error. - lifespan.py: constructs the per-backend client dict, stores the router on app.state, launches discovery_loop_multi. - deps.py: get_backend_router provider + BackendRouterDep type alias; get_pipeline passes the router into Pipeline. - cli/manage.py: probe-ollama iterates every backend and reports per- backend status; list-models groups its output by backend and prints the union count + Redis cache size for sanity. - .env.example + docker-compose.yml: document and pass through OLLAMA_BACKENDS with a real example. Verified: ruff check (clean), mypy --strict src/ + tests/ (clean, 66 source files), pytest (60 passed + 39 skipped — same baseline as before this change; integration tests are Docker-socket-gated).
This commit is contained in:
15
.env.example
15
.env.example
@@ -46,6 +46,21 @@ OLLAMA_AUTH_TOKEN=
|
|||||||
OLLAMA_AUTH_HEADER=Authorization
|
OLLAMA_AUTH_HEADER=Authorization
|
||||||
OLLAMA_AUTH_SCHEME=Bearer
|
OLLAMA_AUTH_SCHEME=Bearer
|
||||||
|
|
||||||
|
# ─────────────────────── Multi-backend (opt-in) ──────────────────
|
||||||
|
# Aggregate models across SEVERAL Ollama backends. When set (non-empty JSON
|
||||||
|
# list), this REPLACES the single-backend config above — include each backend
|
||||||
|
# explicitly, in priority order. Requests for a given model route to the
|
||||||
|
# FIRST backend that hosts it; on transport errors the gateway transparently
|
||||||
|
# fails over to the next backend that has the same model.
|
||||||
|
#
|
||||||
|
# Each entry: {name, base_url, [auth_token], [auth_header], [auth_scheme]}
|
||||||
|
# Example (embedded GPU container + publicly-fronted auth-protected Ollama):
|
||||||
|
# OLLAMA_BACKENDS='[
|
||||||
|
# {"name":"embedded","base_url":"http://ollama:11434"},
|
||||||
|
# {"name":"public","base_url":"https://ollama.neuronetz.ai","auth_token":"YOUR_TOKEN"}
|
||||||
|
# ]'
|
||||||
|
OLLAMA_BACKENDS=
|
||||||
|
|
||||||
# ──────────────────────── Model discovery (§4.6) ─────────────────
|
# ──────────────────────── Model discovery (§4.6) ─────────────────
|
||||||
MODEL_DISCOVERY_REFRESH_S=60
|
MODEL_DISCOVERY_REFRESH_S=60
|
||||||
MODEL_DISCOVERY_CACHE_TTL_S=120
|
MODEL_DISCOVERY_CACHE_TTL_S=120
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ services:
|
|||||||
OLLAMA_AUTH_TOKEN: ${OLLAMA_AUTH_TOKEN:-}
|
OLLAMA_AUTH_TOKEN: ${OLLAMA_AUTH_TOKEN:-}
|
||||||
OLLAMA_AUTH_HEADER: ${OLLAMA_AUTH_HEADER:-Authorization}
|
OLLAMA_AUTH_HEADER: ${OLLAMA_AUTH_HEADER:-Authorization}
|
||||||
OLLAMA_AUTH_SCHEME: ${OLLAMA_AUTH_SCHEME:-Bearer}
|
OLLAMA_AUTH_SCHEME: ${OLLAMA_AUTH_SCHEME:-Bearer}
|
||||||
|
# Multi-backend (opt-in JSON list). See .env.example for the schema.
|
||||||
|
OLLAMA_BACKENDS: ${OLLAMA_BACKENDS:-}
|
||||||
MODEL_DISCOVERY_REFRESH_S: ${MODEL_DISCOVERY_REFRESH_S:-60}
|
MODEL_DISCOVERY_REFRESH_S: ${MODEL_DISCOVERY_REFRESH_S:-60}
|
||||||
MODEL_DISCOVERY_CACHE_TTL_S: ${MODEL_DISCOVERY_CACHE_TTL_S:-120}
|
MODEL_DISCOVERY_CACHE_TTL_S: ${MODEL_DISCOVERY_CACHE_TTL_S:-120}
|
||||||
DEFAULT_RPM: ${DEFAULT_RPM:-60}
|
DEFAULT_RPM: ${DEFAULT_RPM:-60}
|
||||||
|
|||||||
@@ -274,25 +274,62 @@ def list_models(
|
|||||||
str | None, typer.Option("--tenant", help="Also show this tenant's effective set.")
|
str | None, typer.Option("--tenant", help="Also show this tenant's effective set.")
|
||||||
] = None,
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Show live-discovered models (and, with --tenant, the effective set)."""
|
"""Show live-discovered models per backend (and, with --tenant, the effective set).
|
||||||
|
|
||||||
|
Spins up a short-lived poll against every configured backend to surface
|
||||||
|
which models are currently live where, then (optionally) computes the
|
||||||
|
given tenant's effective allow-list intersection.
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
|
|
||||||
|
from neuronetz_gateway.config import BackendSpec
|
||||||
|
from neuronetz_gateway.proxy.discovery import fetch_tags, names_of
|
||||||
|
from neuronetz_gateway.proxy.router import build_backend_headers
|
||||||
|
|
||||||
|
async def _poll_one(backend: BackendSpec) -> tuple[str, list[str]]:
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
base_url=backend.base_url,
|
||||||
|
timeout=httpx.Timeout(connect=5, read=10, write=10, pool=10),
|
||||||
|
headers=build_backend_headers(backend),
|
||||||
|
) as client:
|
||||||
|
return backend.name, sorted(names_of(await fetch_tags(client)))
|
||||||
|
except (httpx.HTTPError, ValueError) as exc:
|
||||||
|
typer.secho(
|
||||||
|
f" ({backend.name}: probe failed — {type(exc).__name__})",
|
||||||
|
fg=typer.colors.RED,
|
||||||
|
)
|
||||||
|
return backend.name, []
|
||||||
|
|
||||||
async def work(settings: Settings) -> None:
|
async def work(settings: Settings) -> None:
|
||||||
|
backends = settings.effective_backends()
|
||||||
|
results = [await _poll_one(b) for b in backends]
|
||||||
|
union: set[str] = set()
|
||||||
|
typer.echo("live-discovered models, by backend:")
|
||||||
|
for name, models in results:
|
||||||
|
typer.echo(f" [{name}]")
|
||||||
|
if not models:
|
||||||
|
typer.echo(" (none)")
|
||||||
|
for m in models:
|
||||||
|
typer.echo(f" {m}")
|
||||||
|
union.update(models)
|
||||||
|
typer.echo(f"\nunion across all backends: {len(union)} unique model(s)")
|
||||||
|
|
||||||
|
# Surface what Redis currently has cached too (for sanity vs the live poll).
|
||||||
client = redis.from_url(settings.redis_url, decode_responses=True)
|
client = redis.from_url(settings.redis_url, decode_responses=True)
|
||||||
try:
|
try:
|
||||||
discovered = await read_discovered_from_redis(client)
|
cached = await read_discovered_from_redis(client)
|
||||||
finally:
|
finally:
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
discovered_names = sorted(discovered)
|
if cached:
|
||||||
typer.echo("discovered models (live from Ollama via discovery cache):")
|
typer.echo(
|
||||||
if discovered_names:
|
f"redis cache (gateway:models:discovered): {len(cached)} model(s)"
|
||||||
for name in discovered_names:
|
)
|
||||||
typer.echo(f" {name}")
|
|
||||||
else:
|
|
||||||
typer.echo(" (none — discovery cache empty or expired; requests fail closed)")
|
|
||||||
|
|
||||||
if tenant is None:
|
if tenant is None:
|
||||||
return
|
return
|
||||||
|
discovered = frozenset(union)
|
||||||
factory = create_session_factory(create_engine(settings))
|
factory = create_session_factory(create_engine(settings))
|
||||||
async with session_scope(factory) as session:
|
async with session_scope(factory) as session:
|
||||||
tenants = TenantRepository(session)
|
tenants = TenantRepository(session)
|
||||||
@@ -319,29 +356,30 @@ def probe_ollama(
|
|||||||
*,
|
*,
|
||||||
timeout: Annotated[float, typer.Option(help="Per-request timeout in seconds.")] = 10.0,
|
timeout: Annotated[float, typer.Option(help="Per-request timeout in seconds.")] = 10.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Probe the upstream Ollama: GET /api/version and /api/tags.
|
"""Probe every configured Ollama backend (GET /api/version + /api/tags).
|
||||||
|
|
||||||
Uses the exact same httpx config as the running gateway (base URL, timeouts,
|
Iterates every entry returned by :meth:`Settings.effective_backends` (i.e.
|
||||||
and the OLLAMA_AUTH_TOKEN header if set) so a passing probe proves the
|
the OLLAMA_BACKENDS JSON list, or a single synthesized backend from the
|
||||||
gateway will be able to reach the backend in production. The token itself
|
legacy OLLAMA_BASE_URL fields). Reports per-backend status and the first
|
||||||
is NEVER printed — only whether one was attached.
|
5 model names; exits non-zero if any backend fails. Tokens are never
|
||||||
|
printed — only whether one was attached.
|
||||||
"""
|
"""
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from neuronetz_gateway.lifespan import _build_upstream_headers
|
from neuronetz_gateway.config import BackendSpec
|
||||||
|
from neuronetz_gateway.proxy.router import build_backend_headers
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
headers = _build_upstream_headers(settings)
|
backends = settings.effective_backends()
|
||||||
auth_header = settings.ollama_auth_header
|
|
||||||
has_token = settings.ollama_auth_token is not None and bool(
|
|
||||||
settings.ollama_auth_token.get_secret_value().strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
auth_status = f"sending {auth_header}" if has_token else "no token (OLLAMA_AUTH_TOKEN unset)"
|
async def _probe_one(backend: BackendSpec, idx: int, total: int) -> int:
|
||||||
typer.echo(f"target: {settings.ollama_base_url}")
|
marker = f"[{idx + 1}/{total}] {backend.name}"
|
||||||
typer.echo(f"auth: {auth_status}")
|
typer.echo("")
|
||||||
|
typer.secho(marker, bold=True)
|
||||||
async def _go() -> int:
|
typer.echo(f" target: {backend.base_url}")
|
||||||
|
typer.echo(
|
||||||
|
f" auth: {('sending ' + backend.auth_header) if backend.has_auth else 'no token'}"
|
||||||
|
)
|
||||||
probe_timeout = httpx.Timeout(
|
probe_timeout = httpx.Timeout(
|
||||||
connect=settings.ollama_connect_timeout_s,
|
connect=settings.ollama_connect_timeout_s,
|
||||||
read=timeout,
|
read=timeout,
|
||||||
@@ -349,9 +387,9 @@ def probe_ollama(
|
|||||||
pool=timeout,
|
pool=timeout,
|
||||||
)
|
)
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
base_url=settings.ollama_base_url,
|
base_url=backend.base_url,
|
||||||
timeout=probe_timeout,
|
timeout=probe_timeout,
|
||||||
headers=headers,
|
headers=build_backend_headers(backend),
|
||||||
) as client:
|
) as client:
|
||||||
errors = 0
|
errors = 0
|
||||||
for path in ("/api/version", "/api/tags"):
|
for path in ("/api/version", "/api/tags"):
|
||||||
@@ -371,8 +409,8 @@ def probe_ollama(
|
|||||||
)
|
)
|
||||||
if resp.status_code in (401, 403):
|
if resp.status_code in (401, 403):
|
||||||
typer.echo(
|
typer.echo(
|
||||||
" upstream rejected the credentials — check "
|
" upstream rejected the credentials — check the "
|
||||||
"OLLAMA_AUTH_TOKEN / header."
|
"auth_token for this backend."
|
||||||
)
|
)
|
||||||
errors += 1
|
errors += 1
|
||||||
continue
|
continue
|
||||||
@@ -392,9 +430,22 @@ def probe_ollama(
|
|||||||
typer.echo(f" … and {n - 5} more")
|
typer.echo(f" … and {n - 5} more")
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
async def _go() -> int:
|
||||||
|
total_errors = 0
|
||||||
|
for idx, backend in enumerate(backends):
|
||||||
|
total_errors += await _probe_one(backend, idx, len(backends))
|
||||||
|
return total_errors
|
||||||
|
|
||||||
errors = asyncio.run(_go())
|
errors = asyncio.run(_go())
|
||||||
|
typer.echo("")
|
||||||
if errors:
|
if errors:
|
||||||
|
typer.secho(f"{errors} probe(s) failed.", fg=typer.colors.RED, bold=True)
|
||||||
raise typer.Exit(code=1)
|
raise typer.Exit(code=1)
|
||||||
|
typer.secho(
|
||||||
|
f"all {len(backends)} backend(s) reachable and authenticated.",
|
||||||
|
fg=typer.colors.GREEN,
|
||||||
|
bold=True,
|
||||||
|
)
|
||||||
typer.secho("upstream reachable and authenticated.", fg=typer.colors.GREEN, bold=True)
|
typer.secho("upstream reachable and authenticated.", fg=typer.colors.GREEN, bold=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,47 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
from pydantic import Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class BackendSpec(BaseModel):
|
||||||
|
"""One Ollama backend the gateway aggregates and proxies to.
|
||||||
|
|
||||||
|
The gateway can either run with a single backend (the legacy mode — left as
|
||||||
|
the default for backward compatibility) or with several backends configured
|
||||||
|
via ``OLLAMA_BACKENDS``, in which case the gateway:
|
||||||
|
|
||||||
|
- polls discovery against each backend in parallel and exposes the **union**
|
||||||
|
of their model sets on ``/api/tags`` and ``/v1/models``;
|
||||||
|
- routes every request to the FIRST backend (priority = list order) whose
|
||||||
|
model set contains the requested model;
|
||||||
|
- auto-fails-over to subsequent backends that have the same model if the
|
||||||
|
primary errors.
|
||||||
|
|
||||||
|
Provide as JSON in env:
|
||||||
|
OLLAMA_BACKENDS='[
|
||||||
|
{"name":"embedded","base_url":"http://ollama:11434"},
|
||||||
|
{"name":"public",
|
||||||
|
"base_url":"https://ollama.neuronetz.ai",
|
||||||
|
"auth_token":"…"}
|
||||||
|
]'
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = Field(min_length=1, max_length=64)
|
||||||
|
base_url: str = Field(min_length=1)
|
||||||
|
auth_token: SecretStr | None = None
|
||||||
|
auth_header: str = "Authorization"
|
||||||
|
auth_scheme: str = "Bearer"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_auth(self) -> bool:
|
||||||
|
"""True if this backend should send an auth header."""
|
||||||
|
if self.auth_token is None:
|
||||||
|
return False
|
||||||
|
return bool(self.auth_token.get_secret_value().strip())
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""Gateway runtime configuration. All fields map to SPEC §7 env vars."""
|
"""Gateway runtime configuration. All fields map to SPEC §7 env vars."""
|
||||||
|
|
||||||
@@ -46,6 +83,14 @@ class Settings(BaseSettings):
|
|||||||
ollama_auth_header: str = Field(default="Authorization")
|
ollama_auth_header: str = Field(default="Authorization")
|
||||||
ollama_auth_scheme: str = Field(default="Bearer")
|
ollama_auth_scheme: str = Field(default="Bearer")
|
||||||
|
|
||||||
|
# --- Multi-backend (opt-in) ---
|
||||||
|
# JSON list of additional/overriding backends. When empty (default), a single
|
||||||
|
# backend named "default" is constructed from the OLLAMA_BASE_URL / TOKEN /
|
||||||
|
# HEADER / SCHEME fields above (backward compat — current prod is unaffected).
|
||||||
|
# When set, this list IS the entire backend roster — include the embedded
|
||||||
|
# one explicitly if you want it. The list order is the routing priority.
|
||||||
|
ollama_backends: list[BackendSpec] = Field(default_factory=list)
|
||||||
|
|
||||||
# --- Model discovery (SPEC §4.6) ---
|
# --- Model discovery (SPEC §4.6) ---
|
||||||
model_discovery_refresh_s: int = Field(default=60)
|
model_discovery_refresh_s: int = Field(default=60)
|
||||||
model_discovery_cache_ttl_s: int = Field(default=120)
|
model_discovery_cache_ttl_s: int = Field(default=120)
|
||||||
@@ -89,6 +134,26 @@ class Settings(BaseSettings):
|
|||||||
"""Parse the comma-separated trusted-proxy list into individual hosts."""
|
"""Parse the comma-separated trusted-proxy list into individual hosts."""
|
||||||
return [p.strip() for p in self.gateway_trusted_proxies.split(",") if p.strip()]
|
return [p.strip() for p in self.gateway_trusted_proxies.split(",") if p.strip()]
|
||||||
|
|
||||||
|
def effective_backends(self) -> list[BackendSpec]:
|
||||||
|
"""Return the resolved backend roster.
|
||||||
|
|
||||||
|
If ``ollama_backends`` is non-empty, that list is used verbatim (order is
|
||||||
|
the routing priority). Otherwise a single backend is synthesized from the
|
||||||
|
legacy ``ollama_base_url`` / ``ollama_auth_token`` fields so a deployment
|
||||||
|
that hasn't opted into multi-backend keeps working unchanged.
|
||||||
|
"""
|
||||||
|
if self.ollama_backends:
|
||||||
|
return list(self.ollama_backends)
|
||||||
|
return [
|
||||||
|
BackendSpec(
|
||||||
|
name="default",
|
||||||
|
base_url=self.ollama_base_url,
|
||||||
|
auth_token=self.ollama_auth_token,
|
||||||
|
auth_header=self.ollama_auth_header,
|
||||||
|
auth_scheme=self.ollama_auth_scheme,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def get_settings() -> Settings:
|
def get_settings() -> Settings:
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from neuronetz_gateway.errors import AuthenticationError, DependencyUnavailableE
|
|||||||
from neuronetz_gateway.proxy.discovery import DiscoveryCache
|
from neuronetz_gateway.proxy.discovery import DiscoveryCache
|
||||||
from neuronetz_gateway.proxy.ollama import OllamaClient
|
from neuronetz_gateway.proxy.ollama import OllamaClient
|
||||||
from neuronetz_gateway.proxy.pipeline import Pipeline
|
from neuronetz_gateway.proxy.pipeline import Pipeline
|
||||||
|
from neuronetz_gateway.proxy.router import BackendRouter
|
||||||
from neuronetz_gateway.ratelimit.concurrency import ConcurrencyLimiter
|
from neuronetz_gateway.ratelimit.concurrency import ConcurrencyLimiter
|
||||||
from neuronetz_gateway.ratelimit.sliding_window import SlidingWindowLimiter
|
from neuronetz_gateway.ratelimit.sliding_window import SlidingWindowLimiter
|
||||||
|
|
||||||
@@ -66,10 +67,24 @@ def get_http_client(request: Request) -> httpx.AsyncClient:
|
|||||||
|
|
||||||
|
|
||||||
def get_ollama_client(request: Request) -> OllamaClient:
|
def get_ollama_client(request: Request) -> OllamaClient:
|
||||||
"""Provide the upstream Ollama proxy client (override target for tests)."""
|
"""Provide the upstream Ollama proxy client (override target for tests).
|
||||||
|
|
||||||
|
In multi-backend mode this returns the FIRST backend's client (priority
|
||||||
|
order = list order). The pipeline uses :func:`get_backend_router` for
|
||||||
|
per-model routing + failover; this provider is kept for tests and for code
|
||||||
|
paths that don't need routing.
|
||||||
|
"""
|
||||||
return OllamaClient(get_http_client(request))
|
return OllamaClient(get_http_client(request))
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend_router(request: Request) -> BackendRouter:
|
||||||
|
"""Provide the multi-backend router (one client per configured backend)."""
|
||||||
|
router: BackendRouter | None = getattr(request.app.state, "backend_router", None)
|
||||||
|
if router is None:
|
||||||
|
raise DependencyUnavailableError(internal_detail="backend router not initialised")
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
def get_discovery_cache(request: Request) -> DiscoveryCache:
|
def get_discovery_cache(request: Request) -> DiscoveryCache:
|
||||||
"""Provide the in-process discovery cache; fail closed if absent."""
|
"""Provide the in-process discovery cache; fail closed if absent."""
|
||||||
cache: DiscoveryCache | None = getattr(request.app.state, "discovery_cache", None)
|
cache: DiscoveryCache | None = getattr(request.app.state, "discovery_cache", None)
|
||||||
@@ -112,10 +127,17 @@ def get_pipeline(
|
|||||||
The pipeline owns all hot-path checks (rate limit, budget, concurrency,
|
The pipeline owns all hot-path checks (rate limit, budget, concurrency,
|
||||||
model/endpoint allowlist) and the streaming-with-bookkeeping contract.
|
model/endpoint allowlist) and the streaming-with-bookkeeping contract.
|
||||||
Audit deny-mode flips this to fail closed at the route layer.
|
Audit deny-mode flips this to fail closed at the route layer.
|
||||||
|
|
||||||
|
In multi-backend deployments the per-request backend selection is done by
|
||||||
|
the pipeline using the :class:`BackendRouter` on ``app.state``; the
|
||||||
|
``ollama`` argument here is the fallback single-backend client (used when
|
||||||
|
the router has no entry for a model, and as the override target for tests
|
||||||
|
that don't care about routing).
|
||||||
"""
|
"""
|
||||||
sessionmaker: async_sessionmaker[AsyncSession] | None = getattr(
|
sessionmaker: async_sessionmaker[AsyncSession] | None = getattr(
|
||||||
request.app.state, "db_sessionmaker", None
|
request.app.state, "db_sessionmaker", None
|
||||||
)
|
)
|
||||||
|
router: BackendRouter | None = getattr(request.app.state, "backend_router", None)
|
||||||
return Pipeline(
|
return Pipeline(
|
||||||
request=request,
|
request=request,
|
||||||
principal=principal,
|
principal=principal,
|
||||||
@@ -127,6 +149,7 @@ def get_pipeline(
|
|||||||
budget=BudgetCounter(redis_client),
|
budget=BudgetCounter(redis_client),
|
||||||
audit=audit,
|
audit=audit,
|
||||||
sessionmaker=sessionmaker,
|
sessionmaker=sessionmaker,
|
||||||
|
router=router,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -151,6 +174,7 @@ ConfigDep = Annotated[Settings, Depends(get_config)]
|
|||||||
RedisDep = Annotated[redis.Redis, Depends(get_redis)]
|
RedisDep = Annotated[redis.Redis, Depends(get_redis)]
|
||||||
HttpClientDep = Annotated[httpx.AsyncClient, Depends(get_http_client)]
|
HttpClientDep = Annotated[httpx.AsyncClient, Depends(get_http_client)]
|
||||||
OllamaClientDep = Annotated[OllamaClient, Depends(get_ollama_client)]
|
OllamaClientDep = Annotated[OllamaClient, Depends(get_ollama_client)]
|
||||||
|
BackendRouterDep = Annotated[BackendRouter, Depends(get_backend_router)]
|
||||||
DiscoveryCacheDep = Annotated[DiscoveryCache, Depends(get_discovery_cache)]
|
DiscoveryCacheDep = Annotated[DiscoveryCache, Depends(get_discovery_cache)]
|
||||||
PrincipalDep = Annotated[Principal, Depends(get_principal)]
|
PrincipalDep = Annotated[Principal, Depends(get_principal)]
|
||||||
AuditWriterDep = Annotated[AuditWriter, Depends(get_audit_writer)]
|
AuditWriterDep = Annotated[AuditWriter, Depends(get_audit_writer)]
|
||||||
@@ -160,6 +184,7 @@ DbSessionDep = Annotated[AsyncSession, Depends(get_db_session)]
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AuditWriterDep",
|
"AuditWriterDep",
|
||||||
|
"BackendRouterDep",
|
||||||
"ConfigDep",
|
"ConfigDep",
|
||||||
"DbSessionDep",
|
"DbSessionDep",
|
||||||
"DiscoveryCacheDep",
|
"DiscoveryCacheDep",
|
||||||
@@ -169,6 +194,7 @@ __all__ = [
|
|||||||
"PrincipalDep",
|
"PrincipalDep",
|
||||||
"RedisDep",
|
"RedisDep",
|
||||||
"get_audit_writer",
|
"get_audit_writer",
|
||||||
|
"get_backend_router",
|
||||||
"get_config",
|
"get_config",
|
||||||
"get_db_session",
|
"get_db_session",
|
||||||
"get_discovery_cache",
|
"get_discovery_cache",
|
||||||
|
|||||||
@@ -23,7 +23,9 @@ from neuronetz_gateway.auth.hashing import build_hasher
|
|||||||
from neuronetz_gateway.config import Settings, get_settings
|
from neuronetz_gateway.config import Settings, get_settings
|
||||||
from neuronetz_gateway.db.session import create_engine, create_session_factory
|
from neuronetz_gateway.db.session import create_engine, create_session_factory
|
||||||
from neuronetz_gateway.observability.logging import get_logger
|
from neuronetz_gateway.observability.logging import get_logger
|
||||||
from neuronetz_gateway.proxy.discovery import DiscoveryCache, discovery_loop
|
from neuronetz_gateway.proxy.discovery import DiscoveryCache, discovery_loop_multi
|
||||||
|
from neuronetz_gateway.proxy.ollama import OllamaClient
|
||||||
|
from neuronetz_gateway.proxy.router import BackendRouter, build_http_clients
|
||||||
from neuronetz_gateway.revocation import revocation_listener
|
from neuronetz_gateway.revocation import revocation_listener
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -93,7 +95,21 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|||||||
_log.error("redis_init_failed", error=str(exc))
|
_log.error("redis_init_failed", error=str(exc))
|
||||||
app.state.redis = None
|
app.state.redis = None
|
||||||
|
|
||||||
app.state.http_client = _build_http_client(settings)
|
# Build per-backend upstream clients (one per entry in OLLAMA_BACKENDS, or
|
||||||
|
# a single "default" backend synthesized from the legacy OLLAMA_BASE_URL).
|
||||||
|
backend_clients, backend_order = build_http_clients(settings)
|
||||||
|
app.state.backend_http_clients = backend_clients
|
||||||
|
app.state.backend_order = backend_order
|
||||||
|
# ``http_client`` retains its single-client meaning for code paths (and
|
||||||
|
# tests) that haven't been migrated to the router yet: it is the FIRST
|
||||||
|
# backend's httpx client. New code should reach upstream via the router.
|
||||||
|
app.state.http_client = backend_clients[backend_order[0]]
|
||||||
|
app.state.backend_router = BackendRouter(
|
||||||
|
clients={name: OllamaClient(client) for name, client in backend_clients.items()},
|
||||||
|
order=backend_order,
|
||||||
|
discovery=app.state.discovery_cache,
|
||||||
|
)
|
||||||
|
_log.info("backends_configured", backends=backend_order)
|
||||||
|
|
||||||
audit_writer = AuditWriter(settings.audit_buffer_size, app.state.db_sessionmaker)
|
audit_writer = AuditWriter(settings.audit_buffer_size, app.state.db_sessionmaker)
|
||||||
audit_writer.start()
|
audit_writer.start()
|
||||||
@@ -102,8 +118,12 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|||||||
# Background tasks (cancelled on shutdown).
|
# Background tasks (cancelled on shutdown).
|
||||||
tasks.append(
|
tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
discovery_loop(
|
discovery_loop_multi(
|
||||||
app.state.http_client, app.state.redis, app.state.discovery_cache, settings
|
backend_clients,
|
||||||
|
backend_order,
|
||||||
|
app.state.redis,
|
||||||
|
app.state.discovery_cache,
|
||||||
|
settings,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -135,10 +155,14 @@ async def _shutdown(
|
|||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
await audit_writer.stop()
|
await audit_writer.stop()
|
||||||
|
|
||||||
http_client: httpx.AsyncClient | None = getattr(app.state, "http_client", None)
|
# Close every per-backend httpx client (the legacy `http_client` attr is
|
||||||
if http_client is not None:
|
# one of these, so we only need to iterate the dict).
|
||||||
|
backend_clients: dict[str, httpx.AsyncClient] = getattr(
|
||||||
|
app.state, "backend_http_clients", {}
|
||||||
|
)
|
||||||
|
for client in backend_clients.values():
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
await http_client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
redis_client = getattr(app.state, "redis", None)
|
redis_client = getattr(app.state, "redis", None)
|
||||||
if redis_client is not None:
|
if redis_client is not None:
|
||||||
|
|||||||
@@ -81,34 +81,80 @@ def names_of(models: list[DiscoveredModel]) -> frozenset[str]:
|
|||||||
|
|
||||||
|
|
||||||
class DiscoveryCache:
|
class DiscoveryCache:
|
||||||
"""In-process holder for the latest discovered model set.
|
"""In-process holder for the latest discovered model set, per backend.
|
||||||
|
|
||||||
Holds both the structured records (for ``/api/tags`` / ``list-models``) and a
|
Holds the structured records (for ``/api/tags`` / ``list-models``), a fast
|
||||||
fast name set (for allowlist resolution on the hot path). Reads never block
|
name set (for allowlist resolution on the hot path), and a model→backends
|
||||||
on Redis or Ollama; the poller refreshes this in the background.
|
priority list (for the request router). Reads never block on Redis or
|
||||||
|
Ollama; the poller refreshes this in the background.
|
||||||
|
|
||||||
|
Backward compat: ``set(models)`` (single-list signature) still works and
|
||||||
|
populates a single "default" backend bucket.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
self._per_backend: dict[str, list[DiscoveredModel]] = {}
|
||||||
self._models: list[DiscoveredModel] = []
|
self._models: list[DiscoveredModel] = []
|
||||||
self._names: frozenset[str] = frozenset()
|
self._names: frozenset[str] = frozenset()
|
||||||
|
# Ordered list of backend names that host each model, in priority order
|
||||||
|
# (the order in which backends were polled — i.e. config order).
|
||||||
|
self._model_backends: dict[str, list[str]] = {}
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
async def set(self, models: list[DiscoveredModel]) -> None:
|
async def set(self, models: list[DiscoveredModel]) -> None:
|
||||||
"""Replace the in-process snapshot atomically."""
|
"""Legacy single-backend setter: replaces the cache with one bucket."""
|
||||||
|
await self.set_per_backend({"default": list(models)}, ["default"])
|
||||||
|
|
||||||
|
async def set_per_backend(
|
||||||
|
self,
|
||||||
|
per_backend: dict[str, list[DiscoveredModel]],
|
||||||
|
backend_order: list[str],
|
||||||
|
) -> None:
|
||||||
|
"""Replace the cache from a per-backend mapping (multi-backend path).
|
||||||
|
|
||||||
|
``backend_order`` is the routing priority: when several backends host
|
||||||
|
the same model, requests for it route to the first in this list.
|
||||||
|
"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._models = list(models)
|
self._per_backend = {name: list(models) for name, models in per_backend.items()}
|
||||||
self._names = names_of(models)
|
# Build a deduplicated combined record list, keeping the first
|
||||||
|
# occurrence (the highest-priority backend's metadata wins).
|
||||||
|
combined: list[DiscoveredModel] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
model_backends: dict[str, list[str]] = {}
|
||||||
|
for backend_name in backend_order:
|
||||||
|
for m in per_backend.get(backend_name, []):
|
||||||
|
model_backends.setdefault(m.name, []).append(backend_name)
|
||||||
|
if m.name not in seen:
|
||||||
|
combined.append(m)
|
||||||
|
seen.add(m.name)
|
||||||
|
self._models = combined
|
||||||
|
self._names = frozenset(seen)
|
||||||
|
self._model_backends = model_backends
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> frozenset[str]:
|
def names(self) -> frozenset[str]:
|
||||||
"""Current discovered model names (possibly empty ⇒ fail-closed)."""
|
"""Current discovered model names across all backends (∅ ⇒ fail-closed)."""
|
||||||
return self._names
|
return self._names
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def models(self) -> list[DiscoveredModel]:
|
def models(self) -> list[DiscoveredModel]:
|
||||||
"""Current discovered model records (copy)."""
|
"""Current de-duplicated discovered model records (copy)."""
|
||||||
return list(self._models)
|
return list(self._models)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def per_backend(self) -> dict[str, list[DiscoveredModel]]:
|
||||||
|
"""Per-backend discovered records (copy of the mapping)."""
|
||||||
|
return {name: list(models) for name, models in self._per_backend.items()}
|
||||||
|
|
||||||
|
def backends_for(self, model: str) -> list[str]:
|
||||||
|
"""Return the priority-ordered list of backends that host ``model``.
|
||||||
|
|
||||||
|
Empty list means no known backend has the model — routing should fail
|
||||||
|
with a generic 403 (the SPEC §13.6 no-existence-disclosure rule).
|
||||||
|
"""
|
||||||
|
return list(self._model_backends.get(model, ()))
|
||||||
|
|
||||||
|
|
||||||
async def write_discovered_to_redis(
|
async def write_discovered_to_redis(
|
||||||
client: redis.Redis, models: list[DiscoveredModel], ttl_s: int
|
client: redis.Redis, models: list[DiscoveredModel], ttl_s: int
|
||||||
@@ -174,16 +220,65 @@ async def refresh_once(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_all_backends(
|
||||||
|
http_clients: dict[str, httpx.AsyncClient],
|
||||||
|
backend_order: list[str],
|
||||||
|
redis_client: redis.Redis | None,
|
||||||
|
cache: DiscoveryCache,
|
||||||
|
settings: Settings,
|
||||||
|
) -> int:
|
||||||
|
"""Poll every backend in parallel and replace the cache atomically.
|
||||||
|
|
||||||
|
Returns the number of backends that responded with at least one model. A
|
||||||
|
backend that errors contributes an empty list (its existing entries in the
|
||||||
|
cache fall away on the next successful poll; fail-closed semantics are
|
||||||
|
preserved because the cache is replaced wholesale, not merged).
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _poll(name: str, client: httpx.AsyncClient) -> tuple[str, list[DiscoveredModel]]:
|
||||||
|
try:
|
||||||
|
return name, await fetch_tags(client)
|
||||||
|
except (httpx.HTTPError, ValueError) as exc:
|
||||||
|
_log.warning("discovery_refresh_failed", backend=name, error=str(exc))
|
||||||
|
return name, []
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*(_poll(name, http_clients[name]) for name in backend_order)
|
||||||
|
)
|
||||||
|
per_backend = dict(results)
|
||||||
|
await cache.set_per_backend(per_backend, backend_order)
|
||||||
|
|
||||||
|
healthy = sum(1 for _, models in results if models)
|
||||||
|
total_models = len(cache.names)
|
||||||
|
_log.info(
|
||||||
|
"discovery_refreshed_multi",
|
||||||
|
backends=len(backend_order),
|
||||||
|
healthy=healthy,
|
||||||
|
unique_models=total_models,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the de-duplicated combined set in Redis for the legacy single-key
|
||||||
|
# reader (used as a recovery snapshot on startup).
|
||||||
|
if redis_client is not None:
|
||||||
|
try:
|
||||||
|
await write_discovered_to_redis(
|
||||||
|
redis_client, cache.models, settings.model_discovery_cache_ttl_s
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001 - Redis is best-effort cache fill
|
||||||
|
_log.warning("discovery_cache_write_failed", error=str(exc))
|
||||||
|
return healthy
|
||||||
|
|
||||||
|
|
||||||
async def discovery_loop(
|
async def discovery_loop(
|
||||||
http_client: httpx.AsyncClient,
|
http_client: httpx.AsyncClient,
|
||||||
redis_client: redis.Redis | None,
|
redis_client: redis.Redis | None,
|
||||||
cache: DiscoveryCache,
|
cache: DiscoveryCache,
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Background poller: refresh now, then every ``MODEL_DISCOVERY_REFRESH_S``.
|
"""Single-backend background poller (legacy).
|
||||||
|
|
||||||
Designed to be launched via ``asyncio.create_task`` in the lifespan and
|
Use :func:`discovery_loop_multi` for the multi-backend path. This function
|
||||||
cancelled on shutdown.
|
is kept for the dependency-override tests that build a single client.
|
||||||
"""
|
"""
|
||||||
await refresh_once(http_client, redis_client, cache, settings)
|
await refresh_once(http_client, redis_client, cache, settings)
|
||||||
while True:
|
while True:
|
||||||
@@ -194,14 +289,38 @@ async def discovery_loop(
|
|||||||
await refresh_once(http_client, redis_client, cache, settings)
|
await refresh_once(http_client, redis_client, cache, settings)
|
||||||
|
|
||||||
|
|
||||||
|
async def discovery_loop_multi(
|
||||||
|
http_clients: dict[str, httpx.AsyncClient],
|
||||||
|
backend_order: list[str],
|
||||||
|
redis_client: redis.Redis | None,
|
||||||
|
cache: DiscoveryCache,
|
||||||
|
settings: Settings,
|
||||||
|
) -> None:
|
||||||
|
"""Multi-backend background poller.
|
||||||
|
|
||||||
|
Polls every configured backend in parallel each tick. Designed to be
|
||||||
|
launched via ``asyncio.create_task`` in the lifespan and cancelled on
|
||||||
|
shutdown.
|
||||||
|
"""
|
||||||
|
await refresh_all_backends(http_clients, backend_order, redis_client, cache, settings)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(settings.model_discovery_refresh_s)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
await refresh_all_backends(http_clients, backend_order, redis_client, cache, settings)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"REDIS_DISCOVERED_KEY",
|
"REDIS_DISCOVERED_KEY",
|
||||||
"DiscoveredModel",
|
"DiscoveredModel",
|
||||||
"DiscoveryCache",
|
"DiscoveryCache",
|
||||||
"discovery_loop",
|
"discovery_loop",
|
||||||
|
"discovery_loop_multi",
|
||||||
"fetch_tags",
|
"fetch_tags",
|
||||||
"names_of",
|
"names_of",
|
||||||
"read_discovered_from_redis",
|
"read_discovered_from_redis",
|
||||||
|
"refresh_all_backends",
|
||||||
"refresh_once",
|
"refresh_once",
|
||||||
"write_discovered_to_redis",
|
"write_discovered_to_redis",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -43,12 +43,14 @@ from neuronetz_gateway.errors import (
|
|||||||
BudgetExceededError,
|
BudgetExceededError,
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
RequestTooLargeError,
|
RequestTooLargeError,
|
||||||
|
UpstreamUnavailableError,
|
||||||
)
|
)
|
||||||
from neuronetz_gateway.observability import metrics
|
from neuronetz_gateway.observability import metrics
|
||||||
from neuronetz_gateway.observability.logging import get_logger
|
from neuronetz_gateway.observability.logging import get_logger
|
||||||
from neuronetz_gateway.proxy.allowlist import is_hard_blocked, is_model_allowed
|
from neuronetz_gateway.proxy.allowlist import is_hard_blocked, is_model_allowed
|
||||||
from neuronetz_gateway.proxy.discovery import DiscoveryCache
|
from neuronetz_gateway.proxy.discovery import DiscoveryCache
|
||||||
from neuronetz_gateway.proxy.ollama import OllamaClient
|
from neuronetz_gateway.proxy.ollama import OllamaClient
|
||||||
|
from neuronetz_gateway.proxy.router import BackendRouter
|
||||||
from neuronetz_gateway.proxy.token_counter import TokenUsage, extract_usage
|
from neuronetz_gateway.proxy.token_counter import TokenUsage, extract_usage
|
||||||
from neuronetz_gateway.ratelimit.concurrency import ConcurrencyLimiter
|
from neuronetz_gateway.ratelimit.concurrency import ConcurrencyLimiter
|
||||||
from neuronetz_gateway.ratelimit.sliding_window import SlidingWindowLimiter
|
from neuronetz_gateway.ratelimit.sliding_window import SlidingWindowLimiter
|
||||||
@@ -100,6 +102,7 @@ class Pipeline:
|
|||||||
budget: BudgetCounter,
|
budget: BudgetCounter,
|
||||||
audit: AuditWriter,
|
audit: AuditWriter,
|
||||||
sessionmaker: async_sessionmaker[AsyncSession] | None = None,
|
sessionmaker: async_sessionmaker[AsyncSession] | None = None,
|
||||||
|
router: BackendRouter | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._request = request
|
self._request = request
|
||||||
self._p = principal
|
self._p = principal
|
||||||
@@ -111,6 +114,10 @@ class Pipeline:
|
|||||||
self._budget = budget
|
self._budget = budget
|
||||||
self._audit = audit
|
self._audit = audit
|
||||||
self._sessionmaker = sessionmaker
|
self._sessionmaker = sessionmaker
|
||||||
|
# If a router is provided, requests use per-model backend selection +
|
||||||
|
# failover. Without a router we keep the legacy single-client behaviour
|
||||||
|
# (still useful for tests that override get_ollama_client directly).
|
||||||
|
self._router = router
|
||||||
self._request_id = str(getattr(request.state, "request_id", uuid.uuid4()))
|
self._request_id = str(getattr(request.state, "request_id", uuid.uuid4()))
|
||||||
self._concurrency_key = f"{_CONCURRENCY_PREFIX}{principal.tenant_id}"
|
self._concurrency_key = f"{_CONCURRENCY_PREFIX}{principal.tenant_id}"
|
||||||
self._headers = RateHeaders(
|
self._headers = RateHeaders(
|
||||||
@@ -234,23 +241,77 @@ class Pipeline:
|
|||||||
"""Render the §6.5 response headers."""
|
"""Render the §6.5 response headers."""
|
||||||
return self._headers.as_dict(self._request_id)
|
return self._headers.as_dict(self._request_id)
|
||||||
|
|
||||||
|
def _candidates(self, model: str) -> list[tuple[str, OllamaClient]]:
|
||||||
|
"""Return (backend_name, client) candidates for ``model``, in priority order.
|
||||||
|
|
||||||
|
With a router configured, this enumerates every backend that hosts the
|
||||||
|
model (so the proxy can fail over to the next one). Without a router,
|
||||||
|
it returns the legacy single client labelled "default".
|
||||||
|
"""
|
||||||
|
if self._router is None:
|
||||||
|
return [("default", self._ollama)]
|
||||||
|
candidates = [(bc.name, bc.client) for bc in self._router.clients_for_with_failover(model)]
|
||||||
|
return candidates or [("default", self._ollama)]
|
||||||
|
|
||||||
async def stream_native(
|
async def stream_native(
|
||||||
self, method: str, path: str, body: dict[str, object], model: str
|
self, method: str, path: str, body: dict[str, object], model: str
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
"""Proxy a streaming NDJSON request, accounting tokens after close."""
|
"""Proxy a streaming NDJSON request, accounting tokens after close.
|
||||||
|
|
||||||
|
Failover on streaming is best-effort: if the FIRST chunk fails before
|
||||||
|
any bytes have been yielded, we retry against the next backend. Once
|
||||||
|
bytes have started flowing we can't switch backends mid-stream, so a
|
||||||
|
later transport error surfaces as a truncated response (the post-close
|
||||||
|
bookkeeping still runs).
|
||||||
|
"""
|
||||||
started = time.monotonic()
|
started = time.monotonic()
|
||||||
media_type = NDJSON_MEDIA_TYPE
|
candidates = self._candidates(model)
|
||||||
|
|
||||||
async def gen() -> AsyncIterator[bytes]:
|
async def gen() -> AsyncIterator[bytes]:
|
||||||
last_obj: dict[str, object] = {}
|
last_obj: dict[str, object] = {}
|
||||||
try:
|
try:
|
||||||
async for chunk in self._ollama.stream(method, path, body):
|
iterator = await self._open_stream(method, path, body, candidates)
|
||||||
|
async for chunk in iterator:
|
||||||
last_obj = _merge_last_ndjson(chunk, last_obj)
|
last_obj = _merge_last_ndjson(chunk, last_obj)
|
||||||
yield chunk
|
yield chunk
|
||||||
finally:
|
finally:
|
||||||
await self._finish(model, path, method, last_obj, started)
|
await self._finish(model, path, method, last_obj, started)
|
||||||
|
|
||||||
return StreamingResponse(gen(), media_type=media_type, headers=self.headers())
|
return StreamingResponse(gen(), media_type=NDJSON_MEDIA_TYPE, headers=self.headers())
|
||||||
|
|
||||||
|
async def _open_stream(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
body: dict[str, object],
|
||||||
|
candidates: list[tuple[str, OllamaClient]],
|
||||||
|
) -> AsyncIterator[bytes]:
|
||||||
|
"""Open the upstream stream, trying each candidate backend in priority order.
|
||||||
|
|
||||||
|
Returns the async iterator from the first backend that doesn't raise on
|
||||||
|
the initial connect/send. If all candidates fail, raises
|
||||||
|
:class:`UpstreamUnavailableError`.
|
||||||
|
"""
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
for name, client in candidates:
|
||||||
|
try:
|
||||||
|
# ``OllamaClient.stream`` is an async generator: building the
|
||||||
|
# generator object doesn't make the request, but the FIRST
|
||||||
|
# ``__anext__`` does. Pre-pump one chunk to surface connect
|
||||||
|
# errors here so we can fail over without yielding partial data.
|
||||||
|
gen = client.stream(method, path, body)
|
||||||
|
first = await gen.__anext__()
|
||||||
|
return _replay_first_then(first, gen)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
# Empty stream: just return an exhausted iterator.
|
||||||
|
return _replay_first_then(b"", _empty())
|
||||||
|
except Exception as exc: # noqa: BLE001 - we retry on any transport error
|
||||||
|
_log.warning("backend_stream_failover", backend=name, error=str(exc))
|
||||||
|
last_exc = exc
|
||||||
|
continue
|
||||||
|
raise UpstreamUnavailableError(
|
||||||
|
internal_detail=f"all backends failed ({last_exc})"
|
||||||
|
)
|
||||||
|
|
||||||
async def stream_openai(
|
async def stream_openai(
|
||||||
self,
|
self,
|
||||||
@@ -262,12 +323,14 @@ class Pipeline:
|
|||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
"""Proxy + translate native NDJSON into OpenAI SSE; account after close."""
|
"""Proxy + translate native NDJSON into OpenAI SSE; account after close."""
|
||||||
started = time.monotonic()
|
started = time.monotonic()
|
||||||
|
candidates = self._candidates(model)
|
||||||
|
|
||||||
async def gen() -> AsyncIterator[bytes]:
|
async def gen() -> AsyncIterator[bytes]:
|
||||||
last_obj: dict[str, object] = {}
|
last_obj: dict[str, object] = {}
|
||||||
buffer = b""
|
buffer = b""
|
||||||
try:
|
try:
|
||||||
async for chunk in self._ollama.stream(method, path, body):
|
iterator = await self._open_stream(method, path, body, candidates)
|
||||||
|
async for chunk in iterator:
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
lines = buffer.split(b"\n")
|
lines = buffer.split(b"\n")
|
||||||
buffer = lines.pop()
|
buffer = lines.pop()
|
||||||
@@ -293,9 +356,8 @@ class Pipeline:
|
|||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""Proxy a non-streaming request and account tokens before responding."""
|
"""Proxy a non-streaming request and account tokens before responding."""
|
||||||
started = time.monotonic()
|
started = time.monotonic()
|
||||||
resp = await self._ollama.request(method, path, body)
|
candidates = self._candidates(model)
|
||||||
payload = resp.json()
|
obj = await self._request_with_failover(method, path, body, candidates)
|
||||||
obj = payload if isinstance(payload, dict) else {}
|
|
||||||
await self._finish(model, path, method, obj, started)
|
await self._finish(model, path, method, obj, started)
|
||||||
return JSONResponse(obj, headers=self.headers())
|
return JSONResponse(obj, headers=self.headers())
|
||||||
|
|
||||||
@@ -309,12 +371,33 @@ class Pipeline:
|
|||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""Proxy a non-streaming request, translate the body, then account."""
|
"""Proxy a non-streaming request, translate the body, then account."""
|
||||||
started = time.monotonic()
|
started = time.monotonic()
|
||||||
resp = await self._ollama.request(method, path, body)
|
candidates = self._candidates(model)
|
||||||
payload = resp.json()
|
obj = await self._request_with_failover(method, path, body, candidates)
|
||||||
obj = payload if isinstance(payload, dict) else {}
|
|
||||||
await self._finish(model, path, method, obj, started)
|
await self._finish(model, path, method, obj, started)
|
||||||
return JSONResponse(translator(obj), headers=self.headers())
|
return JSONResponse(translator(obj), headers=self.headers())
|
||||||
|
|
||||||
|
async def _request_with_failover(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
body: dict[str, object],
|
||||||
|
candidates: list[tuple[str, OllamaClient]],
|
||||||
|
) -> dict[str, object]:
|
||||||
|
"""Try each backend in order; return parsed JSON from the first success."""
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
for name, client in candidates:
|
||||||
|
try:
|
||||||
|
resp = await client.request(method, path, body)
|
||||||
|
payload = resp.json()
|
||||||
|
return payload if isinstance(payload, dict) else {}
|
||||||
|
except Exception as exc: # noqa: BLE001 - retry on any transport error
|
||||||
|
_log.warning("backend_request_failover", backend=name, error=str(exc))
|
||||||
|
last_exc = exc
|
||||||
|
continue
|
||||||
|
raise UpstreamUnavailableError(
|
||||||
|
internal_detail=f"all backends failed ({last_exc})"
|
||||||
|
)
|
||||||
|
|
||||||
async def _finish(
|
async def _finish(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
@@ -457,6 +540,22 @@ async def read_json_body(request: Request, settings: Settings) -> dict[str, obje
|
|||||||
return parsed if isinstance(parsed, dict) else {}
|
return parsed if isinstance(parsed, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
async def _empty() -> AsyncIterator[bytes]:
|
||||||
|
"""An async iterator that yields nothing — used as a no-op stream."""
|
||||||
|
if False: # pragma: no cover - never reached, just makes this an async generator
|
||||||
|
yield b""
|
||||||
|
|
||||||
|
|
||||||
|
async def _replay_first_then(
|
||||||
|
first: bytes, rest: AsyncIterator[bytes]
|
||||||
|
) -> AsyncIterator[bytes]:
|
||||||
|
"""Yield ``first`` (already-pulled probe chunk) and then drain ``rest``."""
|
||||||
|
if first:
|
||||||
|
yield first
|
||||||
|
async for chunk in rest:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"NDJSON_MEDIA_TYPE",
|
"NDJSON_MEDIA_TYPE",
|
||||||
"SSE_MEDIA_TYPE",
|
"SSE_MEDIA_TYPE",
|
||||||
|
|||||||
145
src/neuronetz_gateway/proxy/router.py
Normal file
145
src/neuronetz_gateway/proxy/router.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""Backend router: maps a request's model to the right Ollama backend.
|
||||||
|
|
||||||
|
In multi-backend deployments the gateway aggregates ``/api/tags`` across several
|
||||||
|
Ollama servers and routes each request to the **first** backend (in config
|
||||||
|
order) whose discovered model set includes the requested model. If that backend
|
||||||
|
errors, :meth:`BackendRouter.clients_for_with_failover` yields the next backend
|
||||||
|
that also has the model so the caller can retry without surfacing the failure
|
||||||
|
to the client.
|
||||||
|
|
||||||
|
Single-backend deployments still work: :func:`build_router` synthesizes a
|
||||||
|
single "default" backend from the legacy ``OLLAMA_BASE_URL`` settings, so a
|
||||||
|
gateway that hasn't opted into multi-backend behaves exactly like before.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from neuronetz_gateway.config import BackendSpec, Settings
|
||||||
|
from neuronetz_gateway.errors import UpstreamUnavailableError
|
||||||
|
from neuronetz_gateway.observability.logging import get_logger
|
||||||
|
from neuronetz_gateway.proxy.discovery import DiscoveryCache
|
||||||
|
from neuronetz_gateway.proxy.ollama import OllamaClient
|
||||||
|
|
||||||
|
_log = get_logger("router")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class BackendClient:
|
||||||
|
"""A backend ready to receive a proxied request."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
client: OllamaClient
|
||||||
|
|
||||||
|
|
||||||
|
class BackendRouter:
|
||||||
|
"""Holds the per-backend httpx clients and resolves model → backend."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clients: dict[str, OllamaClient],
|
||||||
|
order: list[str],
|
||||||
|
discovery: DiscoveryCache,
|
||||||
|
) -> None:
|
||||||
|
# Map insertion order matters for the failover path, so we keep `order`
|
||||||
|
# as the canonical priority list and `clients` indexed by name.
|
||||||
|
self._clients = clients
|
||||||
|
self._order = order
|
||||||
|
self._discovery = discovery
|
||||||
|
|
||||||
|
@property
|
||||||
|
def order(self) -> list[str]:
|
||||||
|
"""The configured backend priority order (copy)."""
|
||||||
|
return list(self._order)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def clients(self) -> dict[str, OllamaClient]:
|
||||||
|
"""All configured backend clients (copy of the mapping)."""
|
||||||
|
return dict(self._clients)
|
||||||
|
|
||||||
|
def first_for(self, model: str) -> BackendClient:
|
||||||
|
"""Return the highest-priority backend that hosts ``model``.
|
||||||
|
|
||||||
|
Raises :class:`UpstreamUnavailableError` if no backend currently lists
|
||||||
|
the model — which, per SPEC §13.6, is the same generic response we
|
||||||
|
return for "model not allowed", so callers should let the existing
|
||||||
|
authorization-error path handle the user-visible message instead.
|
||||||
|
"""
|
||||||
|
for backend in self._discovery.backends_for(model):
|
||||||
|
client = self._clients.get(backend)
|
||||||
|
if client is not None:
|
||||||
|
return BackendClient(name=backend, client=client)
|
||||||
|
raise UpstreamUnavailableError(
|
||||||
|
internal_detail=f"no backend hosts model {model!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def clients_for_with_failover(self, model: str) -> Iterator[BackendClient]:
|
||||||
|
"""Yield every backend that hosts ``model``, in priority order.
|
||||||
|
|
||||||
|
Callers can iterate to retry against the next backend if the first
|
||||||
|
attempt raises a transport error. Empty iteration ⇒ no backend.
|
||||||
|
"""
|
||||||
|
for backend in self._discovery.backends_for(model):
|
||||||
|
client = self._clients.get(backend)
|
||||||
|
if client is not None:
|
||||||
|
yield BackendClient(name=backend, client=client)
|
||||||
|
|
||||||
|
|
||||||
|
def build_backend_headers(backend: BackendSpec) -> dict[str, str]:
|
||||||
|
"""Compose default headers for one backend's httpx client.
|
||||||
|
|
||||||
|
Mirrors ``lifespan._build_upstream_headers`` but resolves against a
|
||||||
|
specific :class:`BackendSpec` so multi-backend gets its own per-backend
|
||||||
|
auth without all backends having to share one token.
|
||||||
|
"""
|
||||||
|
headers: dict[str, str] = {"User-Agent": "neuronetz-gateway"}
|
||||||
|
if backend.has_auth and backend.auth_token is not None:
|
||||||
|
raw = backend.auth_token.get_secret_value().strip()
|
||||||
|
if backend.auth_header.lower() == "authorization":
|
||||||
|
headers[backend.auth_header] = f"{backend.auth_scheme} {raw}".strip()
|
||||||
|
else:
|
||||||
|
headers[backend.auth_header] = raw
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def build_http_clients(settings: Settings) -> tuple[dict[str, httpx.AsyncClient], list[str]]:
|
||||||
|
"""Build one httpx client per configured backend.
|
||||||
|
|
||||||
|
Returns ``(clients, order)``: a name→client mapping and the priority-ordered
|
||||||
|
list of backend names. The caller is responsible for ``aclose()`` of every
|
||||||
|
client in the returned mapping during lifespan shutdown.
|
||||||
|
"""
|
||||||
|
backends = settings.effective_backends()
|
||||||
|
timeout = httpx.Timeout(
|
||||||
|
connect=settings.ollama_connect_timeout_s,
|
||||||
|
read=settings.ollama_read_timeout_s,
|
||||||
|
write=settings.ollama_read_timeout_s,
|
||||||
|
pool=settings.ollama_connect_timeout_s,
|
||||||
|
)
|
||||||
|
limits = httpx.Limits(max_connections=settings.ollama_max_connections)
|
||||||
|
clients: dict[str, httpx.AsyncClient] = {}
|
||||||
|
order: list[str] = []
|
||||||
|
for backend in backends:
|
||||||
|
if backend.name in clients:
|
||||||
|
_log.warning("duplicate_backend_name_skipped", name=backend.name)
|
||||||
|
continue
|
||||||
|
clients[backend.name] = httpx.AsyncClient(
|
||||||
|
base_url=backend.base_url,
|
||||||
|
timeout=timeout,
|
||||||
|
limits=limits,
|
||||||
|
headers=build_backend_headers(backend),
|
||||||
|
)
|
||||||
|
order.append(backend.name)
|
||||||
|
return clients, order
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BackendClient",
|
||||||
|
"BackendRouter",
|
||||||
|
"build_backend_headers",
|
||||||
|
"build_http_clients",
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user