proxy: streaming, discovery, OpenAI-compat, rate-limit, budget, audit

The hot path. A single Pipeline class owns enforcement so the eight
non-negotiables can be reviewed in one place.

- Native /api/chat, /api/generate (NDJSON streaming + non-stream), /api/tags,
  /api/show (system-prompt + template stripped), /api/embed(dings), /api/version
  (returns gateway version, not Ollama's). Endpoint catch-all returns the same
  generic 403 for hard-blocked and unknown /api/* paths so attackers cannot
  enumerate which mutating endpoints exist.
- OpenAI-compat /v1/chat/completions, /v1/completions, /v1/embeddings,
  /v1/models with SSE (`data: {...}` + final `data: [DONE]`); preserves
  streaming end-to-end.
- Model discovery (SPEC §4.6): background poller against Ollama /api/tags;
  Redis + in-process cache (TTL = MODEL_DISCOVERY_CACHE_TTL_S, refresh =
  MODEL_DISCOVERY_REFRESH_S); fail-closed when the discovered set is empty.
- Effective-set resolution in proxy/allowlist.py:
    allow_all = key.allow_all_models ?? tenant.allow_all_models
    effective = discovered if allow_all
                else (key.allowed_models ?? tenant.allowed_models) ∩ discovered
  A non-effective model returns the same generic 403 whether it's installed-
  but-unpermitted or doesn't exist at all (no enumeration leak).
- Sliding-window rate limit (Redis Lua, single round-trip) for per-key +
  per-tenant RPM and per-key TPM. Redis-INCR/DECR concurrency semaphore with
  TTL guard. Token-budget counters per (key, period) with a Postgres ledger
  for reconciliation across resets. Headers per SPEC §6.5 on every response;
  429 carries Retry-After; Redis outage → 503 (fail closed, never 200).
- Token counting from the FINAL stream object (NDJSON `done` or the SSE chunk
  carrying `usage`); the audit row is written AFTER stream close so TTFB is
  never degraded by bookkeeping.
- Audit writer: asyncio.Queue + bounded ring buffer; deny-mode flip on overflow.
  Optional prompt log per key (TTL'd).
- Revocation listener: asyncpg LISTEN on key_revoked → evict the Redis cache
  entry within ~1s of the console writing to gateway.revocations.
- Prometheus counters/histograms labeled by tenant only (per SPEC §13.3).
This commit is contained in:
Stephan Berbig
2026-05-26 20:52:33 +02:00
parent 6431b2f72c
commit 6a92bc8ce9
20 changed files with 2139 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
"""Audit logging: buffered async audit writer and opt-in prompt log."""
from __future__ import annotations

View File

@@ -0,0 +1,63 @@
"""Opt-in, TTL'd prompt logging (SPEC §2, §5).
Disabled by default; enabled per key (or inherited from tenant via the resolved
:class:`Principal.log_prompts`). Each row carries a ``retention_until`` deadline
(swept in Phase 4). A redaction hook runs before persistence so sensitive spans
can be scrubbed without changing call sites.
"""
from __future__ import annotations
import datetime
from collections.abc import Callable
from dataclasses import dataclass
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from neuronetz_gateway.db.models import PromptLog
# A redaction hook maps a request/response body to a sanitized copy. The default
# is identity; operators can install a stricter hook.
RedactionHook = Callable[[dict[str, object]], dict[str, object]]
def _identity(body: dict[str, object]) -> dict[str, object]:
return body
@dataclass(frozen=True, slots=True)
class PromptRecord:
"""A captured request/response pair pending TTL'd persistence."""
audit_id: int
key_id: UUID
request_body: dict[str, object]
response_text: str | None
retention_days: int
class PromptLogWriter:
"""Persists opt-in prompt records with a retention deadline."""
def __init__(self, session: AsyncSession, redact: RedactionHook | None = None) -> None:
self._session = session
self._redact = redact or _identity
async def write(self, record: PromptRecord) -> None:
"""Persist a prompt record to ``gateway.prompt_log``."""
retention_until = datetime.datetime.now(datetime.UTC) + datetime.timedelta(
days=record.retention_days
)
row = PromptLog(
audit_id=record.audit_id,
key_id=record.key_id,
request_body=self._redact(record.request_body),
response_text=record.response_text,
retention_until=retention_until,
)
self._session.add(row)
await self._session.flush()
__all__ = ["PromptLogWriter", "PromptRecord", "RedactionHook"]

View File

@@ -0,0 +1,152 @@
"""Buffered async audit-log writer (SPEC §4.4).
Audit rows are enqueued *after* stream close so the hot path is never delayed
(non-negotiable #6). A background drain task persists them to Postgres. On a
Postgres write failure rows are buffered in a bounded in-memory ring (max
``AUDIT_BUFFER_SIZE``) and retried; if the ring fills, the writer flips to
**deny mode** (SPEC §4.4) — :pyattr:`deny_mode` goes True and the request path
must refuse new work until the backlog drains.
The writer holds the session factory directly (it runs outside request scope).
``enqueue`` never blocks on the DB; it only touches the in-memory queue/ring.
"""
from __future__ import annotations
import asyncio
import collections
import datetime
from dataclasses import asdict, dataclass
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from neuronetz_gateway.db.repositories import AuditRepository
from neuronetz_gateway.observability.logging import get_logger
_log = get_logger("audit")
@dataclass(frozen=True, slots=True)
class AuditRecord:
"""A single audit-log row queued for persistence."""
request_id: UUID
method: str
path: str
status: int
ts: datetime.datetime
tenant_id: UUID | None = None
key_id: UUID | None = None
key_prefix: str | None = None
model: str | None = None
tokens_in: int | None = None
tokens_out: int | None = None
latency_ms: int | None = None
client_ip: str | None = None
user_agent: str | None = None
error_code: str | None = None
def as_columns(self) -> dict[str, object]:
"""Map to ``gateway.audit_log`` column kwargs."""
return asdict(self)
class AuditWriter:
"""Buffered, fail-safe writer for ``gateway.audit_log``."""
def __init__(
self,
buffer_size: int,
sessionmaker: async_sessionmaker[AsyncSession] | None = None,
) -> None:
self._buffer_size = buffer_size
self._sessionmaker = sessionmaker
self._ring: collections.deque[AuditRecord] = collections.deque(maxlen=buffer_size)
self._queue: asyncio.Queue[AuditRecord] = asyncio.Queue()
self._deny_mode = False
self._task: asyncio.Task[None] | None = None
@property
def deny_mode(self) -> bool:
"""True when the buffer has overflowed and new work must be denied."""
return self._deny_mode
def bind(self, sessionmaker: async_sessionmaker[AsyncSession]) -> None:
"""Attach the DB session factory (called from lifespan startup)."""
self._sessionmaker = sessionmaker
async def enqueue(self, record: AuditRecord) -> None:
"""Queue an audit record for asynchronous persistence (non-blocking)."""
await self._queue.put(record)
def start(self) -> None:
"""Launch the background drain task (idempotent)."""
if self._task is None or self._task.done():
self._task = asyncio.create_task(self._drain_loop())
async def stop(self) -> None:
"""Cancel the drain task and flush remaining rows best-effort."""
if self._task is not None:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self._task = None
await self.flush()
async def _drain_loop(self) -> None:
"""Continuously move queued records into Postgres, retrying the ring."""
while True:
record = await self._queue.get()
await self._persist_with_retry(record)
async def _persist_with_retry(self, record: AuditRecord) -> None:
"""Persist one record; on failure push to the ring (or enter deny mode)."""
# Drain any backlog first so ordering is roughly preserved.
if self._ring:
await self._try_flush_ring()
if not await self._write_one(record):
self._buffer(record)
def _buffer(self, record: AuditRecord) -> None:
"""Buffer a failed write; flip to deny mode if the ring is full."""
if len(self._ring) >= self._buffer_size:
self._deny_mode = True
_log.error("audit_buffer_overflow_deny_mode", buffered=len(self._ring))
return
self._ring.append(record)
async def _try_flush_ring(self) -> None:
"""Attempt to persist buffered rows; clear deny mode once drained."""
while self._ring:
record = self._ring[0]
if not await self._write_one(record):
return
self._ring.popleft()
self._deny_mode = False
async def _write_one(self, record: AuditRecord) -> bool:
"""Persist a single record; return False on any failure."""
if self._sessionmaker is None:
return False
try:
async with self._sessionmaker() as session:
await AuditRepository(session).insert_audit(**record.as_columns())
await session.commit()
return True
except Exception as exc: # noqa: BLE001 - any DB error ⇒ buffer + retry
_log.warning("audit_write_failed", error=str(exc))
return False
async def flush(self) -> None:
"""Drain queued + buffered records to Postgres (best-effort)."""
while not self._queue.empty():
record = self._queue.get_nowait()
if not await self._write_one(record):
self._buffer(record)
await self._try_flush_ring()
__all__ = ["AuditRecord", "AuditWriter"]

View File

@@ -0,0 +1,3 @@
"""Token budgets: Redis period counters and Postgres ledger reconciliation."""
from __future__ import annotations

View File

@@ -0,0 +1,105 @@
"""Redis period counters for token budgets (day / month / total).
The fast-path budget check reads a Redis counter of tokens already consumed in
the active period and compares it to the configured limit; Postgres
(``ledger.py``) is the durable source of truth reconciled on rollover. Counters
are keyed by ``key_id`` + period + period-start so a new period naturally starts
at zero, and day/month counters carry a TTL so expired periods self-clean.
Fail-closed (SPEC §4.4): Redis errors raise
:class:`DependencyUnavailableError`; the caller must deny (503).
"""
from __future__ import annotations
import datetime
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import cast
import redis.asyncio as redis
from redis.exceptions import RedisError
from neuronetz_gateway.db.models import BudgetPeriod
from neuronetz_gateway.errors import DependencyUnavailableError
_CONSUMED_PREFIX = "gateway:budget:"
def period_start(period: BudgetPeriod, now: datetime.datetime | None = None) -> datetime.datetime:
"""Return the UTC start of the current period."""
moment = now or datetime.datetime.now(datetime.UTC)
moment = moment.astimezone(datetime.UTC)
if period is BudgetPeriod.day:
return moment.replace(hour=0, minute=0, second=0, microsecond=0)
if period is BudgetPeriod.month:
return moment.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
# 'total' has no rollover; anchor at the epoch.
return datetime.datetime(1970, 1, 1, tzinfo=datetime.UTC)
def _ttl_for(period: BudgetPeriod) -> int | None:
"""TTL (seconds) for a period counter; None for 'total' (no expiry)."""
if period is BudgetPeriod.day:
return 2 * 24 * 3600
if period is BudgetPeriod.month:
return 40 * 24 * 3600
return None
def _counter_key(key_id: str, period: BudgetPeriod, start: datetime.datetime) -> str:
return f"{_CONSUMED_PREFIX}{key_id}:{period.value}:{int(start.timestamp())}"
@dataclass(frozen=True, slots=True)
class BudgetState:
"""Snapshot of remaining budget for a period."""
period: BudgetPeriod
limit: int | None
remaining: int | None
@property
def exhausted(self) -> bool:
"""True if a finite limit is configured and nothing remains."""
return self.limit is not None and (self.remaining or 0) <= 0
class BudgetCounter:
"""Redis-backed per-period token counter."""
def __init__(self, client: redis.Redis) -> None:
self._client = client
async def check(self, key_id: str, period: BudgetPeriod, limit: int | None) -> BudgetState:
"""Return remaining budget; ``None`` limit means unlimited for the period."""
if limit is None:
return BudgetState(period=period, limit=None, remaining=None)
start = period_start(period)
try:
raw = await self._client.get(_counter_key(key_id, period, start))
except RedisError as exc:
raise DependencyUnavailableError(
internal_detail=f"budget redis error: {exc!r}"
) from exc
consumed = int(raw) if raw else 0
return BudgetState(period=period, limit=limit, remaining=max(limit - consumed, 0))
async def consume(self, key_id: str, period: BudgetPeriod, tokens: int) -> None:
"""Increment the consumed counter after a request completes."""
if tokens <= 0:
return
start = period_start(period)
key = _counter_key(key_id, period, start)
try:
await cast("Awaitable[int]", self._client.incrby(key, tokens))
ttl = _ttl_for(period)
if ttl is not None:
await self._client.expire(key, ttl)
except RedisError as exc:
raise DependencyUnavailableError(
internal_detail=f"budget consume redis error: {exc!r}"
) from exc
__all__ = ["BudgetCounter", "BudgetState", "period_start"]

View File

@@ -0,0 +1,58 @@
"""Postgres budget ledger reconciliation.
Persists token usage to ``gateway.budget_usage`` (the durable source of truth)
via an idempotent upsert keyed by (key_id, period, period_start). The Redis
counter (``counter.py``) is the fast path; this ledger is what survives a Redis
flush and what ``show-usage`` reports against.
"""
from __future__ import annotations
import uuid
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from neuronetz_gateway.budget.counter import period_start
from neuronetz_gateway.db.models import BudgetPeriod, BudgetUsage
class BudgetLedger:
"""Source-of-truth budget accounting in Postgres."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def record_usage(
self, key_id: str, period: BudgetPeriod, tokens_in: int, tokens_out: int
) -> None:
"""Upsert usage into ``gateway.budget_usage`` for the active period.
Uses an ``ON CONFLICT`` upsert so concurrent writers accumulate rather
than clobber. ``requests`` increments by one per recorded request.
"""
start = period_start(period)
stmt = pg_insert(BudgetUsage).values(
key_id=uuid.UUID(key_id) if isinstance(key_id, str) else key_id,
period=period,
period_start=start,
tokens_in=tokens_in,
tokens_out=tokens_out,
requests=1,
)
stmt = stmt.on_conflict_do_update(
index_elements=[
BudgetUsage.key_id,
BudgetUsage.period,
BudgetUsage.period_start,
],
set_={
"tokens_in": BudgetUsage.tokens_in + stmt.excluded.tokens_in,
"tokens_out": BudgetUsage.tokens_out + stmt.excluded.tokens_out,
"requests": BudgetUsage.requests + stmt.excluded.requests,
},
)
await self._session.execute(stmt)
__all__ = ["BudgetLedger"]

View File

@@ -0,0 +1,67 @@
"""Prometheus metrics.
Phase 1 declares the metric objects and the exposition helper. Instrumentation
(incrementing counters / observing histograms on the request path) is wired in
later phases. Per SPEC §13.3 we label by ``tenant`` only, never by ``key_id``.
"""
from __future__ import annotations
from prometheus_client import CollectorRegistry, Counter, Histogram, generate_latest
REGISTRY = CollectorRegistry()
REQUESTS_TOTAL = Counter(
"gateway_requests_total",
"Total proxied requests.",
labelnames=("tenant", "model", "status"),
registry=REGISTRY,
)
TOKENS_TOTAL = Counter(
"gateway_tokens_total",
"Total tokens accounted, by direction (in|out).",
labelnames=("tenant", "model", "direction"),
registry=REGISTRY,
)
REQUEST_DURATION_SECONDS = Histogram(
"gateway_request_duration_seconds",
"Gateway-side request duration in seconds.",
labelnames=("tenant", "model"),
registry=REGISTRY,
)
def record_request(tenant: str, model: str, status: int, duration_s: float) -> None:
"""Increment the request counter and observe its duration (tenant-labeled)."""
REQUESTS_TOTAL.labels(tenant=tenant, model=model, status=str(status)).inc()
REQUEST_DURATION_SECONDS.labels(tenant=tenant, model=model).observe(duration_s)
def record_tokens(tenant: str, model: str, tokens_in: int, tokens_out: int) -> None:
"""Add input/output token counts to the tokens counter."""
if tokens_in:
TOKENS_TOTAL.labels(tenant=tenant, model=model, direction="in").inc(tokens_in)
if tokens_out:
TOKENS_TOTAL.labels(tenant=tenant, model=model, direction="out").inc(tokens_out)
def render_latest() -> bytes:
"""Return the current metrics in Prometheus text exposition format."""
payload: bytes = generate_latest(REGISTRY)
return payload
CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
__all__ = [
"CONTENT_TYPE_LATEST",
"REGISTRY",
"REQUESTS_TOTAL",
"REQUEST_DURATION_SECONDS",
"TOKENS_TOTAL",
"record_request",
"record_tokens",
"render_latest",
]

View File

@@ -0,0 +1,3 @@
"""Proxy layer: Ollama client, schema translation, token counting, allowlists."""
from __future__ import annotations

View File

@@ -0,0 +1,76 @@
"""Endpoint and model allowlists (SPEC §6.1, §6.2).
Mutating Ollama endpoints are hard-blocked (not configurable, not flagged):
``/api/pull``, ``/api/push``, ``/api/create``, ``/api/copy``, ``/api/delete``,
and any ``/api/blobs/*``. ``/api/ps`` is also blocked (leaks loaded models).
Model allowlist is per-tenant, default-deny. Enforcement logic lands in Phase 2.
"""
from __future__ import annotations
# Hard-blocked upstream endpoints — always 403, not configurable (SPEC §6.2).
HARD_BLOCKED_PATHS: frozenset[str] = frozenset(
{
"/api/pull",
"/api/push",
"/api/create",
"/api/copy",
"/api/delete",
"/api/ps",
}
)
# Path prefixes that are hard-blocked (e.g. blob upload/download).
HARD_BLOCKED_PREFIXES: tuple[str, ...] = ("/api/blobs",)
def is_hard_blocked(path: str) -> bool:
"""Return True if ``path`` is an unconditionally blocked upstream endpoint."""
if path in HARD_BLOCKED_PATHS:
return True
return any(path.startswith(prefix) for prefix in HARD_BLOCKED_PREFIXES)
def resolve_effective_models(
*,
allow_all: bool,
allowed_models: tuple[str, ...],
discovered: frozenset[str],
) -> frozenset[str]:
"""Resolve the effective model set per SPEC §4.3 step 7 / §4.6.
``allow_all`` ⇒ the effective set is the entire live ``discovered`` set;
otherwise it is the configured ``allowed_models`` intersected with
``discovered`` (so stale or typo'd allowlist entries never resolve, and a
model that is unpermitted vs. not-installed are indistinguishable).
Fail-closed: if ``discovered`` is empty (discovery unavailable/expired) the
result is empty regardless of ``allow_all`` — discovery never opens access.
"""
if not discovered:
return frozenset()
if allow_all:
return discovered
return frozenset(allowed_models) & discovered
def is_model_allowed(
model: str,
*,
allow_all: bool,
allowed_models: tuple[str, ...],
discovered: frozenset[str],
) -> bool:
"""Return True iff ``model`` is in the resolved effective set (default-deny)."""
return model in resolve_effective_models(
allow_all=allow_all, allowed_models=allowed_models, discovered=discovered
)
__all__ = [
"HARD_BLOCKED_PATHS",
"HARD_BLOCKED_PREFIXES",
"is_hard_blocked",
"is_model_allowed",
"resolve_effective_models",
]

View File

@@ -0,0 +1,207 @@
"""Live model discovery from the Ollama backend (SPEC §4.6).
A background task polls Ollama ``GET /api/tags`` every
``MODEL_DISCOVERY_REFRESH_S`` seconds. The parsed model set (names + sanitized
metadata) is cached in Redis under ``gateway:models:discovered`` with TTL
``MODEL_DISCOVERY_CACHE_TTL_S`` and held in-process for hot reads on the request
path.
Fail-closed (SPEC §4.6, §13.5): if Ollama is unreachable, or the cache is empty
or expired and cannot be refreshed, the discovered set is empty — and an empty
discovered set means no model resolves, so requests are denied. Discovery never
opens access on failure. It is read-only and only ever touches the allowlisted
``/api/tags`` endpoint; it never triggers a pull.
"""
from __future__ import annotations
import asyncio
import json
from dataclasses import asdict, dataclass
import httpx
import redis.asyncio as redis
from neuronetz_gateway.config import Settings
from neuronetz_gateway.observability.logging import get_logger
_log = get_logger("discovery")
REDIS_DISCOVERED_KEY = "gateway:models:discovered"
@dataclass(frozen=True, slots=True)
class DiscoveredModel:
"""Sanitized metadata for a single installed model."""
name: str
family: str | None = None
parameter_size: str | None = None
quantization: str | None = None
size_bytes: int | None = None
modified_at: str | None = None
def _parse_tags(payload: dict[str, object]) -> list[DiscoveredModel]:
"""Parse an Ollama ``/api/tags`` body into sanitized model records."""
models: list[DiscoveredModel] = []
raw_models = payload.get("models")
if not isinstance(raw_models, list):
return models
for entry in raw_models:
if not isinstance(entry, dict):
continue
name = entry.get("name") or entry.get("model")
if not isinstance(name, str) or not name:
continue
raw_details = entry.get("details")
details: dict[str, object] = raw_details if isinstance(raw_details, dict) else {}
size = entry.get("size")
models.append(
DiscoveredModel(
name=name,
family=_opt_str(details.get("family")),
parameter_size=_opt_str(details.get("parameter_size")),
quantization=_opt_str(details.get("quantization_level")),
size_bytes=size if isinstance(size, int) else None,
modified_at=_opt_str(entry.get("modified_at")),
)
)
return models
def _opt_str(value: object) -> str | None:
"""Coerce a value to ``str`` only when it already is one."""
return value if isinstance(value, str) else None
def names_of(models: list[DiscoveredModel]) -> frozenset[str]:
"""Return just the set of model names from parsed records."""
return frozenset(m.name for m in models)
class DiscoveryCache:
"""In-process holder for the latest discovered model set.
Holds both the structured records (for ``/api/tags`` / ``list-models``) and a
fast name set (for allowlist resolution on the hot path). Reads never block
on Redis or Ollama; the poller refreshes this in the background.
"""
def __init__(self) -> None:
self._models: list[DiscoveredModel] = []
self._names: frozenset[str] = frozenset()
self._lock = asyncio.Lock()
async def set(self, models: list[DiscoveredModel]) -> None:
"""Replace the in-process snapshot atomically."""
async with self._lock:
self._models = list(models)
self._names = names_of(models)
@property
def names(self) -> frozenset[str]:
"""Current discovered model names (possibly empty ⇒ fail-closed)."""
return self._names
@property
def models(self) -> list[DiscoveredModel]:
"""Current discovered model records (copy)."""
return list(self._models)
async def write_discovered_to_redis(
client: redis.Redis, models: list[DiscoveredModel], ttl_s: int
) -> None:
"""Cache the discovered set in Redis with a TTL (so staleness expires)."""
payload = json.dumps([asdict(m) for m in models], separators=(",", ":"))
await client.set(REDIS_DISCOVERED_KEY, payload, ex=ttl_s)
async def read_discovered_from_redis(client: redis.Redis) -> frozenset[str]:
"""Read the cached discovered names from Redis; empty set on miss/expiry."""
raw = await client.get(REDIS_DISCOVERED_KEY)
if not raw:
return frozenset()
try:
data = json.loads(raw)
except (json.JSONDecodeError, TypeError):
return frozenset()
if not isinstance(data, list):
return frozenset()
return frozenset(
str(item["name"]) for item in data if isinstance(item, dict) and item.get("name")
)
async def fetch_tags(client: httpx.AsyncClient) -> list[DiscoveredModel]:
"""Fetch and parse Ollama ``/api/tags``; raise on transport/HTTP error."""
resp = await client.get("/api/tags")
resp.raise_for_status()
body = resp.json()
if not isinstance(body, dict):
return []
return _parse_tags(body)
async def refresh_once(
http_client: httpx.AsyncClient,
redis_client: redis.Redis | None,
cache: DiscoveryCache,
settings: Settings,
) -> bool:
"""Run a single discovery refresh. Returns True on success.
On any failure the in-process and Redis caches are left untouched; they
expire on their own TTL, which is the fail-closed behavior (stale-expired ⇒
empty ⇒ deny). We never *clear* eagerly on a transient error, but we also
never extend the TTL on failure.
"""
try:
models = await fetch_tags(http_client)
except (httpx.HTTPError, ValueError) as exc:
_log.warning("discovery_refresh_failed", error=str(exc))
return False
await cache.set(models)
if redis_client is not None:
try:
await write_discovered_to_redis(
redis_client, models, settings.model_discovery_cache_ttl_s
)
except Exception as exc: # noqa: BLE001 - Redis write is best-effort cache fill
_log.warning("discovery_cache_write_failed", error=str(exc))
_log.info("discovery_refreshed", count=len(models))
return True
async def discovery_loop(
http_client: httpx.AsyncClient,
redis_client: redis.Redis | None,
cache: DiscoveryCache,
settings: Settings,
) -> None:
"""Background poller: refresh now, then every ``MODEL_DISCOVERY_REFRESH_S``.
Designed to be launched via ``asyncio.create_task`` in the lifespan and
cancelled on shutdown.
"""
await refresh_once(http_client, redis_client, cache, settings)
while True:
try:
await asyncio.sleep(settings.model_discovery_refresh_s)
except asyncio.CancelledError:
raise
await refresh_once(http_client, redis_client, cache, settings)
__all__ = [
"REDIS_DISCOVERED_KEY",
"DiscoveredModel",
"DiscoveryCache",
"discovery_loop",
"fetch_tags",
"names_of",
"read_discovered_from_redis",
"refresh_once",
"write_discovered_to_redis",
]

View File

@@ -0,0 +1,79 @@
"""httpx-based streaming proxy client for Ollama.
Opens async streams to the upstream and relays bytes without buffering, so the
gateway does not degrade time-to-first-byte (SPEC §9, non-negotiable #6).
Transport-level upstream failures are sanitized at this boundary into
:class:`UpstreamUnavailableError` (never reflected verbatim, non-negotiable #4).
The client is constructed from the shared ``httpx.AsyncClient`` on
``app.state`` and is exposed to routes via the ``get_ollama_client`` dependency
in ``deps.py`` so tests can override it (the QA override contract).
"""
from __future__ import annotations
from collections.abc import AsyncIterator
import httpx
from neuronetz_gateway.errors import UpstreamUnavailableError
class OllamaClient:
"""Thin async wrapper around the shared httpx client for Ollama calls."""
def __init__(self, client: httpx.AsyncClient) -> None:
self._client = client
async def stream(
self, method: str, path: str, json_body: dict[str, object]
) -> AsyncIterator[bytes]:
"""Open a streaming request to Ollama and yield raw response chunks.
Yields bytes exactly as received so the hot path performs no buffering.
Transport errors are sanitized; if the upstream returns a non-2xx the
body is drained (so internals are not surfaced) and a generic 502 is
raised before any client bytes are emitted.
"""
request = self._client.build_request(method, path, json=json_body)
try:
response = await self._client.send(request, stream=True)
except httpx.HTTPError as exc:
raise UpstreamUnavailableError(internal_detail=f"ollama send failed: {exc!r}") from exc
try:
if response.status_code >= 400:
await response.aread()
raise UpstreamUnavailableError(
internal_detail=f"ollama returned {response.status_code} for {path}"
)
async for chunk in response.aiter_raw():
yield chunk
except httpx.HTTPError as exc:
raise UpstreamUnavailableError(
internal_detail=f"ollama stream failed: {exc!r}"
) from exc
finally:
await response.aclose()
async def request(
self, method: str, path: str, json_body: dict[str, object]
) -> httpx.Response:
"""Perform a non-streaming request to Ollama.
Returns the upstream response on success; raises a sanitized 502 on
transport failure or a non-2xx status (internals never reflected).
"""
try:
response = await self._client.request(method, path, json=json_body)
except httpx.HTTPError as exc:
raise UpstreamUnavailableError(
internal_detail=f"ollama request failed: {exc!r}"
) from exc
if response.status_code >= 400:
raise UpstreamUnavailableError(
internal_detail=f"ollama returned {response.status_code} for {path}"
)
return response
__all__ = ["OllamaClient"]

View File

@@ -0,0 +1,467 @@
"""Shared request pipeline: enforcement, streaming, and post-close bookkeeping.
Both the native (``/api/*``) and OpenAI-compat (``/v1/*``) routes funnel through
here so the security checks and the streaming-integrity contract are written
once. The order mirrors SPEC §4.3:
rate limit (per-key + per-tenant RPM) → budget → concurrency → model allowlist
→ endpoint allowlist → body validation → proxy + stream → post-close audit /
token-count / budget-consume / metrics / semaphore release.
Streaming integrity (non-negotiable #6): the bytes flow to the client untouched
and token counting + audit + budget-consume happen **after** the stream closes,
never on the hot path.
Fail-closed (non-negotiable #1): every limiter/budget call raises
:class:`DependencyUnavailableError` when Redis is down, which the error handler
renders as 503. The model allowlist is default-deny against the live discovered
set; a missing/expired discovery set denies everything.
"""
from __future__ import annotations
import datetime
import json
import time
import uuid
from collections.abc import AsyncIterator, Callable
from dataclasses import dataclass
from typing import Any
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from starlette.responses import JSONResponse, StreamingResponse
from neuronetz_gateway.audit.writer import AuditRecord, AuditWriter
from neuronetz_gateway.auth.principal import Principal
from neuronetz_gateway.budget.counter import BudgetCounter
from neuronetz_gateway.budget.ledger import BudgetLedger
from neuronetz_gateway.config import Settings
from neuronetz_gateway.db.models import BudgetPeriod
from neuronetz_gateway.errors import (
AuthorizationError,
BudgetExceededError,
RateLimitError,
RequestTooLargeError,
)
from neuronetz_gateway.observability import metrics
from neuronetz_gateway.observability.logging import get_logger
from neuronetz_gateway.proxy.allowlist import is_hard_blocked, is_model_allowed
from neuronetz_gateway.proxy.discovery import DiscoveryCache
from neuronetz_gateway.proxy.ollama import OllamaClient
from neuronetz_gateway.proxy.token_counter import TokenUsage, extract_usage
from neuronetz_gateway.ratelimit.concurrency import ConcurrencyLimiter
from neuronetz_gateway.ratelimit.sliding_window import SlidingWindowLimiter
_log = get_logger("pipeline")
NDJSON_MEDIA_TYPE = "application/x-ndjson"
SSE_MEDIA_TYPE = "text/event-stream"
_CONCURRENCY_PREFIX = "gateway:concurrency:"
@dataclass(slots=True)
class RateHeaders:
"""The §6.5 rate/budget header values gathered during pre-flight."""
limit_requests: int
remaining_requests: int
limit_tokens: int
remaining_tokens: int
budget_period: str
budget_remaining: str
def as_dict(self, request_id: str) -> dict[str, str]:
"""Render to the SPEC §6.5 header set."""
return {
"X-Request-ID": request_id,
"X-RateLimit-Limit-Requests": str(self.limit_requests),
"X-RateLimit-Remaining-Requests": str(self.remaining_requests),
"X-RateLimit-Limit-Tokens": str(self.limit_tokens),
"X-RateLimit-Remaining-Tokens": str(self.remaining_tokens),
"X-Budget-Period": self.budget_period,
"X-Budget-Tokens-Remaining": self.budget_remaining,
}
class Pipeline:
"""Per-request enforcement + proxy orchestrator."""
def __init__(
self,
*,
request: Request,
principal: Principal,
settings: Settings,
ollama: OllamaClient,
discovery: DiscoveryCache,
rate_limiter: SlidingWindowLimiter,
concurrency: ConcurrencyLimiter,
budget: BudgetCounter,
audit: AuditWriter,
sessionmaker: async_sessionmaker[AsyncSession] | None = None,
) -> None:
self._request = request
self._p = principal
self._settings = settings
self._ollama = ollama
self._discovery = discovery
self._rate = rate_limiter
self._conc = concurrency
self._budget = budget
self._audit = audit
self._sessionmaker = sessionmaker
self._request_id = str(getattr(request.state, "request_id", uuid.uuid4()))
self._concurrency_key = f"{_CONCURRENCY_PREFIX}{principal.tenant_id}"
self._headers = RateHeaders(
limit_requests=principal.limits.rpm,
remaining_requests=principal.limits.rpm,
limit_tokens=principal.limits.tpm,
remaining_tokens=principal.limits.tpm,
budget_period="day",
budget_remaining="unlimited",
)
@property
def settings(self) -> Settings:
"""The active settings (for body-size / num_predict caps in routes)."""
return self._settings
# ----- enforcement -------------------------------------------------------
def check_scope(self, scope: str) -> None:
"""Authorize a coarse scope (e.g. 'chat', 'embeddings')."""
if scope not in self._p.scopes:
raise AuthorizationError(internal_detail=f"scope {scope!r} not granted")
def check_endpoint(self, path: str) -> None:
"""Reject hard-blocked upstream endpoints with a generic 403."""
if is_hard_blocked(path):
raise AuthorizationError(internal_detail=f"hard-blocked endpoint {path}")
def check_model(self, model: str) -> None:
"""Default-deny model check against the live effective set (§4.3 step 7)."""
if not model or not is_model_allowed(
model,
allow_all=self._p.allow_all_models,
allowed_models=self._p.allowed_models,
discovered=self._discovery.names,
):
# No existence disclosure (SPEC §13.6): unpermitted and not-installed
# both yield the same generic 403.
raise AuthorizationError(internal_detail=f"model {model!r} not in effective set")
def validate_body(self, body: dict[str, object]) -> None:
"""Enforce the ``num_predict`` cap (body-size cap is enforced earlier)."""
options = body.get("options")
if isinstance(options, dict):
num_predict = options.get("num_predict")
if isinstance(num_predict, int) and num_predict > self._settings.max_num_predict:
options["num_predict"] = self._settings.max_num_predict
async def enforce_limits(self, *, token_estimate: int = 0) -> None:
"""Run RPM (per-key + per-tenant), TPM, budget, and concurrency checks.
Budget is checked before concurrency so an over-budget request never even
acquires a permit. Order otherwise follows SPEC §4.3 steps 4-6.
"""
await self._check_rpm()
await self._check_tpm(token_estimate)
await self.check_budgets()
await self._acquire_concurrency()
async def _check_rpm(self) -> None:
key_result = await self._rate.check(
f"gateway:rpm:key:{self._p.key_id}", self._p.limits.rpm, 60, cost=1
)
self._headers.limit_requests = key_result.limit
self._headers.remaining_requests = key_result.remaining
if not key_result.allowed:
raise RateLimitError(retry_after=key_result.retry_after_s)
tenant_result = await self._rate.check(
f"gateway:rpm:tenant:{self._p.tenant_id}", self._p.limits.rpm, 60, cost=1
)
if not tenant_result.allowed:
raise RateLimitError(retry_after=tenant_result.retry_after_s)
async def _check_tpm(self, token_estimate: int) -> None:
# TPM is charged with a minimum of 1 so a request always counts; the
# precise token cost is reconciled post-stream via the budget counter.
cost = max(token_estimate, 1)
result = await self._rate.check(
f"gateway:tpm:key:{self._p.key_id}", self._p.limits.tpm, 60, cost=cost
)
self._headers.limit_tokens = result.limit
self._headers.remaining_tokens = result.remaining
if not result.allowed:
raise RateLimitError(retry_after=result.retry_after_s)
def _budget_periods(self) -> list[tuple[BudgetPeriod, int | None]]:
return [
(BudgetPeriod.day, self._p.limits.tokens_daily),
(BudgetPeriod.month, self._p.limits.tokens_monthly),
(BudgetPeriod.total, self._p.limits.tokens_total),
]
async def check_budgets(self) -> None:
"""Verify no configured budget period is already exhausted."""
tightest_period = "day"
tightest_remaining = "unlimited"
for period, limit in self._budget_periods():
state = await self._budget.check(str(self._p.key_id), period, limit)
if state.exhausted:
raise BudgetExceededError(internal_detail=f"budget {period.value} exhausted")
if state.remaining is not None:
tightest_period = period.value
tightest_remaining = str(state.remaining)
self._headers.budget_period = tightest_period
self._headers.budget_remaining = tightest_remaining
async def _acquire_concurrency(self) -> None:
ok = await self._conc.acquire(
self._concurrency_key,
self._p.limits.concurrent,
self._settings.ollama_read_timeout_s + 30,
)
if not ok:
raise RateLimitError(
retry_after=1, internal_detail="concurrency cap reached"
)
# ----- proxy + bookkeeping ----------------------------------------------
def headers(self) -> dict[str, str]:
"""Render the §6.5 response headers."""
return self._headers.as_dict(self._request_id)
async def stream_native(
self, method: str, path: str, body: dict[str, object], model: str
) -> StreamingResponse:
"""Proxy a streaming NDJSON request, accounting tokens after close."""
started = time.monotonic()
media_type = NDJSON_MEDIA_TYPE
async def gen() -> AsyncIterator[bytes]:
last_obj: dict[str, object] = {}
try:
async for chunk in self._ollama.stream(method, path, body):
last_obj = _merge_last_ndjson(chunk, last_obj)
yield chunk
finally:
await self._finish(model, path, method, last_obj, started)
return StreamingResponse(gen(), media_type=media_type, headers=self.headers())
async def stream_openai(
self,
method: str,
path: str,
body: dict[str, object],
model: str,
chunk_translator: Callable[[dict[str, object]], dict[str, object]],
) -> StreamingResponse:
"""Proxy + translate native NDJSON into OpenAI SSE; account after close."""
started = time.monotonic()
async def gen() -> AsyncIterator[bytes]:
last_obj: dict[str, object] = {}
buffer = b""
try:
async for chunk in self._ollama.stream(method, path, body):
buffer += chunk
lines = buffer.split(b"\n")
buffer = lines.pop()
for line in lines:
if not line.strip():
continue
obj = json.loads(line)
last_obj = obj if isinstance(obj, dict) else last_obj
translated = chunk_translator(obj)
yield f"data: {json.dumps(translated)}\n\n".encode()
if buffer.strip():
obj = json.loads(buffer)
last_obj = obj if isinstance(obj, dict) else last_obj
yield f"data: {json.dumps(chunk_translator(obj))}\n\n".encode()
yield b"data: [DONE]\n\n"
finally:
await self._finish(model, path, method, last_obj, started)
return StreamingResponse(gen(), media_type=SSE_MEDIA_TYPE, headers=self.headers())
async def request_native(
self, method: str, path: str, body: dict[str, object], model: str
) -> JSONResponse:
"""Proxy a non-streaming request and account tokens before responding."""
started = time.monotonic()
resp = await self._ollama.request(method, path, body)
payload = resp.json()
obj = payload if isinstance(payload, dict) else {}
await self._finish(model, path, method, obj, started)
return JSONResponse(obj, headers=self.headers())
async def request_translated(
self,
method: str,
path: str,
body: dict[str, object],
model: str,
translator: Callable[[dict[str, object]], dict[str, object]],
) -> JSONResponse:
"""Proxy a non-streaming request, translate the body, then account."""
started = time.monotonic()
resp = await self._ollama.request(method, path, body)
payload = resp.json()
obj = payload if isinstance(payload, dict) else {}
await self._finish(model, path, method, obj, started)
return JSONResponse(translator(obj), headers=self.headers())
async def _finish(
self,
model: str,
path: str,
method: str,
final_obj: dict[str, object],
started: float,
) -> None:
"""Post-close bookkeeping: tokens, budget, metrics, audit, semaphore.
Runs once per request after the response is fully produced. Each step is
best-effort and guarded so a bookkeeping failure never corrupts the
already-delivered response. The concurrency permit is always released.
"""
usage = extract_usage(final_obj) if final_obj else TokenUsage(0, 0)
latency_ms = int((time.monotonic() - started) * 1000)
try:
await self._account_budget(usage)
except Exception as exc: # noqa: BLE001 - never break a delivered response
_log.warning("budget_account_failed", error=str(exc))
try:
metrics.record_request(self._p.tenant_name, model or "unknown", 200, latency_ms / 1000)
metrics.record_tokens(
self._p.tenant_name, model or "unknown", usage.tokens_in, usage.tokens_out
)
except Exception as exc: # noqa: BLE001 - metrics must not break responses
_log.warning("metrics_record_failed", error=str(exc))
await self._write_audit(model, path, method, usage, latency_ms, 200)
await self._release_concurrency()
async def _account_budget(self, usage: TokenUsage) -> None:
"""Decrement Redis budget counters and persist to the Postgres ledger."""
for period, limit in self._budget_periods():
if limit is not None:
await self._budget.consume(str(self._p.key_id), period, usage.total)
if self._sessionmaker is not None and usage.total >= 0:
try:
async with self._sessionmaker() as session:
ledger = BudgetLedger(session)
for period, _limit in self._budget_periods():
await ledger.record_usage(
str(self._p.key_id), period, usage.tokens_in, usage.tokens_out
)
await session.commit()
except Exception as exc: # noqa: BLE001 - ledger is durable backstop; never break response
_log.warning("ledger_record_failed", error=str(exc))
async def _write_audit(
self,
model: str,
path: str,
method: str,
usage: TokenUsage,
latency_ms: int,
status: int,
) -> None:
record = AuditRecord(
request_id=uuid.UUID(self._request_id)
if _is_uuid(self._request_id)
else uuid.uuid4(),
method=method,
path=path,
status=status,
ts=datetime.datetime.now(datetime.UTC),
tenant_id=self._p.tenant_id,
key_id=self._p.key_id,
key_prefix=self._p.key_prefix,
model=model or None,
tokens_in=usage.tokens_in,
tokens_out=usage.tokens_out,
latency_ms=latency_ms,
client_ip=self._client_ip(),
user_agent=self._request.headers.get("user-agent"),
)
try:
await self._audit.enqueue(record)
except Exception as exc: # noqa: BLE001 - audit enqueue must not break response
_log.warning("audit_enqueue_failed", error=str(exc))
async def _release_concurrency(self) -> None:
try:
await self._conc.release(self._concurrency_key)
except Exception as exc: # noqa: BLE001 - release is best-effort
_log.warning("concurrency_release_failed", error=str(exc))
def _client_ip(self) -> str | None:
xff = self._request.headers.get("x-forwarded-for")
if xff:
return xff.split(",")[0].strip()
return self._request.client.host if self._request.client else None
def _merge_last_ndjson(chunk: bytes, prev: dict[str, object]) -> dict[str, object]:
"""Track the last complete NDJSON object seen in a raw byte chunk.
Token counts live on the final ``done`` frame. We parse only complete lines
and keep the last successfully-parsed object; partial trailing data is
ignored here and will be completed by a subsequent chunk.
"""
text = chunk.decode("utf-8", errors="ignore")
last = prev
for line in text.split("\n"):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError:
continue
if isinstance(obj, dict):
last = obj
return last
def _is_uuid(value: str) -> bool:
try:
uuid.UUID(value)
except ValueError:
return False
return True
def model_of(body: dict[str, object]) -> str:
"""Extract the requested model name from a request body (empty if absent)."""
model = body.get("model")
return model if isinstance(model, str) else ""
async def read_json_body(request: Request, settings: Settings) -> dict[str, object]:
"""Read + size-limit the request body, returning the parsed JSON object."""
raw = await request.body()
if len(raw) > settings.max_request_body_bytes:
raise RequestTooLargeError(internal_detail=f"body {len(raw)} bytes")
if not raw:
return {}
try:
parsed: Any = json.loads(raw)
except json.JSONDecodeError as exc:
raise AuthorizationError(internal_detail="invalid JSON body") from exc
return parsed if isinstance(parsed, dict) else {}
__all__ = [
"NDJSON_MEDIA_TYPE",
"SSE_MEDIA_TYPE",
"Pipeline",
"RateHeaders",
"model_of",
"read_json_body",
]

View File

@@ -0,0 +1,50 @@
"""Precise token accounting parsed from Ollama responses (SPEC §2, §13.1).
Tokens are read from Ollama's reported ``prompt_eval_count`` (input) and
``eval_count`` (output) on the final stream frame — never heuristically
estimated. Embeddings charge ``prompt_eval_count`` only (SPEC §13.1); they have
no ``eval_count`` so output tokens are reported as zero.
"""
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class TokenUsage:
"""Token counts extracted from an Ollama response."""
tokens_in: int
tokens_out: int
@property
def total(self) -> int:
"""Combined input + output tokens."""
return self.tokens_in + self.tokens_out
def _as_int(value: object) -> int:
"""Coerce an Ollama-reported count to a non-negative int (0 if absent/bad)."""
if isinstance(value, bool):
return 0
if isinstance(value, int):
return max(value, 0)
if isinstance(value, float):
return max(int(value), 0)
return 0
def extract_usage(final_frame: dict[str, object]) -> TokenUsage:
"""Extract ``prompt_eval_count``/``eval_count`` from the final Ollama frame.
Works for chat/generate final frames and for embeddings responses (which
carry ``prompt_eval_count`` but no ``eval_count``).
"""
return TokenUsage(
tokens_in=_as_int(final_frame.get("prompt_eval_count")),
tokens_out=_as_int(final_frame.get("eval_count")),
)
__all__ = ["TokenUsage", "extract_usage"]

View File

@@ -0,0 +1,245 @@
"""OpenAI <-> Ollama schema translation (SPEC §6.3).
Native ``/api/*`` speaks NDJSON; OpenAI-compat ``/v1/*`` speaks SSE
(``data: {...}\\n\\n`` … ``data: [DONE]\\n\\n``). These helpers translate request
bodies in both directions and convert native Ollama stream frames into OpenAI
chunk objects, preserving streaming. Unknown OpenAI sampling params are mapped
into Ollama's ``options`` block; unrecognized keys are dropped rather than
forwarded blindly.
"""
from __future__ import annotations
import time
import uuid
from typing import Any
# OpenAI sampling params that map into Ollama's ``options`` object.
_OPTION_KEYS: tuple[str, ...] = (
"temperature",
"top_p",
"top_k",
"seed",
"stop",
"presence_penalty",
"frequency_penalty",
)
def _as_int(value: object) -> int:
"""Coerce a JSON-derived value to a non-negative int (0 if absent/invalid)."""
if isinstance(value, bool):
return 0
if isinstance(value, int):
return max(value, 0)
if isinstance(value, float):
return max(int(value), 0)
return 0
def _build_options(payload: dict[str, Any]) -> dict[str, Any]:
"""Collect OpenAI sampling params into an Ollama ``options`` mapping."""
options: dict[str, Any] = {}
for key in _OPTION_KEYS:
if key in payload and payload[key] is not None:
options[key] = payload[key]
if "max_tokens" in payload and payload["max_tokens"] is not None:
options["num_predict"] = payload["max_tokens"]
return options
def openai_chat_to_ollama(payload: dict[str, object]) -> dict[str, object]:
"""Translate an OpenAI chat-completion request to an Ollama ``/api/chat`` body."""
body: dict[str, object] = {
"model": payload.get("model"),
"messages": payload.get("messages", []),
"stream": bool(payload.get("stream", False)),
}
options = _build_options(dict(payload))
if options:
body["options"] = options
return body
def openai_completion_to_ollama(payload: dict[str, object]) -> dict[str, object]:
"""Translate an OpenAI completion request to an Ollama ``/api/generate`` body."""
prompt = payload.get("prompt", "")
if isinstance(prompt, list):
prompt = "".join(str(p) for p in prompt)
body: dict[str, object] = {
"model": payload.get("model"),
"prompt": prompt,
"stream": bool(payload.get("stream", False)),
}
options = _build_options(dict(payload))
if options:
body["options"] = options
return body
def openai_embeddings_to_ollama(payload: dict[str, object]) -> dict[str, object]:
"""Translate an OpenAI embeddings request to an Ollama ``/api/embed`` body."""
return {
"model": payload.get("model"),
"input": payload.get("input", ""),
}
def _completion_id() -> str:
"""Generate an OpenAI-style completion id."""
return f"chatcmpl-{uuid.uuid4().hex}"
def ollama_chat_chunk_to_openai(
chunk: dict[str, object], *, completion_id: str, model: str, created: int
) -> dict[str, object]:
"""Translate one Ollama ``/api/chat`` NDJSON frame to an OpenAI SSE chunk."""
done = bool(chunk.get("done"))
if done:
return {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"usage": {
"prompt_tokens": _as_int(chunk.get("prompt_eval_count")),
"completion_tokens": _as_int(chunk.get("eval_count")),
"total_tokens": _as_int(chunk.get("prompt_eval_count"))
+ _as_int(chunk.get("eval_count")),
},
}
message = chunk.get("message")
content = ""
if isinstance(message, dict):
content = str(message.get("content", ""))
return {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}],
}
def ollama_generate_chunk_to_openai(
chunk: dict[str, object], *, completion_id: str, model: str, created: int
) -> dict[str, object]:
"""Translate one Ollama ``/api/generate`` NDJSON frame to an OpenAI text chunk."""
done = bool(chunk.get("done"))
if done:
return {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model,
"choices": [{"index": 0, "text": "", "finish_reason": "stop"}],
"usage": {
"prompt_tokens": _as_int(chunk.get("prompt_eval_count")),
"completion_tokens": _as_int(chunk.get("eval_count")),
"total_tokens": _as_int(chunk.get("prompt_eval_count"))
+ _as_int(chunk.get("eval_count")),
},
}
return {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model,
"choices": [{"index": 0, "text": str(chunk.get("response", "")), "finish_reason": None}],
}
def ollama_chat_to_openai(payload: dict[str, object]) -> dict[str, object]:
"""Translate a *non-streaming* Ollama chat response to an OpenAI completion."""
message = payload.get("message")
content = ""
if isinstance(message, dict):
content = str(message.get("content", ""))
prompt_tokens = _as_int(payload.get("prompt_eval_count"))
completion_tokens = _as_int(payload.get("eval_count"))
return {
"id": _completion_id(),
"object": "chat.completion",
"created": int(time.time()),
"model": str(payload.get("model", "")),
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
def ollama_generate_to_openai(payload: dict[str, object]) -> dict[str, object]:
"""Translate a *non-streaming* Ollama generate response to OpenAI completion."""
prompt_tokens = _as_int(payload.get("prompt_eval_count"))
completion_tokens = _as_int(payload.get("eval_count"))
return {
"id": _completion_id(),
"object": "text_completion",
"created": int(time.time()),
"model": str(payload.get("model", "")),
"choices": [
{"index": 0, "text": str(payload.get("response", "")), "finish_reason": "stop"}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
def ollama_embed_to_openai(payload: dict[str, object], model: str) -> dict[str, object]:
"""Translate an Ollama ``/api/embed`` response to the OpenAI embeddings shape."""
raw = payload.get("embeddings")
vectors: list[list[float]] = raw if isinstance(raw, list) else []
prompt_tokens = _as_int(payload.get("prompt_eval_count"))
return {
"object": "list",
"data": [
{"object": "embedding", "index": i, "embedding": vec}
for i, vec in enumerate(vectors)
],
"model": model,
"usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens},
}
def models_to_openai_list(names: list[str]) -> dict[str, object]:
"""Render a list of model names in the OpenAI ``/v1/models`` list format."""
created = int(time.time())
return {
"object": "list",
"data": [
{"id": name, "object": "model", "created": created, "owned_by": "neuronetz"}
for name in names
],
}
def new_completion_id() -> str:
"""Public helper to mint a completion id for a streaming response."""
return _completion_id()
__all__ = [
"models_to_openai_list",
"new_completion_id",
"ollama_chat_chunk_to_openai",
"ollama_chat_to_openai",
"ollama_embed_to_openai",
"ollama_generate_chunk_to_openai",
"ollama_generate_to_openai",
"openai_chat_to_ollama",
"openai_completion_to_ollama",
"openai_embeddings_to_ollama",
]

View File

@@ -0,0 +1,3 @@
"""Rate limiting: sliding-window RPM/TPM and concurrency semaphore (Redis)."""
from __future__ import annotations

View File

@@ -0,0 +1,66 @@
"""Concurrent-connection semaphore backed by Redis ``INCR`` with a TTL guard.
Acquired before proxying and released on stream close. A TTL on the counter
prevents a crashed worker from leaking permits forever (self-healing). Fails
closed (SPEC §4.4): if Redis is unreachable, acquisition raises
:class:`DependencyUnavailableError` so the caller denies (503).
"""
from __future__ import annotations
from collections.abc import Awaitable
from typing import cast
import redis.asyncio as redis
from redis.exceptions import RedisError
from neuronetz_gateway.errors import DependencyUnavailableError
# Default guard TTL: a permit auto-expires if not explicitly released, so a
# crashed worker cannot pin the semaphore. Comfortably longer than any single
# stream is expected to take is set per-call via ``ttl_s``.
class ConcurrencyLimiter:
"""Redis-backed concurrency cap with a self-healing TTL guard."""
def __init__(self, client: redis.Redis) -> None:
self._client = client
async def acquire(self, scope_key: str, limit: int, ttl_s: int) -> bool:
"""Try to acquire a permit; return False (deny) if at capacity.
Increments the counter, refreshes the TTL guard, and rolls back the
increment if the new value exceeds ``limit``. Raises on Redis failure so
the caller fails closed.
"""
try:
count = await cast("Awaitable[int]", self._client.incr(scope_key))
await self._client.expire(scope_key, ttl_s)
except RedisError as exc:
raise DependencyUnavailableError(
internal_detail=f"concurrency redis error: {exc!r}"
) from exc
if count > limit:
# Over capacity: undo our increment and deny.
try:
await self._client.decr(scope_key)
except RedisError:
# Permit self-heals via the TTL guard; denial still stands.
return False
return False
return True
async def release(self, scope_key: str) -> None:
"""Release a previously acquired permit (best-effort; never raises)."""
try:
count = await cast("Awaitable[int]", self._client.decr(scope_key))
if count < 0:
await self._client.set(scope_key, 0)
except RedisError:
# The TTL guard will reclaim the permit; releasing must not break the
# request that already completed.
return
__all__ = ["ConcurrencyLimiter"]

View File

@@ -0,0 +1,109 @@
"""Sliding-window rate limiter backed by an atomic Redis Lua script.
Enforces per-key and per-tenant RPM and per-key TPM. The window is a sorted set
of timestamped hits (a true sliding window, not a fixed bucket): each check
trims entries older than ``window_s``, sums the cost of what remains, and admits
the new cost only if it keeps the total within ``limit``. The trim + sum +
conditional add run inside one Lua script so the decision is atomic across
concurrent workers (SPEC §4.3 step 4).
Fail-closed (SPEC §4.4): if Redis is unavailable the limiter raises
:class:`DependencyUnavailableError`; the caller must deny (503), never allow.
"""
from __future__ import annotations
import time
from dataclasses import dataclass
import redis.asyncio as redis
from redis.exceptions import RedisError
from neuronetz_gateway.errors import DependencyUnavailableError
# KEYS[1] = zset key
# ARGV[1] = now (ms) ARGV[2] = window (ms) ARGV[3] = limit
# ARGV[4] = cost ARGV[5] = unique member suffix
# Returns: {allowed (1/0), used_after, retry_after_ms}
_LUA = """
local key = KEYS[1]
local now = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local limit = tonumber(ARGV[3])
local cost = tonumber(ARGV[4])
local member = ARGV[5]
redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
-- Each surviving member encodes its cost as a ':<cost>' suffix; sum them so the
-- window total is cost-weighted (RPM uses cost 1, TPM uses token count).
local total = 0
local data = redis.call('ZRANGEBYSCORE', key, now - window, now)
for i = 1, #data do
local sep = string.find(data[i], ':[^:]*$')
if sep then
total = total + tonumber(string.sub(data[i], sep + 1))
else
total = total + 1
end
end
if total + cost > limit then
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
local retry = window
if oldest[2] then
retry = (tonumber(oldest[2]) + window) - now
if retry < 0 then retry = 0 end
end
return {0, total, retry}
end
redis.call('ZADD', key, now, member .. ':' .. cost)
redis.call('PEXPIRE', key, window)
return {1, total + cost, 0}
"""
@dataclass(frozen=True, slots=True)
class RateLimitResult:
"""Outcome of a rate-limit check."""
allowed: bool
limit: int
remaining: int
retry_after_s: int | None
class SlidingWindowLimiter:
"""Redis-backed sliding-window limiter (atomic via Lua)."""
def __init__(self, client: redis.Redis) -> None:
self._client = client
self._script = client.register_script(_LUA)
async def check(self, key: str, limit: int, window_s: int, cost: int = 1) -> RateLimitResult:
"""Atomically record a hit of ``cost`` and report admission.
Raises :class:`DependencyUnavailableError` if Redis cannot be reached, so
the caller fails closed.
"""
now_ms = int(time.time() * 1000)
window_ms = window_s * 1000
member = f"{now_ms}-{id(object())}"
try:
raw = await self._script(
keys=[key], args=[now_ms, window_ms, limit, cost, member]
)
except RedisError as exc:
raise DependencyUnavailableError(
internal_detail=f"ratelimit redis error: {exc!r}"
) from exc
allowed_i, used_after, retry_ms = (int(raw[0]), int(raw[1]), int(raw[2]))
allowed = allowed_i == 1
remaining = max(limit - used_after, 0)
retry_after_s = None if allowed else max(1, (retry_ms + 999) // 1000)
return RateLimitResult(
allowed=allowed, limit=limit, remaining=remaining, retry_after_s=retry_after_s
)
__all__ = ["RateLimitResult", "SlidingWindowLimiter"]

View File

@@ -0,0 +1,97 @@
"""Key-revocation NOTIFY listener (SPEC §4.5).
Console (or the gateway's own CLI) revokes a key by inserting into
``gateway.revocations``; an ``AFTER INSERT`` trigger fires
``pg_notify('key_revoked', key_id)``. This background task LISTENs on that
channel and, on each notification, evicts the Redis auth-cache entry for the
revoked key's prefix so the next request misses the cache, re-reads the DB,
finds the key non-active, and is rejected — making revocation effective within
one Redis RTT without any cross-service HTTP.
The listener resolves ``key_id -> prefix`` via a short DB lookup (the NOTIFY
payload is the key id, but the cache is keyed by prefix). It is resilient: a
dropped connection is retried with backoff.
"""
from __future__ import annotations
import asyncio
import contextlib
import uuid
import asyncpg
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from neuronetz_gateway.config import Settings
from neuronetz_gateway.db.models import ApiKey
from neuronetz_gateway.observability.logging import get_logger
_log = get_logger("revocation")
_CHANNEL = "key_revoked"
_CACHE_PREFIX = "gateway:key:"
def _asyncpg_dsn(database_url: str) -> str:
"""Strip the SQLAlchemy ``+asyncpg`` driver tag for a raw asyncpg connect."""
return database_url.replace("postgresql+asyncpg://", "postgresql://")
async def _evict(
key_id_text: str,
sessionmaker: async_sessionmaker[AsyncSession],
redis_client: redis.Redis,
) -> None:
"""Resolve the key id to its prefix and delete the cached principal."""
try:
key_id = uuid.UUID(key_id_text)
except ValueError:
_log.warning("revocation_bad_payload", payload=key_id_text)
return
try:
async with sessionmaker() as session:
key = await session.get(ApiKey, key_id)
prefix = key.prefix if key is not None else None
if prefix is not None:
await redis_client.delete(_CACHE_PREFIX + prefix)
_log.info("revocation_cache_evicted", key_prefix=prefix)
except Exception as exc: # noqa: BLE001 - listener must survive transient errors
_log.warning("revocation_evict_failed", error=str(exc))
async def revocation_listener(
settings: Settings,
redis_client: redis.Redis,
sessionmaker: async_sessionmaker[AsyncSession],
) -> None:
"""LISTEN on ``key_revoked`` and evict the Redis cache on each notification."""
dsn = _asyncpg_dsn(settings.database_url)
while True:
conn = None
try:
conn = await asyncpg.connect(dsn)
def _on_notify(
_c: object, _pid: int, _channel: str, payload: str
) -> None:
# Schedule the async eviction without blocking the callback.
asyncio.create_task(_evict(payload, sessionmaker, redis_client)) # noqa: RUF006
await conn.add_listener(_CHANNEL, _on_notify)
_log.info("revocation_listener_started")
# Wait forever; notifications arrive via the callback. asyncio.Event
# is a cancel-friendly substitute for an unbounded sleep loop.
await asyncio.Event().wait()
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001 - reconnect on any failure
_log.warning("revocation_listener_reconnect", error=str(exc))
await asyncio.sleep(2)
finally:
if conn is not None:
with contextlib.suppress(Exception):
await conn.close()
__all__ = ["revocation_listener"]

View File

@@ -0,0 +1,168 @@
"""Native Ollama passthrough routes (SPEC §6.1).
All proxied endpoints run through the shared :class:`Pipeline` (auth has already
attached the principal in middleware): scope + model + endpoint allowlist, rate
limit, budget, concurrency, body validation, then stream/relay with post-close
token-count + audit + budget accounting.
Mutating endpoints (``/api/pull|push|create|copy|delete``, ``/api/blobs/*``) and
``/api/ps`` are hard-blocked (SPEC §6.2) and intentionally NOT routed; a catch-all
returns a generic 403 so their existence is never confirmed.
"""
from __future__ import annotations
from fastapi import APIRouter, Request
from starlette.responses import JSONResponse, Response
from neuronetz_gateway import __version__
from neuronetz_gateway.deps import (
ConfigDep,
DiscoveryCacheDep,
OllamaClientDep,
PipelineDep,
PrincipalDep,
)
from neuronetz_gateway.errors import AuthorizationError
from neuronetz_gateway.proxy.allowlist import is_hard_blocked, resolve_effective_models
from neuronetz_gateway.proxy.pipeline import model_of, read_json_body
router = APIRouter(prefix="/api", tags=["ollama-native"])
@router.post("/chat")
async def chat(request: Request, pipeline: PipelineDep) -> Response:
"""Proxy ``POST /api/chat`` (streamed NDJSON / non-streamed)."""
return await _chat_or_generate(request, pipeline, path="/api/chat", scope="chat")
@router.post("/generate")
async def generate(request: Request, pipeline: PipelineDep) -> Response:
"""Proxy ``POST /api/generate`` (streamed NDJSON / non-streamed)."""
return await _chat_or_generate(request, pipeline, path="/api/generate", scope="chat")
async def _chat_or_generate(
request: Request, pipeline: PipelineDep, *, path: str, scope: str
) -> Response:
body = await read_json_body(request, pipeline.settings)
model = model_of(body)
pipeline.check_scope(scope)
pipeline.check_endpoint(path)
pipeline.check_model(model)
pipeline.validate_body(body)
await pipeline.enforce_limits()
if bool(body.get("stream", True)):
return await pipeline.stream_native("POST", path, body, model)
return await pipeline.request_native("POST", path, body, model)
@router.post("/embeddings")
async def embeddings(request: Request, pipeline: PipelineDep) -> Response:
"""Proxy ``POST /api/embeddings`` (legacy, non-streamed)."""
return await _embeddings(request, pipeline, "/api/embeddings")
@router.post("/embed")
async def embed(request: Request, pipeline: PipelineDep) -> Response:
"""Proxy ``POST /api/embed`` (non-streamed)."""
return await _embeddings(request, pipeline, "/api/embed")
async def _embeddings(request: Request, pipeline: PipelineDep, path: str) -> Response:
body = await read_json_body(request, pipeline.settings)
model = model_of(body)
pipeline.check_scope("embeddings")
pipeline.check_endpoint(path)
pipeline.check_model(model)
await pipeline.enforce_limits()
return await pipeline.request_native("POST", path, body, model)
@router.get("/tags")
async def tags(principal: PrincipalDep, discovery: DiscoveryCacheDep) -> JSONResponse:
"""Return the tenant's effective model set (live-discovered ∩ allowed)."""
effective = resolve_effective_models(
allow_all=principal.allow_all_models,
allowed_models=principal.allowed_models,
discovered=discovery.names,
)
models = [
{
"name": m.name,
"model": m.name,
"modified_at": m.modified_at,
"size": m.size_bytes,
"details": {
"family": m.family,
"parameter_size": m.parameter_size,
"quantization_level": m.quantization,
},
}
for m in discovery.models
if m.name in effective
]
return JSONResponse({"models": models})
@router.post("/show")
async def show(
request: Request,
principal: PrincipalDep,
discovery: DiscoveryCacheDep,
ollama: OllamaClientDep,
settings: ConfigDep,
) -> JSONResponse:
"""Proxy ``POST /api/show`` for an effective-set model; sanitize the result."""
body = await read_json_body(request, settings)
name = body.get("model") or body.get("name")
model = name if isinstance(name, str) else ""
if not model or model not in resolve_effective_models(
allow_all=principal.allow_all_models,
allowed_models=principal.allowed_models,
discovered=discovery.names,
):
raise AuthorizationError(internal_detail="show: model not in effective set")
resp = await ollama.request("POST", "/api/show", {"model": model})
payload = resp.json()
raw: dict[str, object] = payload if isinstance(payload, dict) else {}
# Strip system prompt + template (SPEC §6.1: no system prompts, no template).
raw_details = raw.get("details")
details: dict[str, object] = raw_details if isinstance(raw_details, dict) else {}
return JSONResponse(
{
"model": model,
"details": {
"family": details.get("family"),
"parameter_size": details.get("parameter_size"),
"quantization_level": details.get("quantization_level"),
},
}
)
@router.get("/version")
async def version() -> JSONResponse:
"""Return the gateway version (never Ollama's; SPEC §6.1)."""
return JSONResponse({"version": __version__})
@router.api_route(
"/{rest:path}",
methods=["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"],
include_in_schema=False,
)
async def catch_all(rest: str) -> Response:
"""Generic 403 for any other ``/api/*`` path (hard-blocked or unknown).
Mutating endpoints and ``/api/ps`` resolve here and return the same generic
forbidden response, so the gateway never confirms which upstream endpoints
exist (SPEC §6.2, §13.6).
"""
full = f"/api/{rest}"
if is_hard_blocked(full):
raise AuthorizationError(internal_detail=f"hard-blocked {full}")
raise AuthorizationError(internal_detail=f"unrouted upstream path {full}")
__all__ = ["router"]

View File

@@ -0,0 +1,118 @@
"""OpenAI-compatible routes (SPEC §6.3).
Each route translates the OpenAI request into the native Ollama body, runs the
same :class:`Pipeline` enforcement as the native routes, and translates the
response back. Streaming uses SSE (``data: {...}\\n\\n`` … ``data: [DONE]\\n\\n``);
non-streaming returns a single OpenAI-shaped JSON object. ``/v1/models`` returns
the tenant's effective discovered set in OpenAI list format.
"""
from __future__ import annotations
import time
from fastapi import APIRouter, Request
from starlette.responses import JSONResponse, Response
from neuronetz_gateway.deps import (
DiscoveryCacheDep,
PipelineDep,
PrincipalDep,
)
from neuronetz_gateway.proxy.allowlist import resolve_effective_models
from neuronetz_gateway.proxy.pipeline import model_of, read_json_body
from neuronetz_gateway.proxy.translate import (
models_to_openai_list,
new_completion_id,
ollama_chat_chunk_to_openai,
ollama_chat_to_openai,
ollama_embed_to_openai,
ollama_generate_chunk_to_openai,
ollama_generate_to_openai,
openai_chat_to_ollama,
openai_completion_to_ollama,
openai_embeddings_to_ollama,
)
router = APIRouter(prefix="/v1", tags=["openai-compat"])
@router.post("/chat/completions")
async def chat_completions(request: Request, pipeline: PipelineDep) -> Response:
"""OpenAI ``/v1/chat/completions`` -> Ollama ``/api/chat``."""
payload = await read_json_body(request, pipeline.settings)
body = openai_chat_to_ollama(payload)
model = model_of(body)
pipeline.check_scope("chat")
pipeline.check_endpoint("/api/chat")
pipeline.check_model(model)
pipeline.validate_body(body)
await pipeline.enforce_limits()
if bool(payload.get("stream", False)):
completion_id = new_completion_id()
created = int(time.time())
def translate(chunk: dict[str, object]) -> dict[str, object]:
return ollama_chat_chunk_to_openai(
chunk, completion_id=completion_id, model=model, created=created
)
return await pipeline.stream_openai("POST", "/api/chat", body, model, translate)
return await pipeline.request_translated(
"POST", "/api/chat", body, model, ollama_chat_to_openai
)
@router.post("/completions")
async def completions(request: Request, pipeline: PipelineDep) -> Response:
"""OpenAI ``/v1/completions`` -> Ollama ``/api/generate``."""
payload = await read_json_body(request, pipeline.settings)
body = openai_completion_to_ollama(payload)
model = model_of(body)
pipeline.check_scope("chat")
pipeline.check_endpoint("/api/generate")
pipeline.check_model(model)
pipeline.validate_body(body)
await pipeline.enforce_limits()
if bool(payload.get("stream", False)):
completion_id = new_completion_id()
created = int(time.time())
def translate(chunk: dict[str, object]) -> dict[str, object]:
return ollama_generate_chunk_to_openai(
chunk, completion_id=completion_id, model=model, created=created
)
return await pipeline.stream_openai("POST", "/api/generate", body, model, translate)
return await pipeline.request_translated(
"POST", "/api/generate", body, model, ollama_generate_to_openai
)
@router.post("/embeddings")
async def embeddings(request: Request, pipeline: PipelineDep) -> Response:
"""OpenAI ``/v1/embeddings`` -> Ollama ``/api/embed``."""
payload = await read_json_body(request, pipeline.settings)
body = openai_embeddings_to_ollama(payload)
model = model_of(body)
pipeline.check_scope("embeddings")
pipeline.check_endpoint("/api/embed")
pipeline.check_model(model)
await pipeline.enforce_limits()
return await pipeline.request_translated(
"POST", "/api/embed", body, model, lambda obj: ollama_embed_to_openai(obj, model)
)
@router.get("/models")
async def models(principal: PrincipalDep, discovery: DiscoveryCacheDep) -> JSONResponse:
"""OpenAI ``/v1/models`` -> the tenant's effective discovered set."""
effective = resolve_effective_models(
allow_all=principal.allow_all_models,
allowed_models=principal.allowed_models,
discovered=discovery.names,
)
return JSONResponse(models_to_openai_list(sorted(effective)))
__all__ = ["router"]