proxy: multi-backend Ollama aggregation with per-model routing + failover
Some checks failed
CI / ruff (push) Has been cancelled
CI / mypy --strict (push) Has been cancelled
CI / pytest (push) Has been cancelled
CI / bandit (push) Has been cancelled
CI / pip-audit (push) Has been cancelled

The gateway can now aggregate models across SEVERAL Ollama backends and
route each request to the correct one. Opt-in via OLLAMA_BACKENDS in .env
— single-backend deployments are unaffected (effective_backends()
synthesizes a single "default" backend from the legacy OLLAMA_BASE_URL /
OLLAMA_AUTH_TOKEN fields when the list is empty).

Behavior:
- Discovery polls EVERY configured backend in parallel each tick; the
  cache stores per-backend model lists plus a model → backends priority
  list (config order = priority order).
- /api/tags and /v1/models surface the DEDUPLICATED UNION of all
  backends' models.
- A request's model is looked up in the priority list and proxied to the
  FIRST backend that hosts it. If that backend errors on the request, the
  pipeline transparently fails over to the next backend that has the
  same model (the streaming-failover probes the first chunk before
  releasing the response, so we never serve partial bytes from a dead
  backend).
- No existence disclosure: a model not hosted by any backend yields the
  same generic 403 as "model not allowed" (SPEC §13.6 preserved).

Components:
- config.py: new BackendSpec model + ollama_backends list field + an
  effective_backends() helper.
- proxy/router.py (new): BackendRouter (clients_for_with_failover),
  build_http_clients() builds one httpx client per backend with its own
  auth headers, build_backend_headers() exposes the per-backend header
  composition for the CLI probe.
- proxy/discovery.py: DiscoveryCache.set_per_backend() + backends_for(),
  refresh_all_backends() polls all in parallel, discovery_loop_multi()
  replaces the single-backend loop in production; the legacy single-
  backend functions are kept for the dependency-override tests.
- proxy/pipeline.py: Pipeline accepts an optional router; the four proxy
  methods now retry against each candidate backend in priority order on
  transport error.
- lifespan.py: constructs the per-backend client dict, stores the router
  on app.state, launches discovery_loop_multi.
- deps.py: get_backend_router provider + BackendRouterDep type alias;
  get_pipeline passes the router into Pipeline.
- cli/manage.py: probe-ollama iterates every backend and reports per-
  backend status; list-models groups its output by backend and prints
  the union count + Redis cache size for sanity.
- .env.example + docker-compose.yml: document and pass through
  OLLAMA_BACKENDS with a real example.

Verified: ruff check (clean), mypy --strict src/ + tests/ (clean,
66 source files), pytest (60 passed + 39 skipped — same baseline as
before this change; integration tests are Docker-socket-gated).
This commit is contained in:
Stephan Berbig
2026-05-27 22:30:26 +02:00
parent 5044a44a17
commit 653e03bf29
9 changed files with 607 additions and 61 deletions

View File

