From 6a92bc8ce92d835d7f604f9791b01cd20d67e0a9 Mon Sep 17 00:00:00 2001 From: Stephan Berbig Date: Tue, 26 May 2026 20:52:33 +0200 Subject: [PATCH] proxy: streaming, discovery, OpenAI-compat, rate-limit, budget, audit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The hot path. A single Pipeline class owns enforcement so the eight non-negotiables can be reviewed in one place. - Native /api/chat, /api/generate (NDJSON streaming + non-stream), /api/tags, /api/show (system-prompt + template stripped), /api/embed(dings), /api/version (returns gateway version, not Ollama's). Endpoint catch-all returns the same generic 403 for hard-blocked and unknown /api/* paths so attackers cannot enumerate which mutating endpoints exist. - OpenAI-compat /v1/chat/completions, /v1/completions, /v1/embeddings, /v1/models with SSE (`data: {...}` + final `data: [DONE]`); preserves streaming end-to-end. - Model discovery (SPEC §4.6): background poller against Ollama /api/tags; Redis + in-process cache (TTL = MODEL_DISCOVERY_CACHE_TTL_S, refresh = MODEL_DISCOVERY_REFRESH_S); fail-closed when the discovered set is empty. - Effective-set resolution in proxy/allowlist.py: allow_all = key.allow_all_models ?? tenant.allow_all_models effective = discovered if allow_all else (key.allowed_models ?? tenant.allowed_models) ∩ discovered A non-effective model returns the same generic 403 whether it's installed- but-unpermitted or doesn't exist at all (no enumeration leak). - Sliding-window rate limit (Redis Lua, single round-trip) for per-key + per-tenant RPM and per-key TPM. Redis-INCR/DECR concurrency semaphore with TTL guard. Token-budget counters per (key, period) with a Postgres ledger for reconciliation across resets. Headers per SPEC §6.5 on every response; 429 carries Retry-After; Redis outage → 503 (fail closed, never 200). - Token counting from the FINAL stream object (NDJSON `done` or the SSE chunk carrying `usage`); the audit row is written AFTER stream close so TTFB is never degraded by bookkeeping. - Audit writer: asyncio.Queue + bounded ring buffer; deny-mode flip on overflow. Optional prompt log per key (TTL'd). - Revocation listener: asyncpg LISTEN on key_revoked → evict the Redis cache entry within ~1s of the console writing to gateway.revocations. - Prometheus counters/histograms labeled by tenant only (per SPEC §13.3). --- src/neuronetz_gateway/audit/__init__.py | 3 + src/neuronetz_gateway/audit/prompt_log.py | 63 +++ src/neuronetz_gateway/audit/writer.py | 152 ++++++ src/neuronetz_gateway/budget/__init__.py | 3 + src/neuronetz_gateway/budget/counter.py | 105 ++++ src/neuronetz_gateway/budget/ledger.py | 58 +++ .../observability/metrics.py | 67 +++ src/neuronetz_gateway/proxy/__init__.py | 3 + src/neuronetz_gateway/proxy/allowlist.py | 76 +++ src/neuronetz_gateway/proxy/discovery.py | 207 ++++++++ src/neuronetz_gateway/proxy/ollama.py | 79 +++ src/neuronetz_gateway/proxy/pipeline.py | 467 ++++++++++++++++++ src/neuronetz_gateway/proxy/token_counter.py | 50 ++ src/neuronetz_gateway/proxy/translate.py | 245 +++++++++ src/neuronetz_gateway/ratelimit/__init__.py | 3 + .../ratelimit/concurrency.py | 66 +++ .../ratelimit/sliding_window.py | 109 ++++ src/neuronetz_gateway/revocation.py | 97 ++++ src/neuronetz_gateway/routes/ollama_native.py | 168 +++++++ src/neuronetz_gateway/routes/openai_compat.py | 118 +++++ 20 files changed, 2139 insertions(+) create mode 100644 src/neuronetz_gateway/audit/__init__.py create mode 100644 src/neuronetz_gateway/audit/prompt_log.py create mode 100644 src/neuronetz_gateway/audit/writer.py create mode 100644 src/neuronetz_gateway/budget/__init__.py create mode 100644 src/neuronetz_gateway/budget/counter.py create mode 100644 src/neuronetz_gateway/budget/ledger.py create mode 100644 src/neuronetz_gateway/observability/metrics.py create mode 100644 src/neuronetz_gateway/proxy/__init__.py create mode 100644 src/neuronetz_gateway/proxy/allowlist.py create mode 100644 src/neuronetz_gateway/proxy/discovery.py create mode 100644 src/neuronetz_gateway/proxy/ollama.py create mode 100644 src/neuronetz_gateway/proxy/pipeline.py create mode 100644 src/neuronetz_gateway/proxy/token_counter.py create mode 100644 src/neuronetz_gateway/proxy/translate.py create mode 100644 src/neuronetz_gateway/ratelimit/__init__.py create mode 100644 src/neuronetz_gateway/ratelimit/concurrency.py create mode 100644 src/neuronetz_gateway/ratelimit/sliding_window.py create mode 100644 src/neuronetz_gateway/revocation.py create mode 100644 src/neuronetz_gateway/routes/ollama_native.py create mode 100644 src/neuronetz_gateway/routes/openai_compat.py diff --git a/src/neuronetz_gateway/audit/__init__.py b/src/neuronetz_gateway/audit/__init__.py new file mode 100644 index 0000000..b7a1d00 --- /dev/null +++ b/src/neuronetz_gateway/audit/__init__.py @@ -0,0 +1,3 @@ +"""Audit logging: buffered async audit writer and opt-in prompt log.""" + +from __future__ import annotations diff --git a/src/neuronetz_gateway/audit/prompt_log.py b/src/neuronetz_gateway/audit/prompt_log.py new file mode 100644 index 0000000..f01aaa4 --- /dev/null +++ b/src/neuronetz_gateway/audit/prompt_log.py @@ -0,0 +1,63 @@ +"""Opt-in, TTL'd prompt logging (SPEC §2, §5). + +Disabled by default; enabled per key (or inherited from tenant via the resolved +:class:`Principal.log_prompts`). Each row carries a ``retention_until`` deadline +(swept in Phase 4). A redaction hook runs before persistence so sensitive spans +can be scrubbed without changing call sites. +""" + +from __future__ import annotations + +import datetime +from collections.abc import Callable +from dataclasses import dataclass +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from neuronetz_gateway.db.models import PromptLog + +# A redaction hook maps a request/response body to a sanitized copy. The default +# is identity; operators can install a stricter hook. +RedactionHook = Callable[[dict[str, object]], dict[str, object]] + + +def _identity(body: dict[str, object]) -> dict[str, object]: + return body + + +@dataclass(frozen=True, slots=True) +class PromptRecord: + """A captured request/response pair pending TTL'd persistence.""" + + audit_id: int + key_id: UUID + request_body: dict[str, object] + response_text: str | None + retention_days: int + + +class PromptLogWriter: + """Persists opt-in prompt records with a retention deadline.""" + + def __init__(self, session: AsyncSession, redact: RedactionHook | None = None) -> None: + self._session = session + self._redact = redact or _identity + + async def write(self, record: PromptRecord) -> None: + """Persist a prompt record to ``gateway.prompt_log``.""" + retention_until = datetime.datetime.now(datetime.UTC) + datetime.timedelta( + days=record.retention_days + ) + row = PromptLog( + audit_id=record.audit_id, + key_id=record.key_id, + request_body=self._redact(record.request_body), + response_text=record.response_text, + retention_until=retention_until, + ) + self._session.add(row) + await self._session.flush() + + +__all__ = ["PromptLogWriter", "PromptRecord", "RedactionHook"] diff --git a/src/neuronetz_gateway/audit/writer.py b/src/neuronetz_gateway/audit/writer.py new file mode 100644 index 0000000..bddce8d --- /dev/null +++ b/src/neuronetz_gateway/audit/writer.py @@ -0,0 +1,152 @@ +"""Buffered async audit-log writer (SPEC §4.4). + +Audit rows are enqueued *after* stream close so the hot path is never delayed +(non-negotiable #6). A background drain task persists them to Postgres. On a +Postgres write failure rows are buffered in a bounded in-memory ring (max +``AUDIT_BUFFER_SIZE``) and retried; if the ring fills, the writer flips to +**deny mode** (SPEC §4.4) — :pyattr:`deny_mode` goes True and the request path +must refuse new work until the backlog drains. + +The writer holds the session factory directly (it runs outside request scope). +``enqueue`` never blocks on the DB; it only touches the in-memory queue/ring. +""" + +from __future__ import annotations + +import asyncio +import collections +import datetime +from dataclasses import asdict, dataclass +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from neuronetz_gateway.db.repositories import AuditRepository +from neuronetz_gateway.observability.logging import get_logger + +_log = get_logger("audit") + + +@dataclass(frozen=True, slots=True) +class AuditRecord: + """A single audit-log row queued for persistence.""" + + request_id: UUID + method: str + path: str + status: int + ts: datetime.datetime + tenant_id: UUID | None = None + key_id: UUID | None = None + key_prefix: str | None = None + model: str | None = None + tokens_in: int | None = None + tokens_out: int | None = None + latency_ms: int | None = None + client_ip: str | None = None + user_agent: str | None = None + error_code: str | None = None + + def as_columns(self) -> dict[str, object]: + """Map to ``gateway.audit_log`` column kwargs.""" + return asdict(self) + + +class AuditWriter: + """Buffered, fail-safe writer for ``gateway.audit_log``.""" + + def __init__( + self, + buffer_size: int, + sessionmaker: async_sessionmaker[AsyncSession] | None = None, + ) -> None: + self._buffer_size = buffer_size + self._sessionmaker = sessionmaker + self._ring: collections.deque[AuditRecord] = collections.deque(maxlen=buffer_size) + self._queue: asyncio.Queue[AuditRecord] = asyncio.Queue() + self._deny_mode = False + self._task: asyncio.Task[None] | None = None + + @property + def deny_mode(self) -> bool: + """True when the buffer has overflowed and new work must be denied.""" + return self._deny_mode + + def bind(self, sessionmaker: async_sessionmaker[AsyncSession]) -> None: + """Attach the DB session factory (called from lifespan startup).""" + self._sessionmaker = sessionmaker + + async def enqueue(self, record: AuditRecord) -> None: + """Queue an audit record for asynchronous persistence (non-blocking).""" + await self._queue.put(record) + + def start(self) -> None: + """Launch the background drain task (idempotent).""" + if self._task is None or self._task.done(): + self._task = asyncio.create_task(self._drain_loop()) + + async def stop(self) -> None: + """Cancel the drain task and flush remaining rows best-effort.""" + if self._task is not None: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + self._task = None + await self.flush() + + async def _drain_loop(self) -> None: + """Continuously move queued records into Postgres, retrying the ring.""" + while True: + record = await self._queue.get() + await self._persist_with_retry(record) + + async def _persist_with_retry(self, record: AuditRecord) -> None: + """Persist one record; on failure push to the ring (or enter deny mode).""" + # Drain any backlog first so ordering is roughly preserved. + if self._ring: + await self._try_flush_ring() + if not await self._write_one(record): + self._buffer(record) + + def _buffer(self, record: AuditRecord) -> None: + """Buffer a failed write; flip to deny mode if the ring is full.""" + if len(self._ring) >= self._buffer_size: + self._deny_mode = True + _log.error("audit_buffer_overflow_deny_mode", buffered=len(self._ring)) + return + self._ring.append(record) + + async def _try_flush_ring(self) -> None: + """Attempt to persist buffered rows; clear deny mode once drained.""" + while self._ring: + record = self._ring[0] + if not await self._write_one(record): + return + self._ring.popleft() + self._deny_mode = False + + async def _write_one(self, record: AuditRecord) -> bool: + """Persist a single record; return False on any failure.""" + if self._sessionmaker is None: + return False + try: + async with self._sessionmaker() as session: + await AuditRepository(session).insert_audit(**record.as_columns()) + await session.commit() + return True + except Exception as exc: # noqa: BLE001 - any DB error ⇒ buffer + retry + _log.warning("audit_write_failed", error=str(exc)) + return False + + async def flush(self) -> None: + """Drain queued + buffered records to Postgres (best-effort).""" + while not self._queue.empty(): + record = self._queue.get_nowait() + if not await self._write_one(record): + self._buffer(record) + await self._try_flush_ring() + + +__all__ = ["AuditRecord", "AuditWriter"] diff --git a/src/neuronetz_gateway/budget/__init__.py b/src/neuronetz_gateway/budget/__init__.py new file mode 100644 index 0000000..ad1e62c --- /dev/null +++ b/src/neuronetz_gateway/budget/__init__.py @@ -0,0 +1,3 @@ +"""Token budgets: Redis period counters and Postgres ledger reconciliation.""" + +from __future__ import annotations diff --git a/src/neuronetz_gateway/budget/counter.py b/src/neuronetz_gateway/budget/counter.py new file mode 100644 index 0000000..d86d544 --- /dev/null +++ b/src/neuronetz_gateway/budget/counter.py @@ -0,0 +1,105 @@ +"""Redis period counters for token budgets (day / month / total). + +The fast-path budget check reads a Redis counter of tokens already consumed in +the active period and compares it to the configured limit; Postgres +(``ledger.py``) is the durable source of truth reconciled on rollover. Counters +are keyed by ``key_id`` + period + period-start so a new period naturally starts +at zero, and day/month counters carry a TTL so expired periods self-clean. + +Fail-closed (SPEC §4.4): Redis errors raise +:class:`DependencyUnavailableError`; the caller must deny (503). +""" + +from __future__ import annotations + +import datetime +from collections.abc import Awaitable +from dataclasses import dataclass +from typing import cast + +import redis.asyncio as redis +from redis.exceptions import RedisError + +from neuronetz_gateway.db.models import BudgetPeriod +from neuronetz_gateway.errors import DependencyUnavailableError + +_CONSUMED_PREFIX = "gateway:budget:" + + +def period_start(period: BudgetPeriod, now: datetime.datetime | None = None) -> datetime.datetime: + """Return the UTC start of the current period.""" + moment = now or datetime.datetime.now(datetime.UTC) + moment = moment.astimezone(datetime.UTC) + if period is BudgetPeriod.day: + return moment.replace(hour=0, minute=0, second=0, microsecond=0) + if period is BudgetPeriod.month: + return moment.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + # 'total' has no rollover; anchor at the epoch. + return datetime.datetime(1970, 1, 1, tzinfo=datetime.UTC) + + +def _ttl_for(period: BudgetPeriod) -> int | None: + """TTL (seconds) for a period counter; None for 'total' (no expiry).""" + if period is BudgetPeriod.day: + return 2 * 24 * 3600 + if period is BudgetPeriod.month: + return 40 * 24 * 3600 + return None + + +def _counter_key(key_id: str, period: BudgetPeriod, start: datetime.datetime) -> str: + return f"{_CONSUMED_PREFIX}{key_id}:{period.value}:{int(start.timestamp())}" + + +@dataclass(frozen=True, slots=True) +class BudgetState: + """Snapshot of remaining budget for a period.""" + + period: BudgetPeriod + limit: int | None + remaining: int | None + + @property + def exhausted(self) -> bool: + """True if a finite limit is configured and nothing remains.""" + return self.limit is not None and (self.remaining or 0) <= 0 + + +class BudgetCounter: + """Redis-backed per-period token counter.""" + + def __init__(self, client: redis.Redis) -> None: + self._client = client + + async def check(self, key_id: str, period: BudgetPeriod, limit: int | None) -> BudgetState: + """Return remaining budget; ``None`` limit means unlimited for the period.""" + if limit is None: + return BudgetState(period=period, limit=None, remaining=None) + start = period_start(period) + try: + raw = await self._client.get(_counter_key(key_id, period, start)) + except RedisError as exc: + raise DependencyUnavailableError( + internal_detail=f"budget redis error: {exc!r}" + ) from exc + consumed = int(raw) if raw else 0 + return BudgetState(period=period, limit=limit, remaining=max(limit - consumed, 0)) + + async def consume(self, key_id: str, period: BudgetPeriod, tokens: int) -> None: + """Increment the consumed counter after a request completes.""" + if tokens <= 0: + return + start = period_start(period) + key = _counter_key(key_id, period, start) + try: + await cast("Awaitable[int]", self._client.incrby(key, tokens)) + ttl = _ttl_for(period) + if ttl is not None: + await self._client.expire(key, ttl) + except RedisError as exc: + raise DependencyUnavailableError( + internal_detail=f"budget consume redis error: {exc!r}" + ) from exc + + +__all__ = ["BudgetCounter", "BudgetState", "period_start"] diff --git a/src/neuronetz_gateway/budget/ledger.py b/src/neuronetz_gateway/budget/ledger.py new file mode 100644 index 0000000..b86717a --- /dev/null +++ b/src/neuronetz_gateway/budget/ledger.py @@ -0,0 +1,58 @@ +"""Postgres budget ledger reconciliation. + +Persists token usage to ``gateway.budget_usage`` (the durable source of truth) +via an idempotent upsert keyed by (key_id, period, period_start). The Redis +counter (``counter.py``) is the fast path; this ledger is what survives a Redis +flush and what ``show-usage`` reports against. +""" + +from __future__ import annotations + +import uuid + +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncSession + +from neuronetz_gateway.budget.counter import period_start +from neuronetz_gateway.db.models import BudgetPeriod, BudgetUsage + + +class BudgetLedger: + """Source-of-truth budget accounting in Postgres.""" + + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def record_usage( + self, key_id: str, period: BudgetPeriod, tokens_in: int, tokens_out: int + ) -> None: + """Upsert usage into ``gateway.budget_usage`` for the active period. + + Uses an ``ON CONFLICT`` upsert so concurrent writers accumulate rather + than clobber. ``requests`` increments by one per recorded request. + """ + start = period_start(period) + stmt = pg_insert(BudgetUsage).values( + key_id=uuid.UUID(key_id) if isinstance(key_id, str) else key_id, + period=period, + period_start=start, + tokens_in=tokens_in, + tokens_out=tokens_out, + requests=1, + ) + stmt = stmt.on_conflict_do_update( + index_elements=[ + BudgetUsage.key_id, + BudgetUsage.period, + BudgetUsage.period_start, + ], + set_={ + "tokens_in": BudgetUsage.tokens_in + stmt.excluded.tokens_in, + "tokens_out": BudgetUsage.tokens_out + stmt.excluded.tokens_out, + "requests": BudgetUsage.requests + stmt.excluded.requests, + }, + ) + await self._session.execute(stmt) + + +__all__ = ["BudgetLedger"] diff --git a/src/neuronetz_gateway/observability/metrics.py b/src/neuronetz_gateway/observability/metrics.py new file mode 100644 index 0000000..6bee1e6 --- /dev/null +++ b/src/neuronetz_gateway/observability/metrics.py @@ -0,0 +1,67 @@ +"""Prometheus metrics. + +Phase 1 declares the metric objects and the exposition helper. Instrumentation +(incrementing counters / observing histograms on the request path) is wired in +later phases. Per SPEC §13.3 we label by ``tenant`` only, never by ``key_id``. +""" + +from __future__ import annotations + +from prometheus_client import CollectorRegistry, Counter, Histogram, generate_latest + +REGISTRY = CollectorRegistry() + +REQUESTS_TOTAL = Counter( + "gateway_requests_total", + "Total proxied requests.", + labelnames=("tenant", "model", "status"), + registry=REGISTRY, +) + +TOKENS_TOTAL = Counter( + "gateway_tokens_total", + "Total tokens accounted, by direction (in|out).", + labelnames=("tenant", "model", "direction"), + registry=REGISTRY, +) + +REQUEST_DURATION_SECONDS = Histogram( + "gateway_request_duration_seconds", + "Gateway-side request duration in seconds.", + labelnames=("tenant", "model"), + registry=REGISTRY, +) + + +def record_request(tenant: str, model: str, status: int, duration_s: float) -> None: + """Increment the request counter and observe its duration (tenant-labeled).""" + REQUESTS_TOTAL.labels(tenant=tenant, model=model, status=str(status)).inc() + REQUEST_DURATION_SECONDS.labels(tenant=tenant, model=model).observe(duration_s) + + +def record_tokens(tenant: str, model: str, tokens_in: int, tokens_out: int) -> None: + """Add input/output token counts to the tokens counter.""" + if tokens_in: + TOKENS_TOTAL.labels(tenant=tenant, model=model, direction="in").inc(tokens_in) + if tokens_out: + TOKENS_TOTAL.labels(tenant=tenant, model=model, direction="out").inc(tokens_out) + + +def render_latest() -> bytes: + """Return the current metrics in Prometheus text exposition format.""" + payload: bytes = generate_latest(REGISTRY) + return payload + + +CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" + +__all__ = [ + "CONTENT_TYPE_LATEST", + "REGISTRY", + "REQUESTS_TOTAL", + "REQUEST_DURATION_SECONDS", + "TOKENS_TOTAL", + "record_request", + "record_tokens", + "render_latest", +] diff --git a/src/neuronetz_gateway/proxy/__init__.py b/src/neuronetz_gateway/proxy/__init__.py new file mode 100644 index 0000000..0cda8fa --- /dev/null +++ b/src/neuronetz_gateway/proxy/__init__.py @@ -0,0 +1,3 @@ +"""Proxy layer: Ollama client, schema translation, token counting, allowlists.""" + +from __future__ import annotations diff --git a/src/neuronetz_gateway/proxy/allowlist.py b/src/neuronetz_gateway/proxy/allowlist.py new file mode 100644 index 0000000..2c998dc --- /dev/null +++ b/src/neuronetz_gateway/proxy/allowlist.py @@ -0,0 +1,76 @@ +"""Endpoint and model allowlists (SPEC §6.1, §6.2). + +Mutating Ollama endpoints are hard-blocked (not configurable, not flagged): +``/api/pull``, ``/api/push``, ``/api/create``, ``/api/copy``, ``/api/delete``, +and any ``/api/blobs/*``. ``/api/ps`` is also blocked (leaks loaded models). +Model allowlist is per-tenant, default-deny. Enforcement logic lands in Phase 2. +""" + +from __future__ import annotations + +# Hard-blocked upstream endpoints — always 403, not configurable (SPEC §6.2). +HARD_BLOCKED_PATHS: frozenset[str] = frozenset( + { + "/api/pull", + "/api/push", + "/api/create", + "/api/copy", + "/api/delete", + "/api/ps", + } +) + +# Path prefixes that are hard-blocked (e.g. blob upload/download). +HARD_BLOCKED_PREFIXES: tuple[str, ...] = ("/api/blobs",) + + +def is_hard_blocked(path: str) -> bool: + """Return True if ``path`` is an unconditionally blocked upstream endpoint.""" + if path in HARD_BLOCKED_PATHS: + return True + return any(path.startswith(prefix) for prefix in HARD_BLOCKED_PREFIXES) + + +def resolve_effective_models( + *, + allow_all: bool, + allowed_models: tuple[str, ...], + discovered: frozenset[str], +) -> frozenset[str]: + """Resolve the effective model set per SPEC §4.3 step 7 / §4.6. + + ``allow_all`` ⇒ the effective set is the entire live ``discovered`` set; + otherwise it is the configured ``allowed_models`` intersected with + ``discovered`` (so stale or typo'd allowlist entries never resolve, and a + model that is unpermitted vs. not-installed are indistinguishable). + + Fail-closed: if ``discovered`` is empty (discovery unavailable/expired) the + result is empty regardless of ``allow_all`` — discovery never opens access. + """ + if not discovered: + return frozenset() + if allow_all: + return discovered + return frozenset(allowed_models) & discovered + + +def is_model_allowed( + model: str, + *, + allow_all: bool, + allowed_models: tuple[str, ...], + discovered: frozenset[str], +) -> bool: + """Return True iff ``model`` is in the resolved effective set (default-deny).""" + return model in resolve_effective_models( + allow_all=allow_all, allowed_models=allowed_models, discovered=discovered + ) + + +__all__ = [ + "HARD_BLOCKED_PATHS", + "HARD_BLOCKED_PREFIXES", + "is_hard_blocked", + "is_model_allowed", + "resolve_effective_models", +] diff --git a/src/neuronetz_gateway/proxy/discovery.py b/src/neuronetz_gateway/proxy/discovery.py new file mode 100644 index 0000000..81afcf3 --- /dev/null +++ b/src/neuronetz_gateway/proxy/discovery.py @@ -0,0 +1,207 @@ +"""Live model discovery from the Ollama backend (SPEC §4.6). + +A background task polls Ollama ``GET /api/tags`` every +``MODEL_DISCOVERY_REFRESH_S`` seconds. The parsed model set (names + sanitized +metadata) is cached in Redis under ``gateway:models:discovered`` with TTL +``MODEL_DISCOVERY_CACHE_TTL_S`` and held in-process for hot reads on the request +path. + +Fail-closed (SPEC §4.6, §13.5): if Ollama is unreachable, or the cache is empty +or expired and cannot be refreshed, the discovered set is empty — and an empty +discovered set means no model resolves, so requests are denied. Discovery never +opens access on failure. It is read-only and only ever touches the allowlisted +``/api/tags`` endpoint; it never triggers a pull. +""" + +from __future__ import annotations + +import asyncio +import json +from dataclasses import asdict, dataclass + +import httpx +import redis.asyncio as redis + +from neuronetz_gateway.config import Settings +from neuronetz_gateway.observability.logging import get_logger + +_log = get_logger("discovery") + +REDIS_DISCOVERED_KEY = "gateway:models:discovered" + + +@dataclass(frozen=True, slots=True) +class DiscoveredModel: + """Sanitized metadata for a single installed model.""" + + name: str + family: str | None = None + parameter_size: str | None = None + quantization: str | None = None + size_bytes: int | None = None + modified_at: str | None = None + + +def _parse_tags(payload: dict[str, object]) -> list[DiscoveredModel]: + """Parse an Ollama ``/api/tags`` body into sanitized model records.""" + models: list[DiscoveredModel] = [] + raw_models = payload.get("models") + if not isinstance(raw_models, list): + return models + for entry in raw_models: + if not isinstance(entry, dict): + continue + name = entry.get("name") or entry.get("model") + if not isinstance(name, str) or not name: + continue + raw_details = entry.get("details") + details: dict[str, object] = raw_details if isinstance(raw_details, dict) else {} + size = entry.get("size") + models.append( + DiscoveredModel( + name=name, + family=_opt_str(details.get("family")), + parameter_size=_opt_str(details.get("parameter_size")), + quantization=_opt_str(details.get("quantization_level")), + size_bytes=size if isinstance(size, int) else None, + modified_at=_opt_str(entry.get("modified_at")), + ) + ) + return models + + +def _opt_str(value: object) -> str | None: + """Coerce a value to ``str`` only when it already is one.""" + return value if isinstance(value, str) else None + + +def names_of(models: list[DiscoveredModel]) -> frozenset[str]: + """Return just the set of model names from parsed records.""" + return frozenset(m.name for m in models) + + +class DiscoveryCache: + """In-process holder for the latest discovered model set. + + Holds both the structured records (for ``/api/tags`` / ``list-models``) and a + fast name set (for allowlist resolution on the hot path). Reads never block + on Redis or Ollama; the poller refreshes this in the background. + """ + + def __init__(self) -> None: + self._models: list[DiscoveredModel] = [] + self._names: frozenset[str] = frozenset() + self._lock = asyncio.Lock() + + async def set(self, models: list[DiscoveredModel]) -> None: + """Replace the in-process snapshot atomically.""" + async with self._lock: + self._models = list(models) + self._names = names_of(models) + + @property + def names(self) -> frozenset[str]: + """Current discovered model names (possibly empty ⇒ fail-closed).""" + return self._names + + @property + def models(self) -> list[DiscoveredModel]: + """Current discovered model records (copy).""" + return list(self._models) + + +async def write_discovered_to_redis( + client: redis.Redis, models: list[DiscoveredModel], ttl_s: int +) -> None: + """Cache the discovered set in Redis with a TTL (so staleness expires).""" + payload = json.dumps([asdict(m) for m in models], separators=(",", ":")) + await client.set(REDIS_DISCOVERED_KEY, payload, ex=ttl_s) + + +async def read_discovered_from_redis(client: redis.Redis) -> frozenset[str]: + """Read the cached discovered names from Redis; empty set on miss/expiry.""" + raw = await client.get(REDIS_DISCOVERED_KEY) + if not raw: + return frozenset() + try: + data = json.loads(raw) + except (json.JSONDecodeError, TypeError): + return frozenset() + if not isinstance(data, list): + return frozenset() + return frozenset( + str(item["name"]) for item in data if isinstance(item, dict) and item.get("name") + ) + + +async def fetch_tags(client: httpx.AsyncClient) -> list[DiscoveredModel]: + """Fetch and parse Ollama ``/api/tags``; raise on transport/HTTP error.""" + resp = await client.get("/api/tags") + resp.raise_for_status() + body = resp.json() + if not isinstance(body, dict): + return [] + return _parse_tags(body) + + +async def refresh_once( + http_client: httpx.AsyncClient, + redis_client: redis.Redis | None, + cache: DiscoveryCache, + settings: Settings, +) -> bool: + """Run a single discovery refresh. Returns True on success. + + On any failure the in-process and Redis caches are left untouched; they + expire on their own TTL, which is the fail-closed behavior (stale-expired ⇒ + empty ⇒ deny). We never *clear* eagerly on a transient error, but we also + never extend the TTL on failure. + """ + try: + models = await fetch_tags(http_client) + except (httpx.HTTPError, ValueError) as exc: + _log.warning("discovery_refresh_failed", error=str(exc)) + return False + await cache.set(models) + if redis_client is not None: + try: + await write_discovered_to_redis( + redis_client, models, settings.model_discovery_cache_ttl_s + ) + except Exception as exc: # noqa: BLE001 - Redis write is best-effort cache fill + _log.warning("discovery_cache_write_failed", error=str(exc)) + _log.info("discovery_refreshed", count=len(models)) + return True + + +async def discovery_loop( + http_client: httpx.AsyncClient, + redis_client: redis.Redis | None, + cache: DiscoveryCache, + settings: Settings, +) -> None: + """Background poller: refresh now, then every ``MODEL_DISCOVERY_REFRESH_S``. + + Designed to be launched via ``asyncio.create_task`` in the lifespan and + cancelled on shutdown. + """ + await refresh_once(http_client, redis_client, cache, settings) + while True: + try: + await asyncio.sleep(settings.model_discovery_refresh_s) + except asyncio.CancelledError: + raise + await refresh_once(http_client, redis_client, cache, settings) + + +__all__ = [ + "REDIS_DISCOVERED_KEY", + "DiscoveredModel", + "DiscoveryCache", + "discovery_loop", + "fetch_tags", + "names_of", + "read_discovered_from_redis", + "refresh_once", + "write_discovered_to_redis", +] diff --git a/src/neuronetz_gateway/proxy/ollama.py b/src/neuronetz_gateway/proxy/ollama.py new file mode 100644 index 0000000..048d8ca --- /dev/null +++ b/src/neuronetz_gateway/proxy/ollama.py @@ -0,0 +1,79 @@ +"""httpx-based streaming proxy client for Ollama. + +Opens async streams to the upstream and relays bytes without buffering, so the +gateway does not degrade time-to-first-byte (SPEC §9, non-negotiable #6). +Transport-level upstream failures are sanitized at this boundary into +:class:`UpstreamUnavailableError` (never reflected verbatim, non-negotiable #4). + +The client is constructed from the shared ``httpx.AsyncClient`` on +``app.state`` and is exposed to routes via the ``get_ollama_client`` dependency +in ``deps.py`` so tests can override it (the QA override contract). +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator + +import httpx + +from neuronetz_gateway.errors import UpstreamUnavailableError + + +class OllamaClient: + """Thin async wrapper around the shared httpx client for Ollama calls.""" + + def __init__(self, client: httpx.AsyncClient) -> None: + self._client = client + + async def stream( + self, method: str, path: str, json_body: dict[str, object] + ) -> AsyncIterator[bytes]: + """Open a streaming request to Ollama and yield raw response chunks. + + Yields bytes exactly as received so the hot path performs no buffering. + Transport errors are sanitized; if the upstream returns a non-2xx the + body is drained (so internals are not surfaced) and a generic 502 is + raised before any client bytes are emitted. + """ + request = self._client.build_request(method, path, json=json_body) + try: + response = await self._client.send(request, stream=True) + except httpx.HTTPError as exc: + raise UpstreamUnavailableError(internal_detail=f"ollama send failed: {exc!r}") from exc + try: + if response.status_code >= 400: + await response.aread() + raise UpstreamUnavailableError( + internal_detail=f"ollama returned {response.status_code} for {path}" + ) + async for chunk in response.aiter_raw(): + yield chunk + except httpx.HTTPError as exc: + raise UpstreamUnavailableError( + internal_detail=f"ollama stream failed: {exc!r}" + ) from exc + finally: + await response.aclose() + + async def request( + self, method: str, path: str, json_body: dict[str, object] + ) -> httpx.Response: + """Perform a non-streaming request to Ollama. + + Returns the upstream response on success; raises a sanitized 502 on + transport failure or a non-2xx status (internals never reflected). + """ + try: + response = await self._client.request(method, path, json=json_body) + except httpx.HTTPError as exc: + raise UpstreamUnavailableError( + internal_detail=f"ollama request failed: {exc!r}" + ) from exc + if response.status_code >= 400: + raise UpstreamUnavailableError( + internal_detail=f"ollama returned {response.status_code} for {path}" + ) + return response + + +__all__ = ["OllamaClient"] diff --git a/src/neuronetz_gateway/proxy/pipeline.py b/src/neuronetz_gateway/proxy/pipeline.py new file mode 100644 index 0000000..dfa27f8 --- /dev/null +++ b/src/neuronetz_gateway/proxy/pipeline.py @@ -0,0 +1,467 @@ +"""Shared request pipeline: enforcement, streaming, and post-close bookkeeping. + +Both the native (``/api/*``) and OpenAI-compat (``/v1/*``) routes funnel through +here so the security checks and the streaming-integrity contract are written +once. The order mirrors SPEC §4.3: + + rate limit (per-key + per-tenant RPM) → budget → concurrency → model allowlist + → endpoint allowlist → body validation → proxy + stream → post-close audit / + token-count / budget-consume / metrics / semaphore release. + +Streaming integrity (non-negotiable #6): the bytes flow to the client untouched +and token counting + audit + budget-consume happen **after** the stream closes, +never on the hot path. + +Fail-closed (non-negotiable #1): every limiter/budget call raises +:class:`DependencyUnavailableError` when Redis is down, which the error handler +renders as 503. The model allowlist is default-deny against the live discovered +set; a missing/expired discovery set denies everything. +""" + +from __future__ import annotations + +import datetime +import json +import time +import uuid +from collections.abc import AsyncIterator, Callable +from dataclasses import dataclass +from typing import Any + +from fastapi import Request +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from starlette.responses import JSONResponse, StreamingResponse + +from neuronetz_gateway.audit.writer import AuditRecord, AuditWriter +from neuronetz_gateway.auth.principal import Principal +from neuronetz_gateway.budget.counter import BudgetCounter +from neuronetz_gateway.budget.ledger import BudgetLedger +from neuronetz_gateway.config import Settings +from neuronetz_gateway.db.models import BudgetPeriod +from neuronetz_gateway.errors import ( + AuthorizationError, + BudgetExceededError, + RateLimitError, + RequestTooLargeError, +) +from neuronetz_gateway.observability import metrics +from neuronetz_gateway.observability.logging import get_logger +from neuronetz_gateway.proxy.allowlist import is_hard_blocked, is_model_allowed +from neuronetz_gateway.proxy.discovery import DiscoveryCache +from neuronetz_gateway.proxy.ollama import OllamaClient +from neuronetz_gateway.proxy.token_counter import TokenUsage, extract_usage +from neuronetz_gateway.ratelimit.concurrency import ConcurrencyLimiter +from neuronetz_gateway.ratelimit.sliding_window import SlidingWindowLimiter + +_log = get_logger("pipeline") + +NDJSON_MEDIA_TYPE = "application/x-ndjson" +SSE_MEDIA_TYPE = "text/event-stream" +_CONCURRENCY_PREFIX = "gateway:concurrency:" + + +@dataclass(slots=True) +class RateHeaders: + """The §6.5 rate/budget header values gathered during pre-flight.""" + + limit_requests: int + remaining_requests: int + limit_tokens: int + remaining_tokens: int + budget_period: str + budget_remaining: str + + def as_dict(self, request_id: str) -> dict[str, str]: + """Render to the SPEC §6.5 header set.""" + return { + "X-Request-ID": request_id, + "X-RateLimit-Limit-Requests": str(self.limit_requests), + "X-RateLimit-Remaining-Requests": str(self.remaining_requests), + "X-RateLimit-Limit-Tokens": str(self.limit_tokens), + "X-RateLimit-Remaining-Tokens": str(self.remaining_tokens), + "X-Budget-Period": self.budget_period, + "X-Budget-Tokens-Remaining": self.budget_remaining, + } + + +class Pipeline: + """Per-request enforcement + proxy orchestrator.""" + + def __init__( + self, + *, + request: Request, + principal: Principal, + settings: Settings, + ollama: OllamaClient, + discovery: DiscoveryCache, + rate_limiter: SlidingWindowLimiter, + concurrency: ConcurrencyLimiter, + budget: BudgetCounter, + audit: AuditWriter, + sessionmaker: async_sessionmaker[AsyncSession] | None = None, + ) -> None: + self._request = request + self._p = principal + self._settings = settings + self._ollama = ollama + self._discovery = discovery + self._rate = rate_limiter + self._conc = concurrency + self._budget = budget + self._audit = audit + self._sessionmaker = sessionmaker + self._request_id = str(getattr(request.state, "request_id", uuid.uuid4())) + self._concurrency_key = f"{_CONCURRENCY_PREFIX}{principal.tenant_id}" + self._headers = RateHeaders( + limit_requests=principal.limits.rpm, + remaining_requests=principal.limits.rpm, + limit_tokens=principal.limits.tpm, + remaining_tokens=principal.limits.tpm, + budget_period="day", + budget_remaining="unlimited", + ) + + @property + def settings(self) -> Settings: + """The active settings (for body-size / num_predict caps in routes).""" + return self._settings + + # ----- enforcement ------------------------------------------------------- + + def check_scope(self, scope: str) -> None: + """Authorize a coarse scope (e.g. 'chat', 'embeddings').""" + if scope not in self._p.scopes: + raise AuthorizationError(internal_detail=f"scope {scope!r} not granted") + + def check_endpoint(self, path: str) -> None: + """Reject hard-blocked upstream endpoints with a generic 403.""" + if is_hard_blocked(path): + raise AuthorizationError(internal_detail=f"hard-blocked endpoint {path}") + + def check_model(self, model: str) -> None: + """Default-deny model check against the live effective set (§4.3 step 7).""" + if not model or not is_model_allowed( + model, + allow_all=self._p.allow_all_models, + allowed_models=self._p.allowed_models, + discovered=self._discovery.names, + ): + # No existence disclosure (SPEC §13.6): unpermitted and not-installed + # both yield the same generic 403. + raise AuthorizationError(internal_detail=f"model {model!r} not in effective set") + + def validate_body(self, body: dict[str, object]) -> None: + """Enforce the ``num_predict`` cap (body-size cap is enforced earlier).""" + options = body.get("options") + if isinstance(options, dict): + num_predict = options.get("num_predict") + if isinstance(num_predict, int) and num_predict > self._settings.max_num_predict: + options["num_predict"] = self._settings.max_num_predict + + async def enforce_limits(self, *, token_estimate: int = 0) -> None: + """Run RPM (per-key + per-tenant), TPM, budget, and concurrency checks. + + Budget is checked before concurrency so an over-budget request never even + acquires a permit. Order otherwise follows SPEC §4.3 steps 4-6. + """ + await self._check_rpm() + await self._check_tpm(token_estimate) + await self.check_budgets() + await self._acquire_concurrency() + + async def _check_rpm(self) -> None: + key_result = await self._rate.check( + f"gateway:rpm:key:{self._p.key_id}", self._p.limits.rpm, 60, cost=1 + ) + self._headers.limit_requests = key_result.limit + self._headers.remaining_requests = key_result.remaining + if not key_result.allowed: + raise RateLimitError(retry_after=key_result.retry_after_s) + tenant_result = await self._rate.check( + f"gateway:rpm:tenant:{self._p.tenant_id}", self._p.limits.rpm, 60, cost=1 + ) + if not tenant_result.allowed: + raise RateLimitError(retry_after=tenant_result.retry_after_s) + + async def _check_tpm(self, token_estimate: int) -> None: + # TPM is charged with a minimum of 1 so a request always counts; the + # precise token cost is reconciled post-stream via the budget counter. + cost = max(token_estimate, 1) + result = await self._rate.check( + f"gateway:tpm:key:{self._p.key_id}", self._p.limits.tpm, 60, cost=cost + ) + self._headers.limit_tokens = result.limit + self._headers.remaining_tokens = result.remaining + if not result.allowed: + raise RateLimitError(retry_after=result.retry_after_s) + + def _budget_periods(self) -> list[tuple[BudgetPeriod, int | None]]: + return [ + (BudgetPeriod.day, self._p.limits.tokens_daily), + (BudgetPeriod.month, self._p.limits.tokens_monthly), + (BudgetPeriod.total, self._p.limits.tokens_total), + ] + + async def check_budgets(self) -> None: + """Verify no configured budget period is already exhausted.""" + tightest_period = "day" + tightest_remaining = "unlimited" + for period, limit in self._budget_periods(): + state = await self._budget.check(str(self._p.key_id), period, limit) + if state.exhausted: + raise BudgetExceededError(internal_detail=f"budget {period.value} exhausted") + if state.remaining is not None: + tightest_period = period.value + tightest_remaining = str(state.remaining) + self._headers.budget_period = tightest_period + self._headers.budget_remaining = tightest_remaining + + async def _acquire_concurrency(self) -> None: + ok = await self._conc.acquire( + self._concurrency_key, + self._p.limits.concurrent, + self._settings.ollama_read_timeout_s + 30, + ) + if not ok: + raise RateLimitError( + retry_after=1, internal_detail="concurrency cap reached" + ) + + # ----- proxy + bookkeeping ---------------------------------------------- + + def headers(self) -> dict[str, str]: + """Render the §6.5 response headers.""" + return self._headers.as_dict(self._request_id) + + async def stream_native( + self, method: str, path: str, body: dict[str, object], model: str + ) -> StreamingResponse: + """Proxy a streaming NDJSON request, accounting tokens after close.""" + started = time.monotonic() + media_type = NDJSON_MEDIA_TYPE + + async def gen() -> AsyncIterator[bytes]: + last_obj: dict[str, object] = {} + try: + async for chunk in self._ollama.stream(method, path, body): + last_obj = _merge_last_ndjson(chunk, last_obj) + yield chunk + finally: + await self._finish(model, path, method, last_obj, started) + + return StreamingResponse(gen(), media_type=media_type, headers=self.headers()) + + async def stream_openai( + self, + method: str, + path: str, + body: dict[str, object], + model: str, + chunk_translator: Callable[[dict[str, object]], dict[str, object]], + ) -> StreamingResponse: + """Proxy + translate native NDJSON into OpenAI SSE; account after close.""" + started = time.monotonic() + + async def gen() -> AsyncIterator[bytes]: + last_obj: dict[str, object] = {} + buffer = b"" + try: + async for chunk in self._ollama.stream(method, path, body): + buffer += chunk + lines = buffer.split(b"\n") + buffer = lines.pop() + for line in lines: + if not line.strip(): + continue + obj = json.loads(line) + last_obj = obj if isinstance(obj, dict) else last_obj + translated = chunk_translator(obj) + yield f"data: {json.dumps(translated)}\n\n".encode() + if buffer.strip(): + obj = json.loads(buffer) + last_obj = obj if isinstance(obj, dict) else last_obj + yield f"data: {json.dumps(chunk_translator(obj))}\n\n".encode() + yield b"data: [DONE]\n\n" + finally: + await self._finish(model, path, method, last_obj, started) + + return StreamingResponse(gen(), media_type=SSE_MEDIA_TYPE, headers=self.headers()) + + async def request_native( + self, method: str, path: str, body: dict[str, object], model: str + ) -> JSONResponse: + """Proxy a non-streaming request and account tokens before responding.""" + started = time.monotonic() + resp = await self._ollama.request(method, path, body) + payload = resp.json() + obj = payload if isinstance(payload, dict) else {} + await self._finish(model, path, method, obj, started) + return JSONResponse(obj, headers=self.headers()) + + async def request_translated( + self, + method: str, + path: str, + body: dict[str, object], + model: str, + translator: Callable[[dict[str, object]], dict[str, object]], + ) -> JSONResponse: + """Proxy a non-streaming request, translate the body, then account.""" + started = time.monotonic() + resp = await self._ollama.request(method, path, body) + payload = resp.json() + obj = payload if isinstance(payload, dict) else {} + await self._finish(model, path, method, obj, started) + return JSONResponse(translator(obj), headers=self.headers()) + + async def _finish( + self, + model: str, + path: str, + method: str, + final_obj: dict[str, object], + started: float, + ) -> None: + """Post-close bookkeeping: tokens, budget, metrics, audit, semaphore. + + Runs once per request after the response is fully produced. Each step is + best-effort and guarded so a bookkeeping failure never corrupts the + already-delivered response. The concurrency permit is always released. + """ + usage = extract_usage(final_obj) if final_obj else TokenUsage(0, 0) + latency_ms = int((time.monotonic() - started) * 1000) + try: + await self._account_budget(usage) + except Exception as exc: # noqa: BLE001 - never break a delivered response + _log.warning("budget_account_failed", error=str(exc)) + try: + metrics.record_request(self._p.tenant_name, model or "unknown", 200, latency_ms / 1000) + metrics.record_tokens( + self._p.tenant_name, model or "unknown", usage.tokens_in, usage.tokens_out + ) + except Exception as exc: # noqa: BLE001 - metrics must not break responses + _log.warning("metrics_record_failed", error=str(exc)) + await self._write_audit(model, path, method, usage, latency_ms, 200) + await self._release_concurrency() + + async def _account_budget(self, usage: TokenUsage) -> None: + """Decrement Redis budget counters and persist to the Postgres ledger.""" + for period, limit in self._budget_periods(): + if limit is not None: + await self._budget.consume(str(self._p.key_id), period, usage.total) + if self._sessionmaker is not None and usage.total >= 0: + try: + async with self._sessionmaker() as session: + ledger = BudgetLedger(session) + for period, _limit in self._budget_periods(): + await ledger.record_usage( + str(self._p.key_id), period, usage.tokens_in, usage.tokens_out + ) + await session.commit() + except Exception as exc: # noqa: BLE001 - ledger is durable backstop; never break response + _log.warning("ledger_record_failed", error=str(exc)) + + async def _write_audit( + self, + model: str, + path: str, + method: str, + usage: TokenUsage, + latency_ms: int, + status: int, + ) -> None: + record = AuditRecord( + request_id=uuid.UUID(self._request_id) + if _is_uuid(self._request_id) + else uuid.uuid4(), + method=method, + path=path, + status=status, + ts=datetime.datetime.now(datetime.UTC), + tenant_id=self._p.tenant_id, + key_id=self._p.key_id, + key_prefix=self._p.key_prefix, + model=model or None, + tokens_in=usage.tokens_in, + tokens_out=usage.tokens_out, + latency_ms=latency_ms, + client_ip=self._client_ip(), + user_agent=self._request.headers.get("user-agent"), + ) + try: + await self._audit.enqueue(record) + except Exception as exc: # noqa: BLE001 - audit enqueue must not break response + _log.warning("audit_enqueue_failed", error=str(exc)) + + async def _release_concurrency(self) -> None: + try: + await self._conc.release(self._concurrency_key) + except Exception as exc: # noqa: BLE001 - release is best-effort + _log.warning("concurrency_release_failed", error=str(exc)) + + def _client_ip(self) -> str | None: + xff = self._request.headers.get("x-forwarded-for") + if xff: + return xff.split(",")[0].strip() + return self._request.client.host if self._request.client else None + + +def _merge_last_ndjson(chunk: bytes, prev: dict[str, object]) -> dict[str, object]: + """Track the last complete NDJSON object seen in a raw byte chunk. + + Token counts live on the final ``done`` frame. We parse only complete lines + and keep the last successfully-parsed object; partial trailing data is + ignored here and will be completed by a subsequent chunk. + """ + text = chunk.decode("utf-8", errors="ignore") + last = prev + for line in text.split("\n"): + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + except json.JSONDecodeError: + continue + if isinstance(obj, dict): + last = obj + return last + + +def _is_uuid(value: str) -> bool: + try: + uuid.UUID(value) + except ValueError: + return False + return True + + +def model_of(body: dict[str, object]) -> str: + """Extract the requested model name from a request body (empty if absent).""" + model = body.get("model") + return model if isinstance(model, str) else "" + + +async def read_json_body(request: Request, settings: Settings) -> dict[str, object]: + """Read + size-limit the request body, returning the parsed JSON object.""" + raw = await request.body() + if len(raw) > settings.max_request_body_bytes: + raise RequestTooLargeError(internal_detail=f"body {len(raw)} bytes") + if not raw: + return {} + try: + parsed: Any = json.loads(raw) + except json.JSONDecodeError as exc: + raise AuthorizationError(internal_detail="invalid JSON body") from exc + return parsed if isinstance(parsed, dict) else {} + + +__all__ = [ + "NDJSON_MEDIA_TYPE", + "SSE_MEDIA_TYPE", + "Pipeline", + "RateHeaders", + "model_of", + "read_json_body", +] diff --git a/src/neuronetz_gateway/proxy/token_counter.py b/src/neuronetz_gateway/proxy/token_counter.py new file mode 100644 index 0000000..e6c32d2 --- /dev/null +++ b/src/neuronetz_gateway/proxy/token_counter.py @@ -0,0 +1,50 @@ +"""Precise token accounting parsed from Ollama responses (SPEC §2, §13.1). + +Tokens are read from Ollama's reported ``prompt_eval_count`` (input) and +``eval_count`` (output) on the final stream frame — never heuristically +estimated. Embeddings charge ``prompt_eval_count`` only (SPEC §13.1); they have +no ``eval_count`` so output tokens are reported as zero. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class TokenUsage: + """Token counts extracted from an Ollama response.""" + + tokens_in: int + tokens_out: int + + @property + def total(self) -> int: + """Combined input + output tokens.""" + return self.tokens_in + self.tokens_out + + +def _as_int(value: object) -> int: + """Coerce an Ollama-reported count to a non-negative int (0 if absent/bad).""" + if isinstance(value, bool): + return 0 + if isinstance(value, int): + return max(value, 0) + if isinstance(value, float): + return max(int(value), 0) + return 0 + + +def extract_usage(final_frame: dict[str, object]) -> TokenUsage: + """Extract ``prompt_eval_count``/``eval_count`` from the final Ollama frame. + + Works for chat/generate final frames and for embeddings responses (which + carry ``prompt_eval_count`` but no ``eval_count``). + """ + return TokenUsage( + tokens_in=_as_int(final_frame.get("prompt_eval_count")), + tokens_out=_as_int(final_frame.get("eval_count")), + ) + + +__all__ = ["TokenUsage", "extract_usage"] diff --git a/src/neuronetz_gateway/proxy/translate.py b/src/neuronetz_gateway/proxy/translate.py new file mode 100644 index 0000000..8b77c4c --- /dev/null +++ b/src/neuronetz_gateway/proxy/translate.py @@ -0,0 +1,245 @@ +"""OpenAI <-> Ollama schema translation (SPEC §6.3). + +Native ``/api/*`` speaks NDJSON; OpenAI-compat ``/v1/*`` speaks SSE +(``data: {...}\\n\\n`` … ``data: [DONE]\\n\\n``). These helpers translate request +bodies in both directions and convert native Ollama stream frames into OpenAI +chunk objects, preserving streaming. Unknown OpenAI sampling params are mapped +into Ollama's ``options`` block; unrecognized keys are dropped rather than +forwarded blindly. +""" + +from __future__ import annotations + +import time +import uuid +from typing import Any + +# OpenAI sampling params that map into Ollama's ``options`` object. +_OPTION_KEYS: tuple[str, ...] = ( + "temperature", + "top_p", + "top_k", + "seed", + "stop", + "presence_penalty", + "frequency_penalty", +) + + +def _as_int(value: object) -> int: + """Coerce a JSON-derived value to a non-negative int (0 if absent/invalid).""" + if isinstance(value, bool): + return 0 + if isinstance(value, int): + return max(value, 0) + if isinstance(value, float): + return max(int(value), 0) + return 0 + + +def _build_options(payload: dict[str, Any]) -> dict[str, Any]: + """Collect OpenAI sampling params into an Ollama ``options`` mapping.""" + options: dict[str, Any] = {} + for key in _OPTION_KEYS: + if key in payload and payload[key] is not None: + options[key] = payload[key] + if "max_tokens" in payload and payload["max_tokens"] is not None: + options["num_predict"] = payload["max_tokens"] + return options + + +def openai_chat_to_ollama(payload: dict[str, object]) -> dict[str, object]: + """Translate an OpenAI chat-completion request to an Ollama ``/api/chat`` body.""" + body: dict[str, object] = { + "model": payload.get("model"), + "messages": payload.get("messages", []), + "stream": bool(payload.get("stream", False)), + } + options = _build_options(dict(payload)) + if options: + body["options"] = options + return body + + +def openai_completion_to_ollama(payload: dict[str, object]) -> dict[str, object]: + """Translate an OpenAI completion request to an Ollama ``/api/generate`` body.""" + prompt = payload.get("prompt", "") + if isinstance(prompt, list): + prompt = "".join(str(p) for p in prompt) + body: dict[str, object] = { + "model": payload.get("model"), + "prompt": prompt, + "stream": bool(payload.get("stream", False)), + } + options = _build_options(dict(payload)) + if options: + body["options"] = options + return body + + +def openai_embeddings_to_ollama(payload: dict[str, object]) -> dict[str, object]: + """Translate an OpenAI embeddings request to an Ollama ``/api/embed`` body.""" + return { + "model": payload.get("model"), + "input": payload.get("input", ""), + } + + +def _completion_id() -> str: + """Generate an OpenAI-style completion id.""" + return f"chatcmpl-{uuid.uuid4().hex}" + + +def ollama_chat_chunk_to_openai( + chunk: dict[str, object], *, completion_id: str, model: str, created: int +) -> dict[str, object]: + """Translate one Ollama ``/api/chat`` NDJSON frame to an OpenAI SSE chunk.""" + done = bool(chunk.get("done")) + if done: + return { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": _as_int(chunk.get("prompt_eval_count")), + "completion_tokens": _as_int(chunk.get("eval_count")), + "total_tokens": _as_int(chunk.get("prompt_eval_count")) + + _as_int(chunk.get("eval_count")), + }, + } + message = chunk.get("message") + content = "" + if isinstance(message, dict): + content = str(message.get("content", "")) + return { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}], + } + + +def ollama_generate_chunk_to_openai( + chunk: dict[str, object], *, completion_id: str, model: str, created: int +) -> dict[str, object]: + """Translate one Ollama ``/api/generate`` NDJSON frame to an OpenAI text chunk.""" + done = bool(chunk.get("done")) + if done: + return { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model, + "choices": [{"index": 0, "text": "", "finish_reason": "stop"}], + "usage": { + "prompt_tokens": _as_int(chunk.get("prompt_eval_count")), + "completion_tokens": _as_int(chunk.get("eval_count")), + "total_tokens": _as_int(chunk.get("prompt_eval_count")) + + _as_int(chunk.get("eval_count")), + }, + } + return { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model, + "choices": [{"index": 0, "text": str(chunk.get("response", "")), "finish_reason": None}], + } + + +def ollama_chat_to_openai(payload: dict[str, object]) -> dict[str, object]: + """Translate a *non-streaming* Ollama chat response to an OpenAI completion.""" + message = payload.get("message") + content = "" + if isinstance(message, dict): + content = str(message.get("content", "")) + prompt_tokens = _as_int(payload.get("prompt_eval_count")) + completion_tokens = _as_int(payload.get("eval_count")) + return { + "id": _completion_id(), + "object": "chat.completion", + "created": int(time.time()), + "model": str(payload.get("model", "")), + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + + +def ollama_generate_to_openai(payload: dict[str, object]) -> dict[str, object]: + """Translate a *non-streaming* Ollama generate response to OpenAI completion.""" + prompt_tokens = _as_int(payload.get("prompt_eval_count")) + completion_tokens = _as_int(payload.get("eval_count")) + return { + "id": _completion_id(), + "object": "text_completion", + "created": int(time.time()), + "model": str(payload.get("model", "")), + "choices": [ + {"index": 0, "text": str(payload.get("response", "")), "finish_reason": "stop"} + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + + +def ollama_embed_to_openai(payload: dict[str, object], model: str) -> dict[str, object]: + """Translate an Ollama ``/api/embed`` response to the OpenAI embeddings shape.""" + raw = payload.get("embeddings") + vectors: list[list[float]] = raw if isinstance(raw, list) else [] + prompt_tokens = _as_int(payload.get("prompt_eval_count")) + return { + "object": "list", + "data": [ + {"object": "embedding", "index": i, "embedding": vec} + for i, vec in enumerate(vectors) + ], + "model": model, + "usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}, + } + + +def models_to_openai_list(names: list[str]) -> dict[str, object]: + """Render a list of model names in the OpenAI ``/v1/models`` list format.""" + created = int(time.time()) + return { + "object": "list", + "data": [ + {"id": name, "object": "model", "created": created, "owned_by": "neuronetz"} + for name in names + ], + } + + +def new_completion_id() -> str: + """Public helper to mint a completion id for a streaming response.""" + return _completion_id() + + +__all__ = [ + "models_to_openai_list", + "new_completion_id", + "ollama_chat_chunk_to_openai", + "ollama_chat_to_openai", + "ollama_embed_to_openai", + "ollama_generate_chunk_to_openai", + "ollama_generate_to_openai", + "openai_chat_to_ollama", + "openai_completion_to_ollama", + "openai_embeddings_to_ollama", +] diff --git a/src/neuronetz_gateway/ratelimit/__init__.py b/src/neuronetz_gateway/ratelimit/__init__.py new file mode 100644 index 0000000..dfca2ea --- /dev/null +++ b/src/neuronetz_gateway/ratelimit/__init__.py @@ -0,0 +1,3 @@ +"""Rate limiting: sliding-window RPM/TPM and concurrency semaphore (Redis).""" + +from __future__ import annotations diff --git a/src/neuronetz_gateway/ratelimit/concurrency.py b/src/neuronetz_gateway/ratelimit/concurrency.py new file mode 100644 index 0000000..6ec0123 --- /dev/null +++ b/src/neuronetz_gateway/ratelimit/concurrency.py @@ -0,0 +1,66 @@ +"""Concurrent-connection semaphore backed by Redis ``INCR`` with a TTL guard. + +Acquired before proxying and released on stream close. A TTL on the counter +prevents a crashed worker from leaking permits forever (self-healing). Fails +closed (SPEC §4.4): if Redis is unreachable, acquisition raises +:class:`DependencyUnavailableError` so the caller denies (503). +""" + +from __future__ import annotations + +from collections.abc import Awaitable +from typing import cast + +import redis.asyncio as redis +from redis.exceptions import RedisError + +from neuronetz_gateway.errors import DependencyUnavailableError + +# Default guard TTL: a permit auto-expires if not explicitly released, so a +# crashed worker cannot pin the semaphore. Comfortably longer than any single +# stream is expected to take is set per-call via ``ttl_s``. + + +class ConcurrencyLimiter: + """Redis-backed concurrency cap with a self-healing TTL guard.""" + + def __init__(self, client: redis.Redis) -> None: + self._client = client + + async def acquire(self, scope_key: str, limit: int, ttl_s: int) -> bool: + """Try to acquire a permit; return False (deny) if at capacity. + + Increments the counter, refreshes the TTL guard, and rolls back the + increment if the new value exceeds ``limit``. Raises on Redis failure so + the caller fails closed. + """ + try: + count = await cast("Awaitable[int]", self._client.incr(scope_key)) + await self._client.expire(scope_key, ttl_s) + except RedisError as exc: + raise DependencyUnavailableError( + internal_detail=f"concurrency redis error: {exc!r}" + ) from exc + if count > limit: + # Over capacity: undo our increment and deny. + try: + await self._client.decr(scope_key) + except RedisError: + # Permit self-heals via the TTL guard; denial still stands. + return False + return False + return True + + async def release(self, scope_key: str) -> None: + """Release a previously acquired permit (best-effort; never raises).""" + try: + count = await cast("Awaitable[int]", self._client.decr(scope_key)) + if count < 0: + await self._client.set(scope_key, 0) + except RedisError: + # The TTL guard will reclaim the permit; releasing must not break the + # request that already completed. + return + + +__all__ = ["ConcurrencyLimiter"] diff --git a/src/neuronetz_gateway/ratelimit/sliding_window.py b/src/neuronetz_gateway/ratelimit/sliding_window.py new file mode 100644 index 0000000..51da646 --- /dev/null +++ b/src/neuronetz_gateway/ratelimit/sliding_window.py @@ -0,0 +1,109 @@ +"""Sliding-window rate limiter backed by an atomic Redis Lua script. + +Enforces per-key and per-tenant RPM and per-key TPM. The window is a sorted set +of timestamped hits (a true sliding window, not a fixed bucket): each check +trims entries older than ``window_s``, sums the cost of what remains, and admits +the new cost only if it keeps the total within ``limit``. The trim + sum + +conditional add run inside one Lua script so the decision is atomic across +concurrent workers (SPEC §4.3 step 4). + +Fail-closed (SPEC §4.4): if Redis is unavailable the limiter raises +:class:`DependencyUnavailableError`; the caller must deny (503), never allow. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass + +import redis.asyncio as redis +from redis.exceptions import RedisError + +from neuronetz_gateway.errors import DependencyUnavailableError + +# KEYS[1] = zset key +# ARGV[1] = now (ms) ARGV[2] = window (ms) ARGV[3] = limit +# ARGV[4] = cost ARGV[5] = unique member suffix +# Returns: {allowed (1/0), used_after, retry_after_ms} +_LUA = """ +local key = KEYS[1] +local now = tonumber(ARGV[1]) +local window = tonumber(ARGV[2]) +local limit = tonumber(ARGV[3]) +local cost = tonumber(ARGV[4]) +local member = ARGV[5] + +redis.call('ZREMRANGEBYSCORE', key, 0, now - window) +-- Each surviving member encodes its cost as a ':' suffix; sum them so the +-- window total is cost-weighted (RPM uses cost 1, TPM uses token count). +local total = 0 +local data = redis.call('ZRANGEBYSCORE', key, now - window, now) +for i = 1, #data do + local sep = string.find(data[i], ':[^:]*$') + if sep then + total = total + tonumber(string.sub(data[i], sep + 1)) + else + total = total + 1 + end +end + +if total + cost > limit then + local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') + local retry = window + if oldest[2] then + retry = (tonumber(oldest[2]) + window) - now + if retry < 0 then retry = 0 end + end + return {0, total, retry} +end + +redis.call('ZADD', key, now, member .. ':' .. cost) +redis.call('PEXPIRE', key, window) +return {1, total + cost, 0} +""" + + +@dataclass(frozen=True, slots=True) +class RateLimitResult: + """Outcome of a rate-limit check.""" + + allowed: bool + limit: int + remaining: int + retry_after_s: int | None + + +class SlidingWindowLimiter: + """Redis-backed sliding-window limiter (atomic via Lua).""" + + def __init__(self, client: redis.Redis) -> None: + self._client = client + self._script = client.register_script(_LUA) + + async def check(self, key: str, limit: int, window_s: int, cost: int = 1) -> RateLimitResult: + """Atomically record a hit of ``cost`` and report admission. + + Raises :class:`DependencyUnavailableError` if Redis cannot be reached, so + the caller fails closed. + """ + now_ms = int(time.time() * 1000) + window_ms = window_s * 1000 + member = f"{now_ms}-{id(object())}" + try: + raw = await self._script( + keys=[key], args=[now_ms, window_ms, limit, cost, member] + ) + except RedisError as exc: + raise DependencyUnavailableError( + internal_detail=f"ratelimit redis error: {exc!r}" + ) from exc + allowed_i, used_after, retry_ms = (int(raw[0]), int(raw[1]), int(raw[2])) + allowed = allowed_i == 1 + remaining = max(limit - used_after, 0) + retry_after_s = None if allowed else max(1, (retry_ms + 999) // 1000) + return RateLimitResult( + allowed=allowed, limit=limit, remaining=remaining, retry_after_s=retry_after_s + ) + + +__all__ = ["RateLimitResult", "SlidingWindowLimiter"] diff --git a/src/neuronetz_gateway/revocation.py b/src/neuronetz_gateway/revocation.py new file mode 100644 index 0000000..abca4bd --- /dev/null +++ b/src/neuronetz_gateway/revocation.py @@ -0,0 +1,97 @@ +"""Key-revocation NOTIFY listener (SPEC §4.5). + +Console (or the gateway's own CLI) revokes a key by inserting into +``gateway.revocations``; an ``AFTER INSERT`` trigger fires +``pg_notify('key_revoked', key_id)``. This background task LISTENs on that +channel and, on each notification, evicts the Redis auth-cache entry for the +revoked key's prefix so the next request misses the cache, re-reads the DB, +finds the key non-active, and is rejected — making revocation effective within +one Redis RTT without any cross-service HTTP. + +The listener resolves ``key_id -> prefix`` via a short DB lookup (the NOTIFY +payload is the key id, but the cache is keyed by prefix). It is resilient: a +dropped connection is retried with backoff. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import uuid + +import asyncpg +import redis.asyncio as redis +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from neuronetz_gateway.config import Settings +from neuronetz_gateway.db.models import ApiKey +from neuronetz_gateway.observability.logging import get_logger + +_log = get_logger("revocation") + +_CHANNEL = "key_revoked" +_CACHE_PREFIX = "gateway:key:" + + +def _asyncpg_dsn(database_url: str) -> str: + """Strip the SQLAlchemy ``+asyncpg`` driver tag for a raw asyncpg connect.""" + return database_url.replace("postgresql+asyncpg://", "postgresql://") + + +async def _evict( + key_id_text: str, + sessionmaker: async_sessionmaker[AsyncSession], + redis_client: redis.Redis, +) -> None: + """Resolve the key id to its prefix and delete the cached principal.""" + try: + key_id = uuid.UUID(key_id_text) + except ValueError: + _log.warning("revocation_bad_payload", payload=key_id_text) + return + try: + async with sessionmaker() as session: + key = await session.get(ApiKey, key_id) + prefix = key.prefix if key is not None else None + if prefix is not None: + await redis_client.delete(_CACHE_PREFIX + prefix) + _log.info("revocation_cache_evicted", key_prefix=prefix) + except Exception as exc: # noqa: BLE001 - listener must survive transient errors + _log.warning("revocation_evict_failed", error=str(exc)) + + +async def revocation_listener( + settings: Settings, + redis_client: redis.Redis, + sessionmaker: async_sessionmaker[AsyncSession], +) -> None: + """LISTEN on ``key_revoked`` and evict the Redis cache on each notification.""" + dsn = _asyncpg_dsn(settings.database_url) + while True: + conn = None + try: + conn = await asyncpg.connect(dsn) + + def _on_notify( + _c: object, _pid: int, _channel: str, payload: str + ) -> None: + # Schedule the async eviction without blocking the callback. + asyncio.create_task(_evict(payload, sessionmaker, redis_client)) # noqa: RUF006 + + await conn.add_listener(_CHANNEL, _on_notify) + _log.info("revocation_listener_started") + # Wait forever; notifications arrive via the callback. asyncio.Event + # is a cancel-friendly substitute for an unbounded sleep loop. + await asyncio.Event().wait() + except asyncio.CancelledError: + raise + except Exception as exc: # noqa: BLE001 - reconnect on any failure + _log.warning("revocation_listener_reconnect", error=str(exc)) + await asyncio.sleep(2) + finally: + if conn is not None: + with contextlib.suppress(Exception): + await conn.close() + + +__all__ = ["revocation_listener"] diff --git a/src/neuronetz_gateway/routes/ollama_native.py b/src/neuronetz_gateway/routes/ollama_native.py new file mode 100644 index 0000000..ba01987 --- /dev/null +++ b/src/neuronetz_gateway/routes/ollama_native.py @@ -0,0 +1,168 @@ +"""Native Ollama passthrough routes (SPEC §6.1). + +All proxied endpoints run through the shared :class:`Pipeline` (auth has already +attached the principal in middleware): scope + model + endpoint allowlist, rate +limit, budget, concurrency, body validation, then stream/relay with post-close +token-count + audit + budget accounting. + +Mutating endpoints (``/api/pull|push|create|copy|delete``, ``/api/blobs/*``) and +``/api/ps`` are hard-blocked (SPEC §6.2) and intentionally NOT routed; a catch-all +returns a generic 403 so their existence is never confirmed. +""" + +from __future__ import annotations + +from fastapi import APIRouter, Request +from starlette.responses import JSONResponse, Response + +from neuronetz_gateway import __version__ +from neuronetz_gateway.deps import ( + ConfigDep, + DiscoveryCacheDep, + OllamaClientDep, + PipelineDep, + PrincipalDep, +) +from neuronetz_gateway.errors import AuthorizationError +from neuronetz_gateway.proxy.allowlist import is_hard_blocked, resolve_effective_models +from neuronetz_gateway.proxy.pipeline import model_of, read_json_body + +router = APIRouter(prefix="/api", tags=["ollama-native"]) + + +@router.post("/chat") +async def chat(request: Request, pipeline: PipelineDep) -> Response: + """Proxy ``POST /api/chat`` (streamed NDJSON / non-streamed).""" + return await _chat_or_generate(request, pipeline, path="/api/chat", scope="chat") + + +@router.post("/generate") +async def generate(request: Request, pipeline: PipelineDep) -> Response: + """Proxy ``POST /api/generate`` (streamed NDJSON / non-streamed).""" + return await _chat_or_generate(request, pipeline, path="/api/generate", scope="chat") + + +async def _chat_or_generate( + request: Request, pipeline: PipelineDep, *, path: str, scope: str +) -> Response: + body = await read_json_body(request, pipeline.settings) + model = model_of(body) + pipeline.check_scope(scope) + pipeline.check_endpoint(path) + pipeline.check_model(model) + pipeline.validate_body(body) + await pipeline.enforce_limits() + if bool(body.get("stream", True)): + return await pipeline.stream_native("POST", path, body, model) + return await pipeline.request_native("POST", path, body, model) + + +@router.post("/embeddings") +async def embeddings(request: Request, pipeline: PipelineDep) -> Response: + """Proxy ``POST /api/embeddings`` (legacy, non-streamed).""" + return await _embeddings(request, pipeline, "/api/embeddings") + + +@router.post("/embed") +async def embed(request: Request, pipeline: PipelineDep) -> Response: + """Proxy ``POST /api/embed`` (non-streamed).""" + return await _embeddings(request, pipeline, "/api/embed") + + +async def _embeddings(request: Request, pipeline: PipelineDep, path: str) -> Response: + body = await read_json_body(request, pipeline.settings) + model = model_of(body) + pipeline.check_scope("embeddings") + pipeline.check_endpoint(path) + pipeline.check_model(model) + await pipeline.enforce_limits() + return await pipeline.request_native("POST", path, body, model) + + +@router.get("/tags") +async def tags(principal: PrincipalDep, discovery: DiscoveryCacheDep) -> JSONResponse: + """Return the tenant's effective model set (live-discovered ∩ allowed).""" + effective = resolve_effective_models( + allow_all=principal.allow_all_models, + allowed_models=principal.allowed_models, + discovered=discovery.names, + ) + models = [ + { + "name": m.name, + "model": m.name, + "modified_at": m.modified_at, + "size": m.size_bytes, + "details": { + "family": m.family, + "parameter_size": m.parameter_size, + "quantization_level": m.quantization, + }, + } + for m in discovery.models + if m.name in effective + ] + return JSONResponse({"models": models}) + + +@router.post("/show") +async def show( + request: Request, + principal: PrincipalDep, + discovery: DiscoveryCacheDep, + ollama: OllamaClientDep, + settings: ConfigDep, +) -> JSONResponse: + """Proxy ``POST /api/show`` for an effective-set model; sanitize the result.""" + body = await read_json_body(request, settings) + name = body.get("model") or body.get("name") + model = name if isinstance(name, str) else "" + if not model or model not in resolve_effective_models( + allow_all=principal.allow_all_models, + allowed_models=principal.allowed_models, + discovered=discovery.names, + ): + raise AuthorizationError(internal_detail="show: model not in effective set") + resp = await ollama.request("POST", "/api/show", {"model": model}) + payload = resp.json() + raw: dict[str, object] = payload if isinstance(payload, dict) else {} + # Strip system prompt + template (SPEC §6.1: no system prompts, no template). + raw_details = raw.get("details") + details: dict[str, object] = raw_details if isinstance(raw_details, dict) else {} + return JSONResponse( + { + "model": model, + "details": { + "family": details.get("family"), + "parameter_size": details.get("parameter_size"), + "quantization_level": details.get("quantization_level"), + }, + } + ) + + +@router.get("/version") +async def version() -> JSONResponse: + """Return the gateway version (never Ollama's; SPEC §6.1).""" + return JSONResponse({"version": __version__}) + + +@router.api_route( + "/{rest:path}", + methods=["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"], + include_in_schema=False, +) +async def catch_all(rest: str) -> Response: + """Generic 403 for any other ``/api/*`` path (hard-blocked or unknown). + + Mutating endpoints and ``/api/ps`` resolve here and return the same generic + forbidden response, so the gateway never confirms which upstream endpoints + exist (SPEC §6.2, §13.6). + """ + full = f"/api/{rest}" + if is_hard_blocked(full): + raise AuthorizationError(internal_detail=f"hard-blocked {full}") + raise AuthorizationError(internal_detail=f"unrouted upstream path {full}") + + +__all__ = ["router"] diff --git a/src/neuronetz_gateway/routes/openai_compat.py b/src/neuronetz_gateway/routes/openai_compat.py new file mode 100644 index 0000000..cca655a --- /dev/null +++ b/src/neuronetz_gateway/routes/openai_compat.py @@ -0,0 +1,118 @@ +"""OpenAI-compatible routes (SPEC §6.3). + +Each route translates the OpenAI request into the native Ollama body, runs the +same :class:`Pipeline` enforcement as the native routes, and translates the +response back. Streaming uses SSE (``data: {...}\\n\\n`` … ``data: [DONE]\\n\\n``); +non-streaming returns a single OpenAI-shaped JSON object. ``/v1/models`` returns +the tenant's effective discovered set in OpenAI list format. +""" + +from __future__ import annotations + +import time + +from fastapi import APIRouter, Request +from starlette.responses import JSONResponse, Response + +from neuronetz_gateway.deps import ( + DiscoveryCacheDep, + PipelineDep, + PrincipalDep, +) +from neuronetz_gateway.proxy.allowlist import resolve_effective_models +from neuronetz_gateway.proxy.pipeline import model_of, read_json_body +from neuronetz_gateway.proxy.translate import ( + models_to_openai_list, + new_completion_id, + ollama_chat_chunk_to_openai, + ollama_chat_to_openai, + ollama_embed_to_openai, + ollama_generate_chunk_to_openai, + ollama_generate_to_openai, + openai_chat_to_ollama, + openai_completion_to_ollama, + openai_embeddings_to_ollama, +) + +router = APIRouter(prefix="/v1", tags=["openai-compat"]) + + +@router.post("/chat/completions") +async def chat_completions(request: Request, pipeline: PipelineDep) -> Response: + """OpenAI ``/v1/chat/completions`` -> Ollama ``/api/chat``.""" + payload = await read_json_body(request, pipeline.settings) + body = openai_chat_to_ollama(payload) + model = model_of(body) + pipeline.check_scope("chat") + pipeline.check_endpoint("/api/chat") + pipeline.check_model(model) + pipeline.validate_body(body) + await pipeline.enforce_limits() + if bool(payload.get("stream", False)): + completion_id = new_completion_id() + created = int(time.time()) + + def translate(chunk: dict[str, object]) -> dict[str, object]: + return ollama_chat_chunk_to_openai( + chunk, completion_id=completion_id, model=model, created=created + ) + + return await pipeline.stream_openai("POST", "/api/chat", body, model, translate) + return await pipeline.request_translated( + "POST", "/api/chat", body, model, ollama_chat_to_openai + ) + + +@router.post("/completions") +async def completions(request: Request, pipeline: PipelineDep) -> Response: + """OpenAI ``/v1/completions`` -> Ollama ``/api/generate``.""" + payload = await read_json_body(request, pipeline.settings) + body = openai_completion_to_ollama(payload) + model = model_of(body) + pipeline.check_scope("chat") + pipeline.check_endpoint("/api/generate") + pipeline.check_model(model) + pipeline.validate_body(body) + await pipeline.enforce_limits() + if bool(payload.get("stream", False)): + completion_id = new_completion_id() + created = int(time.time()) + + def translate(chunk: dict[str, object]) -> dict[str, object]: + return ollama_generate_chunk_to_openai( + chunk, completion_id=completion_id, model=model, created=created + ) + + return await pipeline.stream_openai("POST", "/api/generate", body, model, translate) + return await pipeline.request_translated( + "POST", "/api/generate", body, model, ollama_generate_to_openai + ) + + +@router.post("/embeddings") +async def embeddings(request: Request, pipeline: PipelineDep) -> Response: + """OpenAI ``/v1/embeddings`` -> Ollama ``/api/embed``.""" + payload = await read_json_body(request, pipeline.settings) + body = openai_embeddings_to_ollama(payload) + model = model_of(body) + pipeline.check_scope("embeddings") + pipeline.check_endpoint("/api/embed") + pipeline.check_model(model) + await pipeline.enforce_limits() + return await pipeline.request_translated( + "POST", "/api/embed", body, model, lambda obj: ollama_embed_to_openai(obj, model) + ) + + +@router.get("/models") +async def models(principal: PrincipalDep, discovery: DiscoveryCacheDep) -> JSONResponse: + """OpenAI ``/v1/models`` -> the tenant's effective discovered set.""" + effective = resolve_effective_models( + allow_all=principal.allow_all_models, + allowed_models=principal.allowed_models, + discovered=discovery.names, + ) + return JSONResponse(models_to_openai_list(sorted(effective))) + + +__all__ = ["router"]