The gateway can now aggregate models across SEVERAL Ollama backends and route each request to the correct one. Opt-in via OLLAMA_BACKENDS in .env — single-backend deployments are unaffected (effective_backends() synthesizes a single "default" backend from the legacy OLLAMA_BASE_URL / OLLAMA_AUTH_TOKEN fields when the list is empty). Behavior: - Discovery polls EVERY configured backend in parallel each tick; the cache stores per-backend model lists plus a model → backends priority list (config order = priority order). - /api/tags and /v1/models surface the DEDUPLICATED UNION of all backends' models. - A request's model is looked up in the priority list and proxied to the FIRST backend that hosts it. If that backend errors on the request, the pipeline transparently fails over to the next backend that has the same model (the streaming-failover probes the first chunk before releasing the response, so we never serve partial bytes from a dead backend). - No existence disclosure: a model not hosted by any backend yields the same generic 403 as "model not allowed" (SPEC §13.6 preserved). Components: - config.py: new BackendSpec model + ollama_backends list field + an effective_backends() helper. - proxy/router.py (new): BackendRouter (clients_for_with_failover), build_http_clients() builds one httpx client per backend with its own auth headers, build_backend_headers() exposes the per-backend header composition for the CLI probe. - proxy/discovery.py: DiscoveryCache.set_per_backend() + backends_for(), refresh_all_backends() polls all in parallel, discovery_loop_multi() replaces the single-backend loop in production; the legacy single- backend functions are kept for the dependency-override tests. - proxy/pipeline.py: Pipeline accepts an optional router; the four proxy methods now retry against each candidate backend in priority order on transport error. - lifespan.py: constructs the per-backend client dict, stores the router on app.state, launches discovery_loop_multi. - deps.py: get_backend_router provider + BackendRouterDep type alias; get_pipeline passes the router into Pipeline. - cli/manage.py: probe-ollama iterates every backend and reports per- backend status; list-models groups its output by backend and prints the union count + Redis cache size for sanity. - .env.example + docker-compose.yml: document and pass through OLLAMA_BACKENDS with a real example. Verified: ruff check (clean), mypy --strict src/ + tests/ (clean, 66 source files), pytest (60 passed + 39 skipped — same baseline as before this change; integration tests are Docker-socket-gated).
181 lines
6.7 KiB
Python
181 lines
6.7 KiB
Python
"""Application lifespan: connect/dispose backends and run background tasks.
|
|
|
|
Startup connects Postgres + Redis + the upstream httpx client, builds the
|
|
argon2 hasher and the buffered audit writer, and launches the background tasks:
|
|
the model-discovery poller (SPEC §4.6) and the Postgres revocation NOTIFY
|
|
listener (SPEC §4.5). Connection failures are tolerated so ``/healthz`` always
|
|
serves; ``/readyz`` reports true readiness. All handles live on ``app.state``.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
from collections.abc import AsyncIterator
|
|
from contextlib import asynccontextmanager
|
|
from typing import TYPE_CHECKING
|
|
|
|
import httpx
|
|
import redis.asyncio as redis
|
|
|
|
from neuronetz_gateway.audit.writer import AuditWriter
|
|
from neuronetz_gateway.auth.hashing import build_hasher
|
|
from neuronetz_gateway.config import Settings, get_settings
|
|
from neuronetz_gateway.db.session import create_engine, create_session_factory
|
|
from neuronetz_gateway.observability.logging import get_logger
|
|
from neuronetz_gateway.proxy.discovery import DiscoveryCache, discovery_loop_multi
|
|
from neuronetz_gateway.proxy.ollama import OllamaClient
|
|
from neuronetz_gateway.proxy.router import BackendRouter, build_http_clients
|
|
from neuronetz_gateway.revocation import revocation_listener
|
|
|
|
if TYPE_CHECKING:
|
|
from fastapi import FastAPI
|
|
|
|
_log = get_logger("lifespan")
|
|
|
|
|
|
def _build_upstream_headers(settings: Settings) -> dict[str, str]:
|
|
"""Compose default headers for the upstream Ollama client.
|
|
|
|
If ``OLLAMA_AUTH_TOKEN`` is set, attach the configured auth header. The
|
|
scheme prefix (``Bearer``) is included only when the header is the standard
|
|
``Authorization``; for custom headers like ``X-API-Key`` the raw token is
|
|
sent. The SecretStr is unwrapped only here, never logged.
|
|
"""
|
|
headers: dict[str, str] = {"User-Agent": "neuronetz-gateway"}
|
|
if settings.ollama_auth_token is not None:
|
|
raw = settings.ollama_auth_token.get_secret_value().strip()
|
|
if raw:
|
|
header = settings.ollama_auth_header
|
|
if header.lower() == "authorization":
|
|
headers[header] = f"{settings.ollama_auth_scheme} {raw}".strip()
|
|
else:
|
|
headers[header] = raw
|
|
return headers
|
|
|
|
|
|
def _build_http_client(settings: Settings) -> httpx.AsyncClient:
|
|
"""Construct the shared httpx client used to reach Ollama."""
|
|
timeout = httpx.Timeout(
|
|
connect=settings.ollama_connect_timeout_s,
|
|
read=settings.ollama_read_timeout_s,
|
|
write=settings.ollama_read_timeout_s,
|
|
pool=settings.ollama_connect_timeout_s,
|
|
)
|
|
limits = httpx.Limits(max_connections=settings.ollama_max_connections)
|
|
return httpx.AsyncClient(
|
|
base_url=settings.ollama_base_url,
|
|
timeout=timeout,
|
|
limits=limits,
|
|
headers=_build_upstream_headers(settings),
|
|
)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|
"""Manage startup/shutdown of all backends and background tasks."""
|
|
settings: Settings = get_settings()
|
|
app.state.settings = settings
|
|
app.state.hasher = build_hasher(settings)
|
|
app.state.discovery_cache = DiscoveryCache()
|
|
tasks: list[asyncio.Task[None]] = []
|
|
|
|
try:
|
|
engine = create_engine(settings)
|
|
app.state.db_engine = engine
|
|
app.state.db_sessionmaker = create_session_factory(engine)
|
|
except Exception as exc: # noqa: BLE001 - tolerate so /healthz still serves
|
|
_log.error("db_engine_init_failed", error=str(exc))
|
|
app.state.db_engine = None
|
|
app.state.db_sessionmaker = None
|
|
|
|
try:
|
|
app.state.redis = redis.from_url(settings.redis_url, decode_responses=True)
|
|
except Exception as exc: # noqa: BLE001 - tolerate so /healthz still serves
|
|
_log.error("redis_init_failed", error=str(exc))
|
|
app.state.redis = None
|
|
|
|
# Build per-backend upstream clients (one per entry in OLLAMA_BACKENDS, or
|
|
# a single "default" backend synthesized from the legacy OLLAMA_BASE_URL).
|
|
backend_clients, backend_order = build_http_clients(settings)
|
|
app.state.backend_http_clients = backend_clients
|
|
app.state.backend_order = backend_order
|
|
# ``http_client`` retains its single-client meaning for code paths (and
|
|
# tests) that haven't been migrated to the router yet: it is the FIRST
|
|
# backend's httpx client. New code should reach upstream via the router.
|
|
app.state.http_client = backend_clients[backend_order[0]]
|
|
app.state.backend_router = BackendRouter(
|
|
clients={name: OllamaClient(client) for name, client in backend_clients.items()},
|
|
order=backend_order,
|
|
discovery=app.state.discovery_cache,
|
|
)
|
|
_log.info("backends_configured", backends=backend_order)
|
|
|
|
audit_writer = AuditWriter(settings.audit_buffer_size, app.state.db_sessionmaker)
|
|
audit_writer.start()
|
|
app.state.audit_writer = audit_writer
|
|
|
|
# Background tasks (cancelled on shutdown).
|
|
tasks.append(
|
|
asyncio.create_task(
|
|
discovery_loop_multi(
|
|
backend_clients,
|
|
backend_order,
|
|
app.state.redis,
|
|
app.state.discovery_cache,
|
|
settings,
|
|
)
|
|
)
|
|
)
|
|
if app.state.redis is not None and app.state.db_sessionmaker is not None:
|
|
tasks.append(
|
|
asyncio.create_task(
|
|
revocation_listener(settings, app.state.redis, app.state.db_sessionmaker)
|
|
)
|
|
)
|
|
app.state.background_tasks = tasks
|
|
|
|
_log.info("gateway_startup_complete")
|
|
try:
|
|
yield
|
|
finally:
|
|
await _shutdown(app, tasks, audit_writer)
|
|
|
|
|
|
async def _shutdown(
|
|
app: FastAPI, tasks: list[asyncio.Task[None]], audit_writer: AuditWriter
|
|
) -> None:
|
|
"""Cancel background tasks and dispose of all backend handles."""
|
|
for task in tasks:
|
|
task.cancel()
|
|
for task in tasks:
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await task
|
|
|
|
with contextlib.suppress(Exception):
|
|
await audit_writer.stop()
|
|
|
|
# Close every per-backend httpx client (the legacy `http_client` attr is
|
|
# one of these, so we only need to iterate the dict).
|
|
backend_clients: dict[str, httpx.AsyncClient] = getattr(
|
|
app.state, "backend_http_clients", {}
|
|
)
|
|
for client in backend_clients.values():
|
|
with contextlib.suppress(Exception):
|
|
await client.aclose()
|
|
|
|
redis_client = getattr(app.state, "redis", None)
|
|
if redis_client is not None:
|
|
with contextlib.suppress(Exception):
|
|
await redis_client.aclose()
|
|
|
|
engine = getattr(app.state, "db_engine", None)
|
|
if engine is not None:
|
|
with contextlib.suppress(Exception):
|
|
await engine.dispose()
|
|
|
|
_log.info("gateway_shutdown_complete")
|
|
|
|
|
|
__all__ = ["lifespan"]
|