@@ -274,25 +274,62 @@ def list_models(
str | None, typer.Option("--tenant", help="Also show this tenant's effective set.")
] = None,
) -> None:
"""Show live-discovered models (and, with --tenant, the effective set)."""
"""Show live-discovered models per backend (and, with --tenant, the effective set).
Spins up a short-lived poll against every configured backend to surface
which models are currently live where, then (optionally) computes the
given tenant's effective allow-list intersection.
"""
import httpx
import redis.asyncio as redis
from neuronetz_gateway.config import BackendSpec
from neuronetz_gateway.proxy.discovery import fetch_tags, names_of
from neuronetz_gateway.proxy.router import build_backend_headers
async def _poll_one(backend: BackendSpec) -> tuple[str, list[str]]:
try:
async with httpx.AsyncClient(
base_url=backend.base_url,
timeout=httpx.Timeout(connect=5, read=10, write=10, pool=10),
headers=build_backend_headers(backend),
) as client:
return backend.name, sorted(names_of(await fetch_tags(client)))
except (httpx.HTTPError, ValueError) as exc:
typer.secho(
f" ({backend.name}: probe failed — {type(exc).__name__})",
fg=typer.colors.RED,
)
return backend.name, []
async def work(settings: Settings) -> None:
backends = settings.effective_backends()
results = [await _poll_one(b) for b in backends]
union: set[str] = set()
typer.echo("live-discovered models, by backend:")
for name, models in results:
typer.echo(f" [{name}]")
if not models:
typer.echo(" (none)")
for m in models:
typer.echo(f" {m}")
union.update(models)
typer.echo(f"\nunion across all backends: {len(union)} unique model(s)")
# Surface what Redis currently has cached too (for sanity vs the live poll).
client = redis.from_url(settings.redis_url, decode_responses=True)
try:
discovered = await read_discovered_from_redis(client)
cached = await read_discovered_from_redis(client)
finally:
await client.aclose()
discovered_names = sorted(discovered)
typer.echo("discovered models (live from Ollama via discovery cache):")
if discovered_names:
for name in discovered_names:
typer.echo(f" {name}")
else:
typer.echo(" (none — discovery cache empty or expired; requests fail closed)")
if cached:
typer.echo(
f"redis cache (gateway:models:discovered): {len(cached)} model(s)"
)
if tenant is None:
return
discovered = frozenset(union)
factory = create_session_factory(create_engine(settings))
async with session_scope(factory) as session:
tenants = TenantRepository(session)
@@ -319,29 +356,30 @@ def probe_ollama(
*,
timeout: Annotated[float, typer.Option(help="Per-request timeout in seconds.")] = 10.0,
) -> None:
"""Probe the upstream Ollama: GET /api/version and /api/tags.
"""Probe every configured Ollama backend (GET /api/version + /api/tags).
Uses the exact same httpx config as the running gateway (base URL, timeouts,
and the OLLAMA_AUTH_TOKEN header if set) so a passing probe proves the
gateway will be able to reach the backend in production. The token itself
is NEVER printed — only whether one was attached.
Iterates every entry returned by :meth:`Settings.effective_backends` (i.e.
the OLLAMA_BACKENDS JSON list, or a single synthesized backend from the
legacy OLLAMA_BASE_URL fields). Reports per-backend status and the first
5 model names; exits non-zero if any backend fails. Tokens are never
printed — only whether one was attached.
"""
import httpx
from neuronetz_gateway.lifespan import _build_upstream_headers
from neuronetz_gateway.config import BackendSpec
from neuronetz_gateway.proxy.router import build_backend_headers
settings = get_settings()
headers = _build_upstream_headers(settings)
auth_header = settings.ollama_auth_header
has_token = settings.ollama_auth_token is not None and bool(
settings.ollama_auth_token.get_secret_value().strip()
)
backends = settings.effective_backends()
auth_status = f"sending {auth_header}" if has_token else "no token (OLLAMA_AUTH_TOKEN unset)"
typer.echo(f"target: {settings.ollama_base_url}")
typer.echo(f"auth: {auth_status}")
async def _go() -> int:
async def _probe_one(backend: BackendSpec, idx: int, total: int) -> int:
marker = f"[{idx + 1}/{total}] {backend.name}"
typer.echo("")
typer.secho(marker, bold=True)
typer.echo(f" target: {backend.base_url}")
typer.echo(
f" auth: {('sending ' + backend.auth_header) if backend.has_auth else 'no token'}"
)
probe_timeout = httpx.Timeout(
connect=settings.ollama_connect_timeout_s,
read=timeout,
@@ -349,9 +387,9 @@ def probe_ollama(
pool=timeout,
)
async with httpx.AsyncClient(
base_url=settings.ollama_base_url,
base_url=backend.base_url,
timeout=probe_timeout,
headers=headers,
headers=build_backend_headers(backend),
) as client:
errors = 0
for path in ("/api/version", "/api/tags"):
@@ -371,8 +409,8 @@ def probe_ollama(
)
if resp.status_code in (401, 403):
typer.echo(
" upstream rejected the credentials — check "
"OLLAMA_AUTH_TOKEN / header."
" upstream rejected the credentials — check the "
"auth_token for this backend."
)
errors += 1
continue
@@ -392,9 +430,22 @@ def probe_ollama(
typer.echo(f" … and {n - 5} more")
return errors
async def _go() -> int:
total_errors = 0
for idx, backend in enumerate(backends):
total_errors += await _probe_one(backend, idx, len(backends))
return total_errors
errors = asyncio.run(_go())
typer.echo("")
if errors:
typer.secho(f"{errors} probe(s) failed.", fg=typer.colors.RED, bold=True)
raise typer.Exit(code=1)
typer.secho(
f"all {len(backends)} backend(s) reachable and authenticated.",
fg=typer.colors.GREEN,
bold=True,
)
typer.secho("upstream reachable and authenticated.", fg=typer.colors.GREEN, bold=True)