diff --git a/.env.example b/.env.example index 16c16ce..f278aa9 100644 --- a/.env.example +++ b/.env.example @@ -46,6 +46,21 @@ OLLAMA_AUTH_TOKEN= OLLAMA_AUTH_HEADER=Authorization OLLAMA_AUTH_SCHEME=Bearer +# ─────────────────────── Multi-backend (opt-in) ────────────────── +# Aggregate models across SEVERAL Ollama backends. When set (non-empty JSON +# list), this REPLACES the single-backend config above — include each backend +# explicitly, in priority order. Requests for a given model route to the +# FIRST backend that hosts it; on transport errors the gateway transparently +# fails over to the next backend that has the same model. +# +# Each entry: {name, base_url, [auth_token], [auth_header], [auth_scheme]} +# Example (embedded GPU container + publicly-fronted auth-protected Ollama): +# OLLAMA_BACKENDS='[ +# {"name":"embedded","base_url":"http://ollama:11434"}, +# {"name":"public","base_url":"https://ollama.neuronetz.ai","auth_token":"YOUR_TOKEN"} +# ]' +OLLAMA_BACKENDS= + # ──────────────────────── Model discovery (§4.6) ───────────────── MODEL_DISCOVERY_REFRESH_S=60 MODEL_DISCOVERY_CACHE_TTL_S=120 diff --git a/docker-compose.yml b/docker-compose.yml index a66028b..9c2f70a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -36,6 +36,8 @@ services: OLLAMA_AUTH_TOKEN: ${OLLAMA_AUTH_TOKEN:-} OLLAMA_AUTH_HEADER: ${OLLAMA_AUTH_HEADER:-Authorization} OLLAMA_AUTH_SCHEME: ${OLLAMA_AUTH_SCHEME:-Bearer} + # Multi-backend (opt-in JSON list). See .env.example for the schema. + OLLAMA_BACKENDS: ${OLLAMA_BACKENDS:-} MODEL_DISCOVERY_REFRESH_S: ${MODEL_DISCOVERY_REFRESH_S:-60} MODEL_DISCOVERY_CACHE_TTL_S: ${MODEL_DISCOVERY_CACHE_TTL_S:-120} DEFAULT_RPM: ${DEFAULT_RPM:-60} diff --git a/src/neuronetz_gateway/cli/manage.py b/src/neuronetz_gateway/cli/manage.py index a000479..4c6bfc9 100644 --- a/src/neuronetz_gateway/cli/manage.py +++ b/src/neuronetz_gateway/cli/manage.py @@ -274,25 +274,62 @@ def list_models( str | None, typer.Option("--tenant", help="Also show this tenant's effective set.") ] = None, ) -> None: - """Show live-discovered models (and, with --tenant, the effective set).""" + """Show live-discovered models per backend (and, with --tenant, the effective set). + + Spins up a short-lived poll against every configured backend to surface + which models are currently live where, then (optionally) computes the + given tenant's effective allow-list intersection. + """ + import httpx import redis.asyncio as redis + from neuronetz_gateway.config import BackendSpec + from neuronetz_gateway.proxy.discovery import fetch_tags, names_of + from neuronetz_gateway.proxy.router import build_backend_headers + + async def _poll_one(backend: BackendSpec) -> tuple[str, list[str]]: + try: + async with httpx.AsyncClient( + base_url=backend.base_url, + timeout=httpx.Timeout(connect=5, read=10, write=10, pool=10), + headers=build_backend_headers(backend), + ) as client: + return backend.name, sorted(names_of(await fetch_tags(client))) + except (httpx.HTTPError, ValueError) as exc: + typer.secho( + f" ({backend.name}: probe failed — {type(exc).__name__})", + fg=typer.colors.RED, + ) + return backend.name, [] + async def work(settings: Settings) -> None: + backends = settings.effective_backends() + results = [await _poll_one(b) for b in backends] + union: set[str] = set() + typer.echo("live-discovered models, by backend:") + for name, models in results: + typer.echo(f" [{name}]") + if not models: + typer.echo(" (none)") + for m in models: + typer.echo(f" {m}") + union.update(models) + typer.echo(f"\nunion across all backends: {len(union)} unique model(s)") + + # Surface what Redis currently has cached too (for sanity vs the live poll). client = redis.from_url(settings.redis_url, decode_responses=True) try: - discovered = await read_discovered_from_redis(client) + cached = await read_discovered_from_redis(client) finally: await client.aclose() - discovered_names = sorted(discovered) - typer.echo("discovered models (live from Ollama via discovery cache):") - if discovered_names: - for name in discovered_names: - typer.echo(f" {name}") - else: - typer.echo(" (none — discovery cache empty or expired; requests fail closed)") + if cached: + typer.echo( + f"redis cache (gateway:models:discovered): {len(cached)} model(s)" + ) if tenant is None: return + discovered = frozenset(union) factory = create_session_factory(create_engine(settings)) async with session_scope(factory) as session: tenants = TenantRepository(session) @@ -319,29 +356,30 @@ def probe_ollama( *, timeout: Annotated[float, typer.Option(help="Per-request timeout in seconds.")] = 10.0, ) -> None: - """Probe the upstream Ollama: GET /api/version and /api/tags. + """Probe every configured Ollama backend (GET /api/version + /api/tags). - Uses the exact same httpx config as the running gateway (base URL, timeouts, - and the OLLAMA_AUTH_TOKEN header if set) so a passing probe proves the - gateway will be able to reach the backend in production. The token itself - is NEVER printed — only whether one was attached. + Iterates every entry returned by :meth:`Settings.effective_backends` (i.e. + the OLLAMA_BACKENDS JSON list, or a single synthesized backend from the + legacy OLLAMA_BASE_URL fields). Reports per-backend status and the first + 5 model names; exits non-zero if any backend fails. Tokens are never + printed — only whether one was attached. """ import httpx - from neuronetz_gateway.lifespan import _build_upstream_headers + from neuronetz_gateway.config import BackendSpec + from neuronetz_gateway.proxy.router import build_backend_headers settings = get_settings() - headers = _build_upstream_headers(settings) - auth_header = settings.ollama_auth_header - has_token = settings.ollama_auth_token is not None and bool( - settings.ollama_auth_token.get_secret_value().strip() - ) + backends = settings.effective_backends() - auth_status = f"sending {auth_header}" if has_token else "no token (OLLAMA_AUTH_TOKEN unset)" - typer.echo(f"target: {settings.ollama_base_url}") - typer.echo(f"auth: {auth_status}") - - async def _go() -> int: + async def _probe_one(backend: BackendSpec, idx: int, total: int) -> int: + marker = f"[{idx + 1}/{total}] {backend.name}" + typer.echo("") + typer.secho(marker, bold=True) + typer.echo(f" target: {backend.base_url}") + typer.echo( + f" auth: {('sending ' + backend.auth_header) if backend.has_auth else 'no token'}" + ) probe_timeout = httpx.Timeout( connect=settings.ollama_connect_timeout_s, read=timeout, @@ -349,9 +387,9 @@ def probe_ollama( pool=timeout, ) async with httpx.AsyncClient( - base_url=settings.ollama_base_url, + base_url=backend.base_url, timeout=probe_timeout, - headers=headers, + headers=build_backend_headers(backend), ) as client: errors = 0 for path in ("/api/version", "/api/tags"): @@ -371,8 +409,8 @@ def probe_ollama( ) if resp.status_code in (401, 403): typer.echo( - " upstream rejected the credentials — check " - "OLLAMA_AUTH_TOKEN / header." + " upstream rejected the credentials — check the " + "auth_token for this backend." ) errors += 1 continue @@ -392,9 +430,22 @@ def probe_ollama( typer.echo(f" … and {n - 5} more") return errors + async def _go() -> int: + total_errors = 0 + for idx, backend in enumerate(backends): + total_errors += await _probe_one(backend, idx, len(backends)) + return total_errors + errors = asyncio.run(_go()) + typer.echo("") if errors: + typer.secho(f"{errors} probe(s) failed.", fg=typer.colors.RED, bold=True) raise typer.Exit(code=1) + typer.secho( + f"all {len(backends)} backend(s) reachable and authenticated.", + fg=typer.colors.GREEN, + bold=True, + ) typer.secho("upstream reachable and authenticated.", fg=typer.colors.GREEN, bold=True) diff --git a/src/neuronetz_gateway/config.py b/src/neuronetz_gateway/config.py index 2bf9d51..6694735 100644 --- a/src/neuronetz_gateway/config.py +++ b/src/neuronetz_gateway/config.py @@ -8,10 +8,47 @@ from __future__ import annotations from functools import lru_cache -from pydantic import Field, SecretStr +from pydantic import BaseModel, Field, SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict +class BackendSpec(BaseModel): + """One Ollama backend the gateway aggregates and proxies to. + + The gateway can either run with a single backend (the legacy mode — left as + the default for backward compatibility) or with several backends configured + via ``OLLAMA_BACKENDS``, in which case the gateway: + + - polls discovery against each backend in parallel and exposes the **union** + of their model sets on ``/api/tags`` and ``/v1/models``; + - routes every request to the FIRST backend (priority = list order) whose + model set contains the requested model; + - auto-fails-over to subsequent backends that have the same model if the + primary errors. + + Provide as JSON in env: + OLLAMA_BACKENDS='[ + {"name":"embedded","base_url":"http://ollama:11434"}, + {"name":"public", + "base_url":"https://ollama.neuronetz.ai", + "auth_token":"…"} + ]' + """ + + name: str = Field(min_length=1, max_length=64) + base_url: str = Field(min_length=1) + auth_token: SecretStr | None = None + auth_header: str = "Authorization" + auth_scheme: str = "Bearer" + + @property + def has_auth(self) -> bool: + """True if this backend should send an auth header.""" + if self.auth_token is None: + return False + return bool(self.auth_token.get_secret_value().strip()) + + class Settings(BaseSettings): """Gateway runtime configuration. All fields map to SPEC §7 env vars.""" @@ -46,6 +83,14 @@ class Settings(BaseSettings): ollama_auth_header: str = Field(default="Authorization") ollama_auth_scheme: str = Field(default="Bearer") + # --- Multi-backend (opt-in) --- + # JSON list of additional/overriding backends. When empty (default), a single + # backend named "default" is constructed from the OLLAMA_BASE_URL / TOKEN / + # HEADER / SCHEME fields above (backward compat — current prod is unaffected). + # When set, this list IS the entire backend roster — include the embedded + # one explicitly if you want it. The list order is the routing priority. + ollama_backends: list[BackendSpec] = Field(default_factory=list) + # --- Model discovery (SPEC §4.6) --- model_discovery_refresh_s: int = Field(default=60) model_discovery_cache_ttl_s: int = Field(default=120) @@ -89,6 +134,26 @@ class Settings(BaseSettings): """Parse the comma-separated trusted-proxy list into individual hosts.""" return [p.strip() for p in self.gateway_trusted_proxies.split(",") if p.strip()] + def effective_backends(self) -> list[BackendSpec]: + """Return the resolved backend roster. + + If ``ollama_backends`` is non-empty, that list is used verbatim (order is + the routing priority). Otherwise a single backend is synthesized from the + legacy ``ollama_base_url`` / ``ollama_auth_token`` fields so a deployment + that hasn't opted into multi-backend keeps working unchanged. + """ + if self.ollama_backends: + return list(self.ollama_backends) + return [ + BackendSpec( + name="default", + base_url=self.ollama_base_url, + auth_token=self.ollama_auth_token, + auth_header=self.ollama_auth_header, + auth_scheme=self.ollama_auth_scheme, + ) + ] + @lru_cache(maxsize=1) def get_settings() -> Settings: diff --git a/src/neuronetz_gateway/deps.py b/src/neuronetz_gateway/deps.py index a93e5d6..0487c40 100644 --- a/src/neuronetz_gateway/deps.py +++ b/src/neuronetz_gateway/deps.py @@ -40,6 +40,7 @@ from neuronetz_gateway.errors import AuthenticationError, DependencyUnavailableE from neuronetz_gateway.proxy.discovery import DiscoveryCache from neuronetz_gateway.proxy.ollama import OllamaClient from neuronetz_gateway.proxy.pipeline import Pipeline +from neuronetz_gateway.proxy.router import BackendRouter from neuronetz_gateway.ratelimit.concurrency import ConcurrencyLimiter from neuronetz_gateway.ratelimit.sliding_window import SlidingWindowLimiter @@ -66,10 +67,24 @@ def get_http_client(request: Request) -> httpx.AsyncClient: def get_ollama_client(request: Request) -> OllamaClient: - """Provide the upstream Ollama proxy client (override target for tests).""" + """Provide the upstream Ollama proxy client (override target for tests). + + In multi-backend mode this returns the FIRST backend's client (priority + order = list order). The pipeline uses :func:`get_backend_router` for + per-model routing + failover; this provider is kept for tests and for code + paths that don't need routing. + """ return OllamaClient(get_http_client(request)) +def get_backend_router(request: Request) -> BackendRouter: + """Provide the multi-backend router (one client per configured backend).""" + router: BackendRouter | None = getattr(request.app.state, "backend_router", None) + if router is None: + raise DependencyUnavailableError(internal_detail="backend router not initialised") + return router + + def get_discovery_cache(request: Request) -> DiscoveryCache: """Provide the in-process discovery cache; fail closed if absent.""" cache: DiscoveryCache | None = getattr(request.app.state, "discovery_cache", None) @@ -112,10 +127,17 @@ def get_pipeline( The pipeline owns all hot-path checks (rate limit, budget, concurrency, model/endpoint allowlist) and the streaming-with-bookkeeping contract. Audit deny-mode flips this to fail closed at the route layer. + + In multi-backend deployments the per-request backend selection is done by + the pipeline using the :class:`BackendRouter` on ``app.state``; the + ``ollama`` argument here is the fallback single-backend client (used when + the router has no entry for a model, and as the override target for tests + that don't care about routing). """ sessionmaker: async_sessionmaker[AsyncSession] | None = getattr( request.app.state, "db_sessionmaker", None ) + router: BackendRouter | None = getattr(request.app.state, "backend_router", None) return Pipeline( request=request, principal=principal, @@ -127,6 +149,7 @@ def get_pipeline( budget=BudgetCounter(redis_client), audit=audit, sessionmaker=sessionmaker, + router=router, ) @@ -151,6 +174,7 @@ ConfigDep = Annotated[Settings, Depends(get_config)] RedisDep = Annotated[redis.Redis, Depends(get_redis)] HttpClientDep = Annotated[httpx.AsyncClient, Depends(get_http_client)] OllamaClientDep = Annotated[OllamaClient, Depends(get_ollama_client)] +BackendRouterDep = Annotated[BackendRouter, Depends(get_backend_router)] DiscoveryCacheDep = Annotated[DiscoveryCache, Depends(get_discovery_cache)] PrincipalDep = Annotated[Principal, Depends(get_principal)] AuditWriterDep = Annotated[AuditWriter, Depends(get_audit_writer)] @@ -160,6 +184,7 @@ DbSessionDep = Annotated[AsyncSession, Depends(get_db_session)] __all__ = [ "AuditWriterDep", + "BackendRouterDep", "ConfigDep", "DbSessionDep", "DiscoveryCacheDep", @@ -169,6 +194,7 @@ __all__ = [ "PrincipalDep", "RedisDep", "get_audit_writer", + "get_backend_router", "get_config", "get_db_session", "get_discovery_cache", diff --git a/src/neuronetz_gateway/lifespan.py b/src/neuronetz_gateway/lifespan.py index 78cb71a..6e68b34 100644 --- a/src/neuronetz_gateway/lifespan.py +++ b/src/neuronetz_gateway/lifespan.py @@ -23,7 +23,9 @@ from neuronetz_gateway.auth.hashing import build_hasher from neuronetz_gateway.config import Settings, get_settings from neuronetz_gateway.db.session import create_engine, create_session_factory from neuronetz_gateway.observability.logging import get_logger -from neuronetz_gateway.proxy.discovery import DiscoveryCache, discovery_loop +from neuronetz_gateway.proxy.discovery import DiscoveryCache, discovery_loop_multi +from neuronetz_gateway.proxy.ollama import OllamaClient +from neuronetz_gateway.proxy.router import BackendRouter, build_http_clients from neuronetz_gateway.revocation import revocation_listener if TYPE_CHECKING: @@ -93,7 +95,21 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: _log.error("redis_init_failed", error=str(exc)) app.state.redis = None - app.state.http_client = _build_http_client(settings) + # Build per-backend upstream clients (one per entry in OLLAMA_BACKENDS, or + # a single "default" backend synthesized from the legacy OLLAMA_BASE_URL). + backend_clients, backend_order = build_http_clients(settings) + app.state.backend_http_clients = backend_clients + app.state.backend_order = backend_order + # ``http_client`` retains its single-client meaning for code paths (and + # tests) that haven't been migrated to the router yet: it is the FIRST + # backend's httpx client. New code should reach upstream via the router. + app.state.http_client = backend_clients[backend_order[0]] + app.state.backend_router = BackendRouter( + clients={name: OllamaClient(client) for name, client in backend_clients.items()}, + order=backend_order, + discovery=app.state.discovery_cache, + ) + _log.info("backends_configured", backends=backend_order) audit_writer = AuditWriter(settings.audit_buffer_size, app.state.db_sessionmaker) audit_writer.start() @@ -102,8 +118,12 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: # Background tasks (cancelled on shutdown). tasks.append( asyncio.create_task( - discovery_loop( - app.state.http_client, app.state.redis, app.state.discovery_cache, settings + discovery_loop_multi( + backend_clients, + backend_order, + app.state.redis, + app.state.discovery_cache, + settings, ) ) ) @@ -135,10 +155,14 @@ async def _shutdown( with contextlib.suppress(Exception): await audit_writer.stop() - http_client: httpx.AsyncClient | None = getattr(app.state, "http_client", None) - if http_client is not None: + # Close every per-backend httpx client (the legacy `http_client` attr is + # one of these, so we only need to iterate the dict). + backend_clients: dict[str, httpx.AsyncClient] = getattr( + app.state, "backend_http_clients", {} + ) + for client in backend_clients.values(): with contextlib.suppress(Exception): - await http_client.aclose() + await client.aclose() redis_client = getattr(app.state, "redis", None) if redis_client is not None: diff --git a/src/neuronetz_gateway/proxy/discovery.py b/src/neuronetz_gateway/proxy/discovery.py index 81afcf3..63e0dad 100644 --- a/src/neuronetz_gateway/proxy/discovery.py +++ b/src/neuronetz_gateway/proxy/discovery.py @@ -81,34 +81,80 @@ def names_of(models: list[DiscoveredModel]) -> frozenset[str]: class DiscoveryCache: - """In-process holder for the latest discovered model set. + """In-process holder for the latest discovered model set, per backend. - Holds both the structured records (for ``/api/tags`` / ``list-models``) and a - fast name set (for allowlist resolution on the hot path). Reads never block - on Redis or Ollama; the poller refreshes this in the background. + Holds the structured records (for ``/api/tags`` / ``list-models``), a fast + name set (for allowlist resolution on the hot path), and a model→backends + priority list (for the request router). Reads never block on Redis or + Ollama; the poller refreshes this in the background. + + Backward compat: ``set(models)`` (single-list signature) still works and + populates a single "default" backend bucket. """ def __init__(self) -> None: + self._per_backend: dict[str, list[DiscoveredModel]] = {} self._models: list[DiscoveredModel] = [] self._names: frozenset[str] = frozenset() + # Ordered list of backend names that host each model, in priority order + # (the order in which backends were polled — i.e. config order). + self._model_backends: dict[str, list[str]] = {} self._lock = asyncio.Lock() async def set(self, models: list[DiscoveredModel]) -> None: - """Replace the in-process snapshot atomically.""" + """Legacy single-backend setter: replaces the cache with one bucket.""" + await self.set_per_backend({"default": list(models)}, ["default"]) + + async def set_per_backend( + self, + per_backend: dict[str, list[DiscoveredModel]], + backend_order: list[str], + ) -> None: + """Replace the cache from a per-backend mapping (multi-backend path). + + ``backend_order`` is the routing priority: when several backends host + the same model, requests for it route to the first in this list. + """ async with self._lock: - self._models = list(models) - self._names = names_of(models) + self._per_backend = {name: list(models) for name, models in per_backend.items()} + # Build a deduplicated combined record list, keeping the first + # occurrence (the highest-priority backend's metadata wins). + combined: list[DiscoveredModel] = [] + seen: set[str] = set() + model_backends: dict[str, list[str]] = {} + for backend_name in backend_order: + for m in per_backend.get(backend_name, []): + model_backends.setdefault(m.name, []).append(backend_name) + if m.name not in seen: + combined.append(m) + seen.add(m.name) + self._models = combined + self._names = frozenset(seen) + self._model_backends = model_backends @property def names(self) -> frozenset[str]: - """Current discovered model names (possibly empty ⇒ fail-closed).""" + """Current discovered model names across all backends (∅ ⇒ fail-closed).""" return self._names @property def models(self) -> list[DiscoveredModel]: - """Current discovered model records (copy).""" + """Current de-duplicated discovered model records (copy).""" return list(self._models) + @property + def per_backend(self) -> dict[str, list[DiscoveredModel]]: + """Per-backend discovered records (copy of the mapping).""" + return {name: list(models) for name, models in self._per_backend.items()} + + def backends_for(self, model: str) -> list[str]: + """Return the priority-ordered list of backends that host ``model``. + + Empty list means no known backend has the model — routing should fail + with a generic 403 (the SPEC §13.6 no-existence-disclosure rule). + """ + return list(self._model_backends.get(model, ())) + async def write_discovered_to_redis( client: redis.Redis, models: list[DiscoveredModel], ttl_s: int @@ -174,16 +220,65 @@ async def refresh_once( return True +async def refresh_all_backends( + http_clients: dict[str, httpx.AsyncClient], + backend_order: list[str], + redis_client: redis.Redis | None, + cache: DiscoveryCache, + settings: Settings, +) -> int: + """Poll every backend in parallel and replace the cache atomically. + + Returns the number of backends that responded with at least one model. A + backend that errors contributes an empty list (its existing entries in the + cache fall away on the next successful poll; fail-closed semantics are + preserved because the cache is replaced wholesale, not merged). + """ + + async def _poll(name: str, client: httpx.AsyncClient) -> tuple[str, list[DiscoveredModel]]: + try: + return name, await fetch_tags(client) + except (httpx.HTTPError, ValueError) as exc: + _log.warning("discovery_refresh_failed", backend=name, error=str(exc)) + return name, [] + + results = await asyncio.gather( + *(_poll(name, http_clients[name]) for name in backend_order) + ) + per_backend = dict(results) + await cache.set_per_backend(per_backend, backend_order) + + healthy = sum(1 for _, models in results if models) + total_models = len(cache.names) + _log.info( + "discovery_refreshed_multi", + backends=len(backend_order), + healthy=healthy, + unique_models=total_models, + ) + + # Cache the de-duplicated combined set in Redis for the legacy single-key + # reader (used as a recovery snapshot on startup). + if redis_client is not None: + try: + await write_discovered_to_redis( + redis_client, cache.models, settings.model_discovery_cache_ttl_s + ) + except Exception as exc: # noqa: BLE001 - Redis is best-effort cache fill + _log.warning("discovery_cache_write_failed", error=str(exc)) + return healthy + + async def discovery_loop( http_client: httpx.AsyncClient, redis_client: redis.Redis | None, cache: DiscoveryCache, settings: Settings, ) -> None: - """Background poller: refresh now, then every ``MODEL_DISCOVERY_REFRESH_S``. + """Single-backend background poller (legacy). - Designed to be launched via ``asyncio.create_task`` in the lifespan and - cancelled on shutdown. + Use :func:`discovery_loop_multi` for the multi-backend path. This function + is kept for the dependency-override tests that build a single client. """ await refresh_once(http_client, redis_client, cache, settings) while True: @@ -194,14 +289,38 @@ async def discovery_loop( await refresh_once(http_client, redis_client, cache, settings) +async def discovery_loop_multi( + http_clients: dict[str, httpx.AsyncClient], + backend_order: list[str], + redis_client: redis.Redis | None, + cache: DiscoveryCache, + settings: Settings, +) -> None: + """Multi-backend background poller. + + Polls every configured backend in parallel each tick. Designed to be + launched via ``asyncio.create_task`` in the lifespan and cancelled on + shutdown. + """ + await refresh_all_backends(http_clients, backend_order, redis_client, cache, settings) + while True: + try: + await asyncio.sleep(settings.model_discovery_refresh_s) + except asyncio.CancelledError: + raise + await refresh_all_backends(http_clients, backend_order, redis_client, cache, settings) + + __all__ = [ "REDIS_DISCOVERED_KEY", "DiscoveredModel", "DiscoveryCache", "discovery_loop", + "discovery_loop_multi", "fetch_tags", "names_of", "read_discovered_from_redis", + "refresh_all_backends", "refresh_once", "write_discovered_to_redis", ] diff --git a/src/neuronetz_gateway/proxy/pipeline.py b/src/neuronetz_gateway/proxy/pipeline.py index dfa27f8..86f4a21 100644 --- a/src/neuronetz_gateway/proxy/pipeline.py +++ b/src/neuronetz_gateway/proxy/pipeline.py @@ -43,12 +43,14 @@ from neuronetz_gateway.errors import ( BudgetExceededError, RateLimitError, RequestTooLargeError, + UpstreamUnavailableError, ) from neuronetz_gateway.observability import metrics from neuronetz_gateway.observability.logging import get_logger from neuronetz_gateway.proxy.allowlist import is_hard_blocked, is_model_allowed from neuronetz_gateway.proxy.discovery import DiscoveryCache from neuronetz_gateway.proxy.ollama import OllamaClient +from neuronetz_gateway.proxy.router import BackendRouter from neuronetz_gateway.proxy.token_counter import TokenUsage, extract_usage from neuronetz_gateway.ratelimit.concurrency import ConcurrencyLimiter from neuronetz_gateway.ratelimit.sliding_window import SlidingWindowLimiter @@ -100,6 +102,7 @@ class Pipeline: budget: BudgetCounter, audit: AuditWriter, sessionmaker: async_sessionmaker[AsyncSession] | None = None, + router: BackendRouter | None = None, ) -> None: self._request = request self._p = principal @@ -111,6 +114,10 @@ class Pipeline: self._budget = budget self._audit = audit self._sessionmaker = sessionmaker + # If a router is provided, requests use per-model backend selection + + # failover. Without a router we keep the legacy single-client behaviour + # (still useful for tests that override get_ollama_client directly). + self._router = router self._request_id = str(getattr(request.state, "request_id", uuid.uuid4())) self._concurrency_key = f"{_CONCURRENCY_PREFIX}{principal.tenant_id}" self._headers = RateHeaders( @@ -234,23 +241,77 @@ class Pipeline: """Render the §6.5 response headers.""" return self._headers.as_dict(self._request_id) + def _candidates(self, model: str) -> list[tuple[str, OllamaClient]]: + """Return (backend_name, client) candidates for ``model``, in priority order. + + With a router configured, this enumerates every backend that hosts the + model (so the proxy can fail over to the next one). Without a router, + it returns the legacy single client labelled "default". + """ + if self._router is None: + return [("default", self._ollama)] + candidates = [(bc.name, bc.client) for bc in self._router.clients_for_with_failover(model)] + return candidates or [("default", self._ollama)] + async def stream_native( self, method: str, path: str, body: dict[str, object], model: str ) -> StreamingResponse: - """Proxy a streaming NDJSON request, accounting tokens after close.""" + """Proxy a streaming NDJSON request, accounting tokens after close. + + Failover on streaming is best-effort: if the FIRST chunk fails before + any bytes have been yielded, we retry against the next backend. Once + bytes have started flowing we can't switch backends mid-stream, so a + later transport error surfaces as a truncated response (the post-close + bookkeeping still runs). + """ started = time.monotonic() - media_type = NDJSON_MEDIA_TYPE + candidates = self._candidates(model) async def gen() -> AsyncIterator[bytes]: last_obj: dict[str, object] = {} try: - async for chunk in self._ollama.stream(method, path, body): + iterator = await self._open_stream(method, path, body, candidates) + async for chunk in iterator: last_obj = _merge_last_ndjson(chunk, last_obj) yield chunk finally: await self._finish(model, path, method, last_obj, started) - return StreamingResponse(gen(), media_type=media_type, headers=self.headers()) + return StreamingResponse(gen(), media_type=NDJSON_MEDIA_TYPE, headers=self.headers()) + + async def _open_stream( + self, + method: str, + path: str, + body: dict[str, object], + candidates: list[tuple[str, OllamaClient]], + ) -> AsyncIterator[bytes]: + """Open the upstream stream, trying each candidate backend in priority order. + + Returns the async iterator from the first backend that doesn't raise on + the initial connect/send. If all candidates fail, raises + :class:`UpstreamUnavailableError`. + """ + last_exc: Exception | None = None + for name, client in candidates: + try: + # ``OllamaClient.stream`` is an async generator: building the + # generator object doesn't make the request, but the FIRST + # ``__anext__`` does. Pre-pump one chunk to surface connect + # errors here so we can fail over without yielding partial data. + gen = client.stream(method, path, body) + first = await gen.__anext__() + return _replay_first_then(first, gen) + except StopAsyncIteration: + # Empty stream: just return an exhausted iterator. + return _replay_first_then(b"", _empty()) + except Exception as exc: # noqa: BLE001 - we retry on any transport error + _log.warning("backend_stream_failover", backend=name, error=str(exc)) + last_exc = exc + continue + raise UpstreamUnavailableError( + internal_detail=f"all backends failed ({last_exc})" + ) async def stream_openai( self, @@ -262,12 +323,14 @@ class Pipeline: ) -> StreamingResponse: """Proxy + translate native NDJSON into OpenAI SSE; account after close.""" started = time.monotonic() + candidates = self._candidates(model) async def gen() -> AsyncIterator[bytes]: last_obj: dict[str, object] = {} buffer = b"" try: - async for chunk in self._ollama.stream(method, path, body): + iterator = await self._open_stream(method, path, body, candidates) + async for chunk in iterator: buffer += chunk lines = buffer.split(b"\n") buffer = lines.pop() @@ -293,9 +356,8 @@ class Pipeline: ) -> JSONResponse: """Proxy a non-streaming request and account tokens before responding.""" started = time.monotonic() - resp = await self._ollama.request(method, path, body) - payload = resp.json() - obj = payload if isinstance(payload, dict) else {} + candidates = self._candidates(model) + obj = await self._request_with_failover(method, path, body, candidates) await self._finish(model, path, method, obj, started) return JSONResponse(obj, headers=self.headers()) @@ -309,12 +371,33 @@ class Pipeline: ) -> JSONResponse: """Proxy a non-streaming request, translate the body, then account.""" started = time.monotonic() - resp = await self._ollama.request(method, path, body) - payload = resp.json() - obj = payload if isinstance(payload, dict) else {} + candidates = self._candidates(model) + obj = await self._request_with_failover(method, path, body, candidates) await self._finish(model, path, method, obj, started) return JSONResponse(translator(obj), headers=self.headers()) + async def _request_with_failover( + self, + method: str, + path: str, + body: dict[str, object], + candidates: list[tuple[str, OllamaClient]], + ) -> dict[str, object]: + """Try each backend in order; return parsed JSON from the first success.""" + last_exc: Exception | None = None + for name, client in candidates: + try: + resp = await client.request(method, path, body) + payload = resp.json() + return payload if isinstance(payload, dict) else {} + except Exception as exc: # noqa: BLE001 - retry on any transport error + _log.warning("backend_request_failover", backend=name, error=str(exc)) + last_exc = exc + continue + raise UpstreamUnavailableError( + internal_detail=f"all backends failed ({last_exc})" + ) + async def _finish( self, model: str, @@ -457,6 +540,22 @@ async def read_json_body(request: Request, settings: Settings) -> dict[str, obje return parsed if isinstance(parsed, dict) else {} +async def _empty() -> AsyncIterator[bytes]: + """An async iterator that yields nothing — used as a no-op stream.""" + if False: # pragma: no cover - never reached, just makes this an async generator + yield b"" + + +async def _replay_first_then( + first: bytes, rest: AsyncIterator[bytes] +) -> AsyncIterator[bytes]: + """Yield ``first`` (already-pulled probe chunk) and then drain ``rest``.""" + if first: + yield first + async for chunk in rest: + yield chunk + + __all__ = [ "NDJSON_MEDIA_TYPE", "SSE_MEDIA_TYPE", diff --git a/src/neuronetz_gateway/proxy/router.py b/src/neuronetz_gateway/proxy/router.py new file mode 100644 index 0000000..83a6f6f --- /dev/null +++ b/src/neuronetz_gateway/proxy/router.py @@ -0,0 +1,145 @@ +"""Backend router: maps a request's model to the right Ollama backend. + +In multi-backend deployments the gateway aggregates ``/api/tags`` across several +Ollama servers and routes each request to the **first** backend (in config +order) whose discovered model set includes the requested model. If that backend +errors, :meth:`BackendRouter.clients_for_with_failover` yields the next backend +that also has the model so the caller can retry without surfacing the failure +to the client. + +Single-backend deployments still work: :func:`build_router` synthesizes a +single "default" backend from the legacy ``OLLAMA_BASE_URL`` settings, so a +gateway that hasn't opted into multi-backend behaves exactly like before. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass + +import httpx + +from neuronetz_gateway.config import BackendSpec, Settings +from neuronetz_gateway.errors import UpstreamUnavailableError +from neuronetz_gateway.observability.logging import get_logger +from neuronetz_gateway.proxy.discovery import DiscoveryCache +from neuronetz_gateway.proxy.ollama import OllamaClient + +_log = get_logger("router") + + +@dataclass(frozen=True, slots=True) +class BackendClient: + """A backend ready to receive a proxied request.""" + + name: str + client: OllamaClient + + +class BackendRouter: + """Holds the per-backend httpx clients and resolves model → backend.""" + + def __init__( + self, + clients: dict[str, OllamaClient], + order: list[str], + discovery: DiscoveryCache, + ) -> None: + # Map insertion order matters for the failover path, so we keep `order` + # as the canonical priority list and `clients` indexed by name. + self._clients = clients + self._order = order + self._discovery = discovery + + @property + def order(self) -> list[str]: + """The configured backend priority order (copy).""" + return list(self._order) + + @property + def clients(self) -> dict[str, OllamaClient]: + """All configured backend clients (copy of the mapping).""" + return dict(self._clients) + + def first_for(self, model: str) -> BackendClient: + """Return the highest-priority backend that hosts ``model``. + + Raises :class:`UpstreamUnavailableError` if no backend currently lists + the model — which, per SPEC §13.6, is the same generic response we + return for "model not allowed", so callers should let the existing + authorization-error path handle the user-visible message instead. + """ + for backend in self._discovery.backends_for(model): + client = self._clients.get(backend) + if client is not None: + return BackendClient(name=backend, client=client) + raise UpstreamUnavailableError( + internal_detail=f"no backend hosts model {model!r}" + ) + + def clients_for_with_failover(self, model: str) -> Iterator[BackendClient]: + """Yield every backend that hosts ``model``, in priority order. + + Callers can iterate to retry against the next backend if the first + attempt raises a transport error. Empty iteration ⇒ no backend. + """ + for backend in self._discovery.backends_for(model): + client = self._clients.get(backend) + if client is not None: + yield BackendClient(name=backend, client=client) + + +def build_backend_headers(backend: BackendSpec) -> dict[str, str]: + """Compose default headers for one backend's httpx client. + + Mirrors ``lifespan._build_upstream_headers`` but resolves against a + specific :class:`BackendSpec` so multi-backend gets its own per-backend + auth without all backends having to share one token. + """ + headers: dict[str, str] = {"User-Agent": "neuronetz-gateway"} + if backend.has_auth and backend.auth_token is not None: + raw = backend.auth_token.get_secret_value().strip() + if backend.auth_header.lower() == "authorization": + headers[backend.auth_header] = f"{backend.auth_scheme} {raw}".strip() + else: + headers[backend.auth_header] = raw + return headers + + +def build_http_clients(settings: Settings) -> tuple[dict[str, httpx.AsyncClient], list[str]]: + """Build one httpx client per configured backend. + + Returns ``(clients, order)``: a name→client mapping and the priority-ordered + list of backend names. The caller is responsible for ``aclose()`` of every + client in the returned mapping during lifespan shutdown. + """ + backends = settings.effective_backends() + timeout = httpx.Timeout( + connect=settings.ollama_connect_timeout_s, + read=settings.ollama_read_timeout_s, + write=settings.ollama_read_timeout_s, + pool=settings.ollama_connect_timeout_s, + ) + limits = httpx.Limits(max_connections=settings.ollama_max_connections) + clients: dict[str, httpx.AsyncClient] = {} + order: list[str] = [] + for backend in backends: + if backend.name in clients: + _log.warning("duplicate_backend_name_skipped", name=backend.name) + continue + clients[backend.name] = httpx.AsyncClient( + base_url=backend.base_url, + timeout=timeout, + limits=limits, + headers=build_backend_headers(backend), + ) + order.append(backend.name) + return clients, order + + +__all__ = [ + "BackendClient", + "BackendRouter", + "build_backend_headers", + "build_http_clients", +]