proxy: streaming, discovery, OpenAI-compat, rate-limit, budget, audit
The hot path. A single Pipeline class owns enforcement so the eight
non-negotiables can be reviewed in one place.
- Native /api/chat, /api/generate (NDJSON streaming + non-stream), /api/tags,
/api/show (system-prompt + template stripped), /api/embed(dings), /api/version
(returns gateway version, not Ollama's). Endpoint catch-all returns the same
generic 403 for hard-blocked and unknown /api/* paths so attackers cannot
enumerate which mutating endpoints exist.
- OpenAI-compat /v1/chat/completions, /v1/completions, /v1/embeddings,
/v1/models with SSE (`data: {...}` + final `data: [DONE]`); preserves
streaming end-to-end.
- Model discovery (SPEC §4.6): background poller against Ollama /api/tags;
Redis + in-process cache (TTL = MODEL_DISCOVERY_CACHE_TTL_S, refresh =
MODEL_DISCOVERY_REFRESH_S); fail-closed when the discovered set is empty.
- Effective-set resolution in proxy/allowlist.py:
allow_all = key.allow_all_models ?? tenant.allow_all_models
effective = discovered if allow_all
else (key.allowed_models ?? tenant.allowed_models) ∩ discovered
A non-effective model returns the same generic 403 whether it's installed-
but-unpermitted or doesn't exist at all (no enumeration leak).
- Sliding-window rate limit (Redis Lua, single round-trip) for per-key +
per-tenant RPM and per-key TPM. Redis-INCR/DECR concurrency semaphore with
TTL guard. Token-budget counters per (key, period) with a Postgres ledger
for reconciliation across resets. Headers per SPEC §6.5 on every response;
429 carries Retry-After; Redis outage → 503 (fail closed, never 200).
- Token counting from the FINAL stream object (NDJSON `done` or the SSE chunk
carrying `usage`); the audit row is written AFTER stream close so TTFB is
never degraded by bookkeeping.
- Audit writer: asyncio.Queue + bounded ring buffer; deny-mode flip on overflow.
Optional prompt log per key (TTL'd).
- Revocation listener: asyncpg LISTEN on key_revoked → evict the Redis cache
entry within ~1s of the console writing to gateway.revocations.
- Prometheus counters/histograms labeled by tenant only (per SPEC §13.3).
This commit is contained in:
3
src/neuronetz_gateway/audit/__init__.py
Normal file
3
src/neuronetz_gateway/audit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Audit logging: buffered async audit writer and opt-in prompt log."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
63
src/neuronetz_gateway/audit/prompt_log.py
Normal file
63
src/neuronetz_gateway/audit/prompt_log.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""Opt-in, TTL'd prompt logging (SPEC §2, §5).
|
||||||
|
|
||||||
|
Disabled by default; enabled per key (or inherited from tenant via the resolved
|
||||||
|
:class:`Principal.log_prompts`). Each row carries a ``retention_until`` deadline
|
||||||
|
(swept in Phase 4). A redaction hook runs before persistence so sensitive spans
|
||||||
|
can be scrubbed without changing call sites.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from neuronetz_gateway.db.models import PromptLog
|
||||||
|
|
||||||
|
# A redaction hook maps a request/response body to a sanitized copy. The default
|
||||||
|
# is identity; operators can install a stricter hook.
|
||||||
|
RedactionHook = Callable[[dict[str, object]], dict[str, object]]
|
||||||
|
|
||||||
|
|
||||||
|
def _identity(body: dict[str, object]) -> dict[str, object]:
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class PromptRecord:
|
||||||
|
"""A captured request/response pair pending TTL'd persistence."""
|
||||||
|
|
||||||
|
audit_id: int
|
||||||
|
key_id: UUID
|
||||||
|
request_body: dict[str, object]
|
||||||
|
response_text: str | None
|
||||||
|
retention_days: int
|
||||||
|
|
||||||
|
|
||||||
|
class PromptLogWriter:
|
||||||
|
"""Persists opt-in prompt records with a retention deadline."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession, redact: RedactionHook | None = None) -> None:
|
||||||
|
self._session = session
|
||||||
|
self._redact = redact or _identity
|
||||||
|
|
||||||
|
async def write(self, record: PromptRecord) -> None:
|
||||||
|
"""Persist a prompt record to ``gateway.prompt_log``."""
|
||||||
|
retention_until = datetime.datetime.now(datetime.UTC) + datetime.timedelta(
|
||||||
|
days=record.retention_days
|
||||||
|
)
|
||||||
|
row = PromptLog(
|
||||||
|
audit_id=record.audit_id,
|
||||||
|
key_id=record.key_id,
|
||||||
|
request_body=self._redact(record.request_body),
|
||||||
|
response_text=record.response_text,
|
||||||
|
retention_until=retention_until,
|
||||||
|
)
|
||||||
|
self._session.add(row)
|
||||||
|
await self._session.flush()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["PromptLogWriter", "PromptRecord", "RedactionHook"]
|
||||||
152
src/neuronetz_gateway/audit/writer.py
Normal file
152
src/neuronetz_gateway/audit/writer.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
"""Buffered async audit-log writer (SPEC §4.4).
|
||||||
|
|
||||||
|
Audit rows are enqueued *after* stream close so the hot path is never delayed
|
||||||
|
(non-negotiable #6). A background drain task persists them to Postgres. On a
|
||||||
|
Postgres write failure rows are buffered in a bounded in-memory ring (max
|
||||||
|
``AUDIT_BUFFER_SIZE``) and retried; if the ring fills, the writer flips to
|
||||||
|
**deny mode** (SPEC §4.4) — :pyattr:`deny_mode` goes True and the request path
|
||||||
|
must refuse new work until the backlog drains.
|
||||||
|
|
||||||
|
The writer holds the session factory directly (it runs outside request scope).
|
||||||
|
``enqueue`` never blocks on the DB; it only touches the in-memory queue/ring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import collections
|
||||||
|
import datetime
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from neuronetz_gateway.db.repositories import AuditRepository
|
||||||
|
from neuronetz_gateway.observability.logging import get_logger
|
||||||
|
|
||||||
|
_log = get_logger("audit")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class AuditRecord:
|
||||||
|
"""A single audit-log row queued for persistence."""
|
||||||
|
|
||||||
|
request_id: UUID
|
||||||
|
method: str
|
||||||
|
path: str
|
||||||
|
status: int
|
||||||
|
ts: datetime.datetime
|
||||||
|
tenant_id: UUID | None = None
|
||||||
|
key_id: UUID | None = None
|
||||||
|
key_prefix: str | None = None
|
||||||
|
model: str | None = None
|
||||||
|
tokens_in: int | None = None
|
||||||
|
tokens_out: int | None = None
|
||||||
|
latency_ms: int | None = None
|
||||||
|
client_ip: str | None = None
|
||||||
|
user_agent: str | None = None
|
||||||
|
error_code: str | None = None
|
||||||
|
|
||||||
|
def as_columns(self) -> dict[str, object]:
|
||||||
|
"""Map to ``gateway.audit_log`` column kwargs."""
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
|
class AuditWriter:
|
||||||
|
"""Buffered, fail-safe writer for ``gateway.audit_log``."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
buffer_size: int,
|
||||||
|
sessionmaker: async_sessionmaker[AsyncSession] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._buffer_size = buffer_size
|
||||||
|
self._sessionmaker = sessionmaker
|
||||||
|
self._ring: collections.deque[AuditRecord] = collections.deque(maxlen=buffer_size)
|
||||||
|
self._queue: asyncio.Queue[AuditRecord] = asyncio.Queue()
|
||||||
|
self._deny_mode = False
|
||||||
|
self._task: asyncio.Task[None] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def deny_mode(self) -> bool:
|
||||||
|
"""True when the buffer has overflowed and new work must be denied."""
|
||||||
|
return self._deny_mode
|
||||||
|
|
||||||
|
def bind(self, sessionmaker: async_sessionmaker[AsyncSession]) -> None:
|
||||||
|
"""Attach the DB session factory (called from lifespan startup)."""
|
||||||
|
self._sessionmaker = sessionmaker
|
||||||
|
|
||||||
|
async def enqueue(self, record: AuditRecord) -> None:
|
||||||
|
"""Queue an audit record for asynchronous persistence (non-blocking)."""
|
||||||
|
await self._queue.put(record)
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Launch the background drain task (idempotent)."""
|
||||||
|
if self._task is None or self._task.done():
|
||||||
|
self._task = asyncio.create_task(self._drain_loop())
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Cancel the drain task and flush remaining rows best-effort."""
|
||||||
|
if self._task is not None:
|
||||||
|
self._task.cancel()
|
||||||
|
try:
|
||||||
|
await self._task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._task = None
|
||||||
|
await self.flush()
|
||||||
|
|
||||||
|
async def _drain_loop(self) -> None:
|
||||||
|
"""Continuously move queued records into Postgres, retrying the ring."""
|
||||||
|
while True:
|
||||||
|
record = await self._queue.get()
|
||||||
|
await self._persist_with_retry(record)
|
||||||
|
|
||||||
|
async def _persist_with_retry(self, record: AuditRecord) -> None:
|
||||||
|
"""Persist one record; on failure push to the ring (or enter deny mode)."""
|
||||||
|
# Drain any backlog first so ordering is roughly preserved.
|
||||||
|
if self._ring:
|
||||||
|
await self._try_flush_ring()
|
||||||
|
if not await self._write_one(record):
|
||||||
|
self._buffer(record)
|
||||||
|
|
||||||
|
def _buffer(self, record: AuditRecord) -> None:
|
||||||
|
"""Buffer a failed write; flip to deny mode if the ring is full."""
|
||||||
|
if len(self._ring) >= self._buffer_size:
|
||||||
|
self._deny_mode = True
|
||||||
|
_log.error("audit_buffer_overflow_deny_mode", buffered=len(self._ring))
|
||||||
|
return
|
||||||
|
self._ring.append(record)
|
||||||
|
|
||||||
|
async def _try_flush_ring(self) -> None:
|
||||||
|
"""Attempt to persist buffered rows; clear deny mode once drained."""
|
||||||
|
while self._ring:
|
||||||
|
record = self._ring[0]
|
||||||
|
if not await self._write_one(record):
|
||||||
|
return
|
||||||
|
self._ring.popleft()
|
||||||
|
self._deny_mode = False
|
||||||
|
|
||||||
|
async def _write_one(self, record: AuditRecord) -> bool:
|
||||||
|
"""Persist a single record; return False on any failure."""
|
||||||
|
if self._sessionmaker is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
async with self._sessionmaker() as session:
|
||||||
|
await AuditRepository(session).insert_audit(**record.as_columns())
|
||||||
|
await session.commit()
|
||||||
|
return True
|
||||||
|
except Exception as exc: # noqa: BLE001 - any DB error ⇒ buffer + retry
|
||||||
|
_log.warning("audit_write_failed", error=str(exc))
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def flush(self) -> None:
|
||||||
|
"""Drain queued + buffered records to Postgres (best-effort)."""
|
||||||
|
while not self._queue.empty():
|
||||||
|
record = self._queue.get_nowait()
|
||||||
|
if not await self._write_one(record):
|
||||||
|
self._buffer(record)
|
||||||
|
await self._try_flush_ring()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["AuditRecord", "AuditWriter"]
|
||||||
3
src/neuronetz_gateway/budget/__init__.py
Normal file
3
src/neuronetz_gateway/budget/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Token budgets: Redis period counters and Postgres ledger reconciliation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
105
src/neuronetz_gateway/budget/counter.py
Normal file
105
src/neuronetz_gateway/budget/counter.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""Redis period counters for token budgets (day / month / total).
|
||||||
|
|
||||||
|
The fast-path budget check reads a Redis counter of tokens already consumed in
|
||||||
|
the active period and compares it to the configured limit; Postgres
|
||||||
|
(``ledger.py``) is the durable source of truth reconciled on rollover. Counters
|
||||||
|
are keyed by ``key_id`` + period + period-start so a new period naturally starts
|
||||||
|
at zero, and day/month counters carry a TTL so expired periods self-clean.
|
||||||
|
|
||||||
|
Fail-closed (SPEC §4.4): Redis errors raise
|
||||||
|
:class:`DependencyUnavailableError`; the caller must deny (503).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
from collections.abc import Awaitable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import redis.asyncio as redis
|
||||||
|
from redis.exceptions import RedisError
|
||||||
|
|
||||||
|
from neuronetz_gateway.db.models import BudgetPeriod
|
||||||
|
from neuronetz_gateway.errors import DependencyUnavailableError
|
||||||
|
|
||||||
|
_CONSUMED_PREFIX = "gateway:budget:"
|
||||||
|
|
||||||
|
|
||||||
|
def period_start(period: BudgetPeriod, now: datetime.datetime | None = None) -> datetime.datetime:
|
||||||
|
"""Return the UTC start of the current period."""
|
||||||
|
moment = now or datetime.datetime.now(datetime.UTC)
|
||||||
|
moment = moment.astimezone(datetime.UTC)
|
||||||
|
if period is BudgetPeriod.day:
|
||||||
|
return moment.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
if period is BudgetPeriod.month:
|
||||||
|
return moment.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
# 'total' has no rollover; anchor at the epoch.
|
||||||
|
return datetime.datetime(1970, 1, 1, tzinfo=datetime.UTC)
|
||||||
|
|
||||||
|
|
||||||
|
def _ttl_for(period: BudgetPeriod) -> int | None:
|
||||||
|
"""TTL (seconds) for a period counter; None for 'total' (no expiry)."""
|
||||||
|
if period is BudgetPeriod.day:
|
||||||
|
return 2 * 24 * 3600
|
||||||
|
if period is BudgetPeriod.month:
|
||||||
|
return 40 * 24 * 3600
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _counter_key(key_id: str, period: BudgetPeriod, start: datetime.datetime) -> str:
|
||||||
|
return f"{_CONSUMED_PREFIX}{key_id}:{period.value}:{int(start.timestamp())}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class BudgetState:
|
||||||
|
"""Snapshot of remaining budget for a period."""
|
||||||
|
|
||||||
|
period: BudgetPeriod
|
||||||
|
limit: int | None
|
||||||
|
remaining: int | None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exhausted(self) -> bool:
|
||||||
|
"""True if a finite limit is configured and nothing remains."""
|
||||||
|
return self.limit is not None and (self.remaining or 0) <= 0
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetCounter:
|
||||||
|
"""Redis-backed per-period token counter."""
|
||||||
|
|
||||||
|
def __init__(self, client: redis.Redis) -> None:
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
async def check(self, key_id: str, period: BudgetPeriod, limit: int | None) -> BudgetState:
|
||||||
|
"""Return remaining budget; ``None`` limit means unlimited for the period."""
|
||||||
|
if limit is None:
|
||||||
|
return BudgetState(period=period, limit=None, remaining=None)
|
||||||
|
start = period_start(period)
|
||||||
|
try:
|
||||||
|
raw = await self._client.get(_counter_key(key_id, period, start))
|
||||||
|
except RedisError as exc:
|
||||||
|
raise DependencyUnavailableError(
|
||||||
|
internal_detail=f"budget redis error: {exc!r}"
|
||||||
|
) from exc
|
||||||
|
consumed = int(raw) if raw else 0
|
||||||
|
return BudgetState(period=period, limit=limit, remaining=max(limit - consumed, 0))
|
||||||
|
|
||||||
|
async def consume(self, key_id: str, period: BudgetPeriod, tokens: int) -> None:
|
||||||
|
"""Increment the consumed counter after a request completes."""
|
||||||
|
if tokens <= 0:
|
||||||
|
return
|
||||||
|
start = period_start(period)
|
||||||
|
key = _counter_key(key_id, period, start)
|
||||||
|
try:
|
||||||
|
await cast("Awaitable[int]", self._client.incrby(key, tokens))
|
||||||
|
ttl = _ttl_for(period)
|
||||||
|
if ttl is not None:
|
||||||
|
await self._client.expire(key, ttl)
|
||||||
|
except RedisError as exc:
|
||||||
|
raise DependencyUnavailableError(
|
||||||
|
internal_detail=f"budget consume redis error: {exc!r}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BudgetCounter", "BudgetState", "period_start"]
|
||||||
58
src/neuronetz_gateway/budget/ledger.py
Normal file
58
src/neuronetz_gateway/budget/ledger.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Postgres budget ledger reconciliation.
|
||||||
|
|
||||||
|
Persists token usage to ``gateway.budget_usage`` (the durable source of truth)
|
||||||
|
via an idempotent upsert keyed by (key_id, period, period_start). The Redis
|
||||||
|
counter (``counter.py``) is the fast path; this ledger is what survives a Redis
|
||||||
|
flush and what ``show-usage`` reports against.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from neuronetz_gateway.budget.counter import period_start
|
||||||
|
from neuronetz_gateway.db.models import BudgetPeriod, BudgetUsage
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetLedger:
|
||||||
|
"""Source-of-truth budget accounting in Postgres."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def record_usage(
|
||||||
|
self, key_id: str, period: BudgetPeriod, tokens_in: int, tokens_out: int
|
||||||
|
) -> None:
|
||||||
|
"""Upsert usage into ``gateway.budget_usage`` for the active period.
|
||||||
|
|
||||||
|
Uses an ``ON CONFLICT`` upsert so concurrent writers accumulate rather
|
||||||
|
than clobber. ``requests`` increments by one per recorded request.
|
||||||
|
"""
|
||||||
|
start = period_start(period)
|
||||||
|
stmt = pg_insert(BudgetUsage).values(
|
||||||
|
key_id=uuid.UUID(key_id) if isinstance(key_id, str) else key_id,
|
||||||
|
period=period,
|
||||||
|
period_start=start,
|
||||||
|
tokens_in=tokens_in,
|
||||||
|
tokens_out=tokens_out,
|
||||||
|
requests=1,
|
||||||
|
)
|
||||||
|
stmt = stmt.on_conflict_do_update(
|
||||||
|
index_elements=[
|
||||||
|
BudgetUsage.key_id,
|
||||||
|
BudgetUsage.period,
|
||||||
|
BudgetUsage.period_start,
|
||||||
|
],
|
||||||
|
set_={
|
||||||
|
"tokens_in": BudgetUsage.tokens_in + stmt.excluded.tokens_in,
|
||||||
|
"tokens_out": BudgetUsage.tokens_out + stmt.excluded.tokens_out,
|
||||||
|
"requests": BudgetUsage.requests + stmt.excluded.requests,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await self._session.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BudgetLedger"]
|
||||||
67
src/neuronetz_gateway/observability/metrics.py
Normal file
67
src/neuronetz_gateway/observability/metrics.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""Prometheus metrics.
|
||||||
|
|
||||||
|
Phase 1 declares the metric objects and the exposition helper. Instrumentation
|
||||||
|
(incrementing counters / observing histograms on the request path) is wired in
|
||||||
|
later phases. Per SPEC §13.3 we label by ``tenant`` only, never by ``key_id``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from prometheus_client import CollectorRegistry, Counter, Histogram, generate_latest
|
||||||
|
|
||||||
|
REGISTRY = CollectorRegistry()
|
||||||
|
|
||||||
|
REQUESTS_TOTAL = Counter(
|
||||||
|
"gateway_requests_total",
|
||||||
|
"Total proxied requests.",
|
||||||
|
labelnames=("tenant", "model", "status"),
|
||||||
|
registry=REGISTRY,
|
||||||
|
)
|
||||||
|
|
||||||
|
TOKENS_TOTAL = Counter(
|
||||||
|
"gateway_tokens_total",
|
||||||
|
"Total tokens accounted, by direction (in|out).",
|
||||||
|
labelnames=("tenant", "model", "direction"),
|
||||||
|
registry=REGISTRY,
|
||||||
|
)
|
||||||
|
|
||||||
|
REQUEST_DURATION_SECONDS = Histogram(
|
||||||
|
"gateway_request_duration_seconds",
|
||||||
|
"Gateway-side request duration in seconds.",
|
||||||
|
labelnames=("tenant", "model"),
|
||||||
|
registry=REGISTRY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def record_request(tenant: str, model: str, status: int, duration_s: float) -> None:
|
||||||
|
"""Increment the request counter and observe its duration (tenant-labeled)."""
|
||||||
|
REQUESTS_TOTAL.labels(tenant=tenant, model=model, status=str(status)).inc()
|
||||||
|
REQUEST_DURATION_SECONDS.labels(tenant=tenant, model=model).observe(duration_s)
|
||||||
|
|
||||||
|
|
||||||
|
def record_tokens(tenant: str, model: str, tokens_in: int, tokens_out: int) -> None:
|
||||||
|
"""Add input/output token counts to the tokens counter."""
|
||||||
|
if tokens_in:
|
||||||
|
TOKENS_TOTAL.labels(tenant=tenant, model=model, direction="in").inc(tokens_in)
|
||||||
|
if tokens_out:
|
||||||
|
TOKENS_TOTAL.labels(tenant=tenant, model=model, direction="out").inc(tokens_out)
|
||||||
|
|
||||||
|
|
||||||
|
def render_latest() -> bytes:
|
||||||
|
"""Return the current metrics in Prometheus text exposition format."""
|
||||||
|
payload: bytes = generate_latest(REGISTRY)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CONTENT_TYPE_LATEST",
|
||||||
|
"REGISTRY",
|
||||||
|
"REQUESTS_TOTAL",
|
||||||
|
"REQUEST_DURATION_SECONDS",
|
||||||
|
"TOKENS_TOTAL",
|
||||||
|
"record_request",
|
||||||
|
"record_tokens",
|
||||||
|
"render_latest",
|
||||||
|
]
|
||||||
3
src/neuronetz_gateway/proxy/__init__.py
Normal file
3
src/neuronetz_gateway/proxy/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Proxy layer: Ollama client, schema translation, token counting, allowlists."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
76
src/neuronetz_gateway/proxy/allowlist.py
Normal file
76
src/neuronetz_gateway/proxy/allowlist.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""Endpoint and model allowlists (SPEC §6.1, §6.2).
|
||||||
|
|
||||||
|
Mutating Ollama endpoints are hard-blocked (not configurable, not flagged):
|
||||||
|
``/api/pull``, ``/api/push``, ``/api/create``, ``/api/copy``, ``/api/delete``,
|
||||||
|
and any ``/api/blobs/*``. ``/api/ps`` is also blocked (leaks loaded models).
|
||||||
|
Model allowlist is per-tenant, default-deny. Enforcement logic lands in Phase 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Hard-blocked upstream endpoints — always 403, not configurable (SPEC §6.2).
|
||||||
|
HARD_BLOCKED_PATHS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"/api/pull",
|
||||||
|
"/api/push",
|
||||||
|
"/api/create",
|
||||||
|
"/api/copy",
|
||||||
|
"/api/delete",
|
||||||
|
"/api/ps",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Path prefixes that are hard-blocked (e.g. blob upload/download).
|
||||||
|
HARD_BLOCKED_PREFIXES: tuple[str, ...] = ("/api/blobs",)
|
||||||
|
|
||||||
|
|
||||||
|
def is_hard_blocked(path: str) -> bool:
|
||||||
|
"""Return True if ``path`` is an unconditionally blocked upstream endpoint."""
|
||||||
|
if path in HARD_BLOCKED_PATHS:
|
||||||
|
return True
|
||||||
|
return any(path.startswith(prefix) for prefix in HARD_BLOCKED_PREFIXES)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_effective_models(
|
||||||
|
*,
|
||||||
|
allow_all: bool,
|
||||||
|
allowed_models: tuple[str, ...],
|
||||||
|
discovered: frozenset[str],
|
||||||
|
) -> frozenset[str]:
|
||||||
|
"""Resolve the effective model set per SPEC §4.3 step 7 / §4.6.
|
||||||
|
|
||||||
|
``allow_all`` ⇒ the effective set is the entire live ``discovered`` set;
|
||||||
|
otherwise it is the configured ``allowed_models`` intersected with
|
||||||
|
``discovered`` (so stale or typo'd allowlist entries never resolve, and a
|
||||||
|
model that is unpermitted vs. not-installed are indistinguishable).
|
||||||
|
|
||||||
|
Fail-closed: if ``discovered`` is empty (discovery unavailable/expired) the
|
||||||
|
result is empty regardless of ``allow_all`` — discovery never opens access.
|
||||||
|
"""
|
||||||
|
if not discovered:
|
||||||
|
return frozenset()
|
||||||
|
if allow_all:
|
||||||
|
return discovered
|
||||||
|
return frozenset(allowed_models) & discovered
|
||||||
|
|
||||||
|
|
||||||
|
def is_model_allowed(
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
allow_all: bool,
|
||||||
|
allowed_models: tuple[str, ...],
|
||||||
|
discovered: frozenset[str],
|
||||||
|
) -> bool:
|
||||||
|
"""Return True iff ``model`` is in the resolved effective set (default-deny)."""
|
||||||
|
return model in resolve_effective_models(
|
||||||
|
allow_all=allow_all, allowed_models=allowed_models, discovered=discovered
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"HARD_BLOCKED_PATHS",
|
||||||
|
"HARD_BLOCKED_PREFIXES",
|
||||||
|
"is_hard_blocked",
|
||||||
|
"is_model_allowed",
|
||||||
|
"resolve_effective_models",
|
||||||
|
]
|
||||||
207
src/neuronetz_gateway/proxy/discovery.py
Normal file
207
src/neuronetz_gateway/proxy/discovery.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""Live model discovery from the Ollama backend (SPEC §4.6).
|
||||||
|
|
||||||
|
A background task polls Ollama ``GET /api/tags`` every
|
||||||
|
``MODEL_DISCOVERY_REFRESH_S`` seconds. The parsed model set (names + sanitized
|
||||||
|
metadata) is cached in Redis under ``gateway:models:discovered`` with TTL
|
||||||
|
``MODEL_DISCOVERY_CACHE_TTL_S`` and held in-process for hot reads on the request
|
||||||
|
path.
|
||||||
|
|
||||||
|
Fail-closed (SPEC §4.6, §13.5): if Ollama is unreachable, or the cache is empty
|
||||||
|
or expired and cannot be refreshed, the discovered set is empty — and an empty
|
||||||
|
discovered set means no model resolves, so requests are denied. Discovery never
|
||||||
|
opens access on failure. It is read-only and only ever touches the allowlisted
|
||||||
|
``/api/tags`` endpoint; it never triggers a pull.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import redis.asyncio as redis
|
||||||
|
|
||||||
|
from neuronetz_gateway.config import Settings
|
||||||
|
from neuronetz_gateway.observability.logging import get_logger
|
||||||
|
|
||||||
|
_log = get_logger("discovery")
|
||||||
|
|
||||||
|
REDIS_DISCOVERED_KEY = "gateway:models:discovered"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class DiscoveredModel:
|
||||||
|
"""Sanitized metadata for a single installed model."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
family: str | None = None
|
||||||
|
parameter_size: str | None = None
|
||||||
|
quantization: str | None = None
|
||||||
|
size_bytes: int | None = None
|
||||||
|
modified_at: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_tags(payload: dict[str, object]) -> list[DiscoveredModel]:
|
||||||
|
"""Parse an Ollama ``/api/tags`` body into sanitized model records."""
|
||||||
|
models: list[DiscoveredModel] = []
|
||||||
|
raw_models = payload.get("models")
|
||||||
|
if not isinstance(raw_models, list):
|
||||||
|
return models
|
||||||
|
for entry in raw_models:
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
name = entry.get("name") or entry.get("model")
|
||||||
|
if not isinstance(name, str) or not name:
|
||||||
|
continue
|
||||||
|
raw_details = entry.get("details")
|
||||||
|
details: dict[str, object] = raw_details if isinstance(raw_details, dict) else {}
|
||||||
|
size = entry.get("size")
|
||||||
|
models.append(
|
||||||
|
DiscoveredModel(
|
||||||
|
name=name,
|
||||||
|
family=_opt_str(details.get("family")),
|
||||||
|
parameter_size=_opt_str(details.get("parameter_size")),
|
||||||
|
quantization=_opt_str(details.get("quantization_level")),
|
||||||
|
size_bytes=size if isinstance(size, int) else None,
|
||||||
|
modified_at=_opt_str(entry.get("modified_at")),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def _opt_str(value: object) -> str | None:
|
||||||
|
"""Coerce a value to ``str`` only when it already is one."""
|
||||||
|
return value if isinstance(value, str) else None
|
||||||
|
|
||||||
|
|
||||||
|
def names_of(models: list[DiscoveredModel]) -> frozenset[str]:
|
||||||
|
"""Return just the set of model names from parsed records."""
|
||||||
|
return frozenset(m.name for m in models)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscoveryCache:
|
||||||
|
"""In-process holder for the latest discovered model set.
|
||||||
|
|
||||||
|
Holds both the structured records (for ``/api/tags`` / ``list-models``) and a
|
||||||
|
fast name set (for allowlist resolution on the hot path). Reads never block
|
||||||
|
on Redis or Ollama; the poller refreshes this in the background.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._models: list[DiscoveredModel] = []
|
||||||
|
self._names: frozenset[str] = frozenset()
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def set(self, models: list[DiscoveredModel]) -> None:
|
||||||
|
"""Replace the in-process snapshot atomically."""
|
||||||
|
async with self._lock:
|
||||||
|
self._models = list(models)
|
||||||
|
self._names = names_of(models)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def names(self) -> frozenset[str]:
|
||||||
|
"""Current discovered model names (possibly empty ⇒ fail-closed)."""
|
||||||
|
return self._names
|
||||||
|
|
||||||
|
@property
|
||||||
|
def models(self) -> list[DiscoveredModel]:
|
||||||
|
"""Current discovered model records (copy)."""
|
||||||
|
return list(self._models)
|
||||||
|
|
||||||
|
|
||||||
|
async def write_discovered_to_redis(
|
||||||
|
client: redis.Redis, models: list[DiscoveredModel], ttl_s: int
|
||||||
|
) -> None:
|
||||||
|
"""Cache the discovered set in Redis with a TTL (so staleness expires)."""
|
||||||
|
payload = json.dumps([asdict(m) for m in models], separators=(",", ":"))
|
||||||
|
await client.set(REDIS_DISCOVERED_KEY, payload, ex=ttl_s)
|
||||||
|
|
||||||
|
|
||||||
|
async def read_discovered_from_redis(client: redis.Redis) -> frozenset[str]:
|
||||||
|
"""Read the cached discovered names from Redis; empty set on miss/expiry."""
|
||||||
|
raw = await client.get(REDIS_DISCOVERED_KEY)
|
||||||
|
if not raw:
|
||||||
|
return frozenset()
|
||||||
|
try:
|
||||||
|
data = json.loads(raw)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return frozenset()
|
||||||
|
if not isinstance(data, list):
|
||||||
|
return frozenset()
|
||||||
|
return frozenset(
|
||||||
|
str(item["name"]) for item in data if isinstance(item, dict) and item.get("name")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_tags(client: httpx.AsyncClient) -> list[DiscoveredModel]:
|
||||||
|
"""Fetch and parse Ollama ``/api/tags``; raise on transport/HTTP error."""
|
||||||
|
resp = await client.get("/api/tags")
|
||||||
|
resp.raise_for_status()
|
||||||
|
body = resp.json()
|
||||||
|
if not isinstance(body, dict):
|
||||||
|
return []
|
||||||
|
return _parse_tags(body)
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_once(
|
||||||
|
http_client: httpx.AsyncClient,
|
||||||
|
redis_client: redis.Redis | None,
|
||||||
|
cache: DiscoveryCache,
|
||||||
|
settings: Settings,
|
||||||
|
) -> bool:
|
||||||
|
"""Run a single discovery refresh. Returns True on success.
|
||||||
|
|
||||||
|
On any failure the in-process and Redis caches are left untouched; they
|
||||||
|
expire on their own TTL, which is the fail-closed behavior (stale-expired ⇒
|
||||||
|
empty ⇒ deny). We never *clear* eagerly on a transient error, but we also
|
||||||
|
never extend the TTL on failure.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
models = await fetch_tags(http_client)
|
||||||
|
except (httpx.HTTPError, ValueError) as exc:
|
||||||
|
_log.warning("discovery_refresh_failed", error=str(exc))
|
||||||
|
return False
|
||||||
|
await cache.set(models)
|
||||||
|
if redis_client is not None:
|
||||||
|
try:
|
||||||
|
await write_discovered_to_redis(
|
||||||
|
redis_client, models, settings.model_discovery_cache_ttl_s
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001 - Redis write is best-effort cache fill
|
||||||
|
_log.warning("discovery_cache_write_failed", error=str(exc))
|
||||||
|
_log.info("discovery_refreshed", count=len(models))
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def discovery_loop(
|
||||||
|
http_client: httpx.AsyncClient,
|
||||||
|
redis_client: redis.Redis | None,
|
||||||
|
cache: DiscoveryCache,
|
||||||
|
settings: Settings,
|
||||||
|
) -> None:
|
||||||
|
"""Background poller: refresh now, then every ``MODEL_DISCOVERY_REFRESH_S``.
|
||||||
|
|
||||||
|
Designed to be launched via ``asyncio.create_task`` in the lifespan and
|
||||||
|
cancelled on shutdown.
|
||||||
|
"""
|
||||||
|
await refresh_once(http_client, redis_client, cache, settings)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(settings.model_discovery_refresh_s)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
await refresh_once(http_client, redis_client, cache, settings)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"REDIS_DISCOVERED_KEY",
|
||||||
|
"DiscoveredModel",
|
||||||
|
"DiscoveryCache",
|
||||||
|
"discovery_loop",
|
||||||
|
"fetch_tags",
|
||||||
|
"names_of",
|
||||||
|
"read_discovered_from_redis",
|
||||||
|
"refresh_once",
|
||||||
|
"write_discovered_to_redis",
|
||||||
|
]
|
||||||
79
src/neuronetz_gateway/proxy/ollama.py
Normal file
79
src/neuronetz_gateway/proxy/ollama.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""httpx-based streaming proxy client for Ollama.
|
||||||
|
|
||||||
|
Opens async streams to the upstream and relays bytes without buffering, so the
|
||||||
|
gateway does not degrade time-to-first-byte (SPEC §9, non-negotiable #6).
|
||||||
|
Transport-level upstream failures are sanitized at this boundary into
|
||||||
|
:class:`UpstreamUnavailableError` (never reflected verbatim, non-negotiable #4).
|
||||||
|
|
||||||
|
The client is constructed from the shared ``httpx.AsyncClient`` on
|
||||||
|
``app.state`` and is exposed to routes via the ``get_ollama_client`` dependency
|
||||||
|
in ``deps.py`` so tests can override it (the QA override contract).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from neuronetz_gateway.errors import UpstreamUnavailableError
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaClient:
|
||||||
|
"""Thin async wrapper around the shared httpx client for Ollama calls."""
|
||||||
|
|
||||||
|
def __init__(self, client: httpx.AsyncClient) -> None:
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
async def stream(
|
||||||
|
self, method: str, path: str, json_body: dict[str, object]
|
||||||
|
) -> AsyncIterator[bytes]:
|
||||||
|
"""Open a streaming request to Ollama and yield raw response chunks.
|
||||||
|
|
||||||
|
Yields bytes exactly as received so the hot path performs no buffering.
|
||||||
|
Transport errors are sanitized; if the upstream returns a non-2xx the
|
||||||
|
body is drained (so internals are not surfaced) and a generic 502 is
|
||||||
|
raised before any client bytes are emitted.
|
||||||
|
"""
|
||||||
|
request = self._client.build_request(method, path, json=json_body)
|
||||||
|
try:
|
||||||
|
response = await self._client.send(request, stream=True)
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
raise UpstreamUnavailableError(internal_detail=f"ollama send failed: {exc!r}") from exc
|
||||||
|
try:
|
||||||
|
if response.status_code >= 400:
|
||||||
|
await response.aread()
|
||||||
|
raise UpstreamUnavailableError(
|
||||||
|
internal_detail=f"ollama returned {response.status_code} for {path}"
|
||||||
|
)
|
||||||
|
async for chunk in response.aiter_raw():
|
||||||
|
yield chunk
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
raise UpstreamUnavailableError(
|
||||||
|
internal_detail=f"ollama stream failed: {exc!r}"
|
||||||
|
) from exc
|
||||||
|
finally:
|
||||||
|
await response.aclose()
|
||||||
|
|
||||||
|
async def request(
|
||||||
|
self, method: str, path: str, json_body: dict[str, object]
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""Perform a non-streaming request to Ollama.
|
||||||
|
|
||||||
|
Returns the upstream response on success; raises a sanitized 502 on
|
||||||
|
transport failure or a non-2xx status (internals never reflected).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await self._client.request(method, path, json=json_body)
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
raise UpstreamUnavailableError(
|
||||||
|
internal_detail=f"ollama request failed: {exc!r}"
|
||||||
|
) from exc
|
||||||
|
if response.status_code >= 400:
|
||||||
|
raise UpstreamUnavailableError(
|
||||||
|
internal_detail=f"ollama returned {response.status_code} for {path}"
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["OllamaClient"]
|
||||||
467
src/neuronetz_gateway/proxy/pipeline.py
Normal file
467
src/neuronetz_gateway/proxy/pipeline.py
Normal file
@@ -0,0 +1,467 @@
|
|||||||
|
"""Shared request pipeline: enforcement, streaming, and post-close bookkeeping.
|
||||||
|
|
||||||
|
Both the native (``/api/*``) and OpenAI-compat (``/v1/*``) routes funnel through
|
||||||
|
here so the security checks and the streaming-integrity contract are written
|
||||||
|
once. The order mirrors SPEC §4.3:
|
||||||
|
|
||||||
|
rate limit (per-key + per-tenant RPM) → budget → concurrency → model allowlist
|
||||||
|
→ endpoint allowlist → body validation → proxy + stream → post-close audit /
|
||||||
|
token-count / budget-consume / metrics / semaphore release.
|
||||||
|
|
||||||
|
Streaming integrity (non-negotiable #6): the bytes flow to the client untouched
|
||||||
|
and token counting + audit + budget-consume happen **after** the stream closes,
|
||||||
|
never on the hot path.
|
||||||
|
|
||||||
|
Fail-closed (non-negotiable #1): every limiter/budget call raises
|
||||||
|
:class:`DependencyUnavailableError` when Redis is down, which the error handler
|
||||||
|
renders as 503. The model allowlist is default-deny against the live discovered
|
||||||
|
set; a missing/expired discovery set denies everything.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncIterator, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
from starlette.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
|
from neuronetz_gateway.audit.writer import AuditRecord, AuditWriter
|
||||||
|
from neuronetz_gateway.auth.principal import Principal
|
||||||
|
from neuronetz_gateway.budget.counter import BudgetCounter
|
||||||
|
from neuronetz_gateway.budget.ledger import BudgetLedger
|
||||||
|
from neuronetz_gateway.config import Settings
|
||||||
|
from neuronetz_gateway.db.models import BudgetPeriod
|
||||||
|
from neuronetz_gateway.errors import (
|
||||||
|
AuthorizationError,
|
||||||
|
BudgetExceededError,
|
||||||
|
RateLimitError,
|
||||||
|
RequestTooLargeError,
|
||||||
|
)
|
||||||
|
from neuronetz_gateway.observability import metrics
|
||||||
|
from neuronetz_gateway.observability.logging import get_logger
|
||||||
|
from neuronetz_gateway.proxy.allowlist import is_hard_blocked, is_model_allowed
|
||||||
|
from neuronetz_gateway.proxy.discovery import DiscoveryCache
|
||||||
|
from neuronetz_gateway.proxy.ollama import OllamaClient
|
||||||
|
from neuronetz_gateway.proxy.token_counter import TokenUsage, extract_usage
|
||||||
|
from neuronetz_gateway.ratelimit.concurrency import ConcurrencyLimiter
|
||||||
|
from neuronetz_gateway.ratelimit.sliding_window import SlidingWindowLimiter
|
||||||
|
|
||||||
|
_log = get_logger("pipeline")
|
||||||
|
|
||||||
|
NDJSON_MEDIA_TYPE = "application/x-ndjson"
|
||||||
|
SSE_MEDIA_TYPE = "text/event-stream"
|
||||||
|
_CONCURRENCY_PREFIX = "gateway:concurrency:"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class RateHeaders:
|
||||||
|
"""The §6.5 rate/budget header values gathered during pre-flight."""
|
||||||
|
|
||||||
|
limit_requests: int
|
||||||
|
remaining_requests: int
|
||||||
|
limit_tokens: int
|
||||||
|
remaining_tokens: int
|
||||||
|
budget_period: str
|
||||||
|
budget_remaining: str
|
||||||
|
|
||||||
|
def as_dict(self, request_id: str) -> dict[str, str]:
|
||||||
|
"""Render to the SPEC §6.5 header set."""
|
||||||
|
return {
|
||||||
|
"X-Request-ID": request_id,
|
||||||
|
"X-RateLimit-Limit-Requests": str(self.limit_requests),
|
||||||
|
"X-RateLimit-Remaining-Requests": str(self.remaining_requests),
|
||||||
|
"X-RateLimit-Limit-Tokens": str(self.limit_tokens),
|
||||||
|
"X-RateLimit-Remaining-Tokens": str(self.remaining_tokens),
|
||||||
|
"X-Budget-Period": self.budget_period,
|
||||||
|
"X-Budget-Tokens-Remaining": self.budget_remaining,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline:
|
||||||
|
"""Per-request enforcement + proxy orchestrator."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
request: Request,
|
||||||
|
principal: Principal,
|
||||||
|
settings: Settings,
|
||||||
|
ollama: OllamaClient,
|
||||||
|
discovery: DiscoveryCache,
|
||||||
|
rate_limiter: SlidingWindowLimiter,
|
||||||
|
concurrency: ConcurrencyLimiter,
|
||||||
|
budget: BudgetCounter,
|
||||||
|
audit: AuditWriter,
|
||||||
|
sessionmaker: async_sessionmaker[AsyncSession] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._request = request
|
||||||
|
self._p = principal
|
||||||
|
self._settings = settings
|
||||||
|
self._ollama = ollama
|
||||||
|
self._discovery = discovery
|
||||||
|
self._rate = rate_limiter
|
||||||
|
self._conc = concurrency
|
||||||
|
self._budget = budget
|
||||||
|
self._audit = audit
|
||||||
|
self._sessionmaker = sessionmaker
|
||||||
|
self._request_id = str(getattr(request.state, "request_id", uuid.uuid4()))
|
||||||
|
self._concurrency_key = f"{_CONCURRENCY_PREFIX}{principal.tenant_id}"
|
||||||
|
self._headers = RateHeaders(
|
||||||
|
limit_requests=principal.limits.rpm,
|
||||||
|
remaining_requests=principal.limits.rpm,
|
||||||
|
limit_tokens=principal.limits.tpm,
|
||||||
|
remaining_tokens=principal.limits.tpm,
|
||||||
|
budget_period="day",
|
||||||
|
budget_remaining="unlimited",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def settings(self) -> Settings:
|
||||||
|
"""The active settings (for body-size / num_predict caps in routes)."""
|
||||||
|
return self._settings
|
||||||
|
|
||||||
|
# ----- enforcement -------------------------------------------------------
|
||||||
|
|
||||||
|
def check_scope(self, scope: str) -> None:
|
||||||
|
"""Authorize a coarse scope (e.g. 'chat', 'embeddings')."""
|
||||||
|
if scope not in self._p.scopes:
|
||||||
|
raise AuthorizationError(internal_detail=f"scope {scope!r} not granted")
|
||||||
|
|
||||||
|
def check_endpoint(self, path: str) -> None:
|
||||||
|
"""Reject hard-blocked upstream endpoints with a generic 403."""
|
||||||
|
if is_hard_blocked(path):
|
||||||
|
raise AuthorizationError(internal_detail=f"hard-blocked endpoint {path}")
|
||||||
|
|
||||||
|
def check_model(self, model: str) -> None:
|
||||||
|
"""Default-deny model check against the live effective set (§4.3 step 7)."""
|
||||||
|
if not model or not is_model_allowed(
|
||||||
|
model,
|
||||||
|
allow_all=self._p.allow_all_models,
|
||||||
|
allowed_models=self._p.allowed_models,
|
||||||
|
discovered=self._discovery.names,
|
||||||
|
):
|
||||||
|
# No existence disclosure (SPEC §13.6): unpermitted and not-installed
|
||||||
|
# both yield the same generic 403.
|
||||||
|
raise AuthorizationError(internal_detail=f"model {model!r} not in effective set")
|
||||||
|
|
||||||
|
def validate_body(self, body: dict[str, object]) -> None:
|
||||||
|
"""Enforce the ``num_predict`` cap (body-size cap is enforced earlier)."""
|
||||||
|
options = body.get("options")
|
||||||
|
if isinstance(options, dict):
|
||||||
|
num_predict = options.get("num_predict")
|
||||||
|
if isinstance(num_predict, int) and num_predict > self._settings.max_num_predict:
|
||||||
|
options["num_predict"] = self._settings.max_num_predict
|
||||||
|
|
||||||
|
async def enforce_limits(self, *, token_estimate: int = 0) -> None:
|
||||||
|
"""Run RPM (per-key + per-tenant), TPM, budget, and concurrency checks.
|
||||||
|
|
||||||
|
Budget is checked before concurrency so an over-budget request never even
|
||||||
|
acquires a permit. Order otherwise follows SPEC §4.3 steps 4-6.
|
||||||
|
"""
|
||||||
|
await self._check_rpm()
|
||||||
|
await self._check_tpm(token_estimate)
|
||||||
|
await self.check_budgets()
|
||||||
|
await self._acquire_concurrency()
|
||||||
|
|
||||||
|
async def _check_rpm(self) -> None:
|
||||||
|
key_result = await self._rate.check(
|
||||||
|
f"gateway:rpm:key:{self._p.key_id}", self._p.limits.rpm, 60, cost=1
|
||||||
|
)
|
||||||
|
self._headers.limit_requests = key_result.limit
|
||||||
|
self._headers.remaining_requests = key_result.remaining
|
||||||
|
if not key_result.allowed:
|
||||||
|
raise RateLimitError(retry_after=key_result.retry_after_s)
|
||||||
|
tenant_result = await self._rate.check(
|
||||||
|
f"gateway:rpm:tenant:{self._p.tenant_id}", self._p.limits.rpm, 60, cost=1
|
||||||
|
)
|
||||||
|
if not tenant_result.allowed:
|
||||||
|
raise RateLimitError(retry_after=tenant_result.retry_after_s)
|
||||||
|
|
||||||
|
async def _check_tpm(self, token_estimate: int) -> None:
|
||||||
|
# TPM is charged with a minimum of 1 so a request always counts; the
|
||||||
|
# precise token cost is reconciled post-stream via the budget counter.
|
||||||
|
cost = max(token_estimate, 1)
|
||||||
|
result = await self._rate.check(
|
||||||
|
f"gateway:tpm:key:{self._p.key_id}", self._p.limits.tpm, 60, cost=cost
|
||||||
|
)
|
||||||
|
self._headers.limit_tokens = result.limit
|
||||||
|
self._headers.remaining_tokens = result.remaining
|
||||||
|
if not result.allowed:
|
||||||
|
raise RateLimitError(retry_after=result.retry_after_s)
|
||||||
|
|
||||||
|
def _budget_periods(self) -> list[tuple[BudgetPeriod, int | None]]:
|
||||||
|
return [
|
||||||
|
(BudgetPeriod.day, self._p.limits.tokens_daily),
|
||||||
|
(BudgetPeriod.month, self._p.limits.tokens_monthly),
|
||||||
|
(BudgetPeriod.total, self._p.limits.tokens_total),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def check_budgets(self) -> None:
|
||||||
|
"""Verify no configured budget period is already exhausted."""
|
||||||
|
tightest_period = "day"
|
||||||
|
tightest_remaining = "unlimited"
|
||||||
|
for period, limit in self._budget_periods():
|
||||||
|
state = await self._budget.check(str(self._p.key_id), period, limit)
|
||||||
|
if state.exhausted:
|
||||||
|
raise BudgetExceededError(internal_detail=f"budget {period.value} exhausted")
|
||||||
|
if state.remaining is not None:
|
||||||
|
tightest_period = period.value
|
||||||
|
tightest_remaining = str(state.remaining)
|
||||||
|
self._headers.budget_period = tightest_period
|
||||||
|
self._headers.budget_remaining = tightest_remaining
|
||||||
|
|
||||||
|
async def _acquire_concurrency(self) -> None:
|
||||||
|
ok = await self._conc.acquire(
|
||||||
|
self._concurrency_key,
|
||||||
|
self._p.limits.concurrent,
|
||||||
|
self._settings.ollama_read_timeout_s + 30,
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
|
raise RateLimitError(
|
||||||
|
retry_after=1, internal_detail="concurrency cap reached"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ----- proxy + bookkeeping ----------------------------------------------
|
||||||
|
|
||||||
|
def headers(self) -> dict[str, str]:
|
||||||
|
"""Render the §6.5 response headers."""
|
||||||
|
return self._headers.as_dict(self._request_id)
|
||||||
|
|
||||||
|
async def stream_native(
|
||||||
|
self, method: str, path: str, body: dict[str, object], model: str
|
||||||
|
) -> StreamingResponse:
|
||||||
|
"""Proxy a streaming NDJSON request, accounting tokens after close."""
|
||||||
|
started = time.monotonic()
|
||||||
|
media_type = NDJSON_MEDIA_TYPE
|
||||||
|
|
||||||
|
async def gen() -> AsyncIterator[bytes]:
|
||||||
|
last_obj: dict[str, object] = {}
|
||||||
|
try:
|
||||||
|
async for chunk in self._ollama.stream(method, path, body):
|
||||||
|
last_obj = _merge_last_ndjson(chunk, last_obj)
|
||||||
|
yield chunk
|
||||||
|
finally:
|
||||||
|
await self._finish(model, path, method, last_obj, started)
|
||||||
|
|
||||||
|
return StreamingResponse(gen(), media_type=media_type, headers=self.headers())
|
||||||
|
|
||||||
|
async def stream_openai(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
body: dict[str, object],
|
||||||
|
model: str,
|
||||||
|
chunk_translator: Callable[[dict[str, object]], dict[str, object]],
|
||||||
|
) -> StreamingResponse:
|
||||||
|
"""Proxy + translate native NDJSON into OpenAI SSE; account after close."""
|
||||||
|
started = time.monotonic()
|
||||||
|
|
||||||
|
async def gen() -> AsyncIterator[bytes]:
|
||||||
|
last_obj: dict[str, object] = {}
|
||||||
|
buffer = b""
|
||||||
|
try:
|
||||||
|
async for chunk in self._ollama.stream(method, path, body):
|
||||||
|
buffer += chunk
|
||||||
|
lines = buffer.split(b"\n")
|
||||||
|
buffer = lines.pop()
|
||||||
|
for line in lines:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
obj = json.loads(line)
|
||||||
|
last_obj = obj if isinstance(obj, dict) else last_obj
|
||||||
|
translated = chunk_translator(obj)
|
||||||
|
yield f"data: {json.dumps(translated)}\n\n".encode()
|
||||||
|
if buffer.strip():
|
||||||
|
obj = json.loads(buffer)
|
||||||
|
last_obj = obj if isinstance(obj, dict) else last_obj
|
||||||
|
yield f"data: {json.dumps(chunk_translator(obj))}\n\n".encode()
|
||||||
|
yield b"data: [DONE]\n\n"
|
||||||
|
finally:
|
||||||
|
await self._finish(model, path, method, last_obj, started)
|
||||||
|
|
||||||
|
return StreamingResponse(gen(), media_type=SSE_MEDIA_TYPE, headers=self.headers())
|
||||||
|
|
||||||
|
async def request_native(
|
||||||
|
self, method: str, path: str, body: dict[str, object], model: str
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""Proxy a non-streaming request and account tokens before responding."""
|
||||||
|
started = time.monotonic()
|
||||||
|
resp = await self._ollama.request(method, path, body)
|
||||||
|
payload = resp.json()
|
||||||
|
obj = payload if isinstance(payload, dict) else {}
|
||||||
|
await self._finish(model, path, method, obj, started)
|
||||||
|
return JSONResponse(obj, headers=self.headers())
|
||||||
|
|
||||||
|
async def request_translated(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
body: dict[str, object],
|
||||||
|
model: str,
|
||||||
|
translator: Callable[[dict[str, object]], dict[str, object]],
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""Proxy a non-streaming request, translate the body, then account."""
|
||||||
|
started = time.monotonic()
|
||||||
|
resp = await self._ollama.request(method, path, body)
|
||||||
|
payload = resp.json()
|
||||||
|
obj = payload if isinstance(payload, dict) else {}
|
||||||
|
await self._finish(model, path, method, obj, started)
|
||||||
|
return JSONResponse(translator(obj), headers=self.headers())
|
||||||
|
|
||||||
|
async def _finish(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
path: str,
|
||||||
|
method: str,
|
||||||
|
final_obj: dict[str, object],
|
||||||
|
started: float,
|
||||||
|
) -> None:
|
||||||
|
"""Post-close bookkeeping: tokens, budget, metrics, audit, semaphore.
|
||||||
|
|
||||||
|
Runs once per request after the response is fully produced. Each step is
|
||||||
|
best-effort and guarded so a bookkeeping failure never corrupts the
|
||||||
|
already-delivered response. The concurrency permit is always released.
|
||||||
|
"""
|
||||||
|
usage = extract_usage(final_obj) if final_obj else TokenUsage(0, 0)
|
||||||
|
latency_ms = int((time.monotonic() - started) * 1000)
|
||||||
|
try:
|
||||||
|
await self._account_budget(usage)
|
||||||
|
except Exception as exc: # noqa: BLE001 - never break a delivered response
|
||||||
|
_log.warning("budget_account_failed", error=str(exc))
|
||||||
|
try:
|
||||||
|
metrics.record_request(self._p.tenant_name, model or "unknown", 200, latency_ms / 1000)
|
||||||
|
metrics.record_tokens(
|
||||||
|
self._p.tenant_name, model or "unknown", usage.tokens_in, usage.tokens_out
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001 - metrics must not break responses
|
||||||
|
_log.warning("metrics_record_failed", error=str(exc))
|
||||||
|
await self._write_audit(model, path, method, usage, latency_ms, 200)
|
||||||
|
await self._release_concurrency()
|
||||||
|
|
||||||
|
async def _account_budget(self, usage: TokenUsage) -> None:
|
||||||
|
"""Decrement Redis budget counters and persist to the Postgres ledger."""
|
||||||
|
for period, limit in self._budget_periods():
|
||||||
|
if limit is not None:
|
||||||
|
await self._budget.consume(str(self._p.key_id), period, usage.total)
|
||||||
|
if self._sessionmaker is not None and usage.total >= 0:
|
||||||
|
try:
|
||||||
|
async with self._sessionmaker() as session:
|
||||||
|
ledger = BudgetLedger(session)
|
||||||
|
for period, _limit in self._budget_periods():
|
||||||
|
await ledger.record_usage(
|
||||||
|
str(self._p.key_id), period, usage.tokens_in, usage.tokens_out
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
except Exception as exc: # noqa: BLE001 - ledger is durable backstop; never break response
|
||||||
|
_log.warning("ledger_record_failed", error=str(exc))
|
||||||
|
|
||||||
|
async def _write_audit(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
path: str,
|
||||||
|
method: str,
|
||||||
|
usage: TokenUsage,
|
||||||
|
latency_ms: int,
|
||||||
|
status: int,
|
||||||
|
) -> None:
|
||||||
|
record = AuditRecord(
|
||||||
|
request_id=uuid.UUID(self._request_id)
|
||||||
|
if _is_uuid(self._request_id)
|
||||||
|
else uuid.uuid4(),
|
||||||
|
method=method,
|
||||||
|
path=path,
|
||||||
|
status=status,
|
||||||
|
ts=datetime.datetime.now(datetime.UTC),
|
||||||
|
tenant_id=self._p.tenant_id,
|
||||||
|
key_id=self._p.key_id,
|
||||||
|
key_prefix=self._p.key_prefix,
|
||||||
|
model=model or None,
|
||||||
|
tokens_in=usage.tokens_in,
|
||||||
|
tokens_out=usage.tokens_out,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
client_ip=self._client_ip(),
|
||||||
|
user_agent=self._request.headers.get("user-agent"),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await self._audit.enqueue(record)
|
||||||
|
except Exception as exc: # noqa: BLE001 - audit enqueue must not break response
|
||||||
|
_log.warning("audit_enqueue_failed", error=str(exc))
|
||||||
|
|
||||||
|
async def _release_concurrency(self) -> None:
|
||||||
|
try:
|
||||||
|
await self._conc.release(self._concurrency_key)
|
||||||
|
except Exception as exc: # noqa: BLE001 - release is best-effort
|
||||||
|
_log.warning("concurrency_release_failed", error=str(exc))
|
||||||
|
|
||||||
|
def _client_ip(self) -> str | None:
|
||||||
|
xff = self._request.headers.get("x-forwarded-for")
|
||||||
|
if xff:
|
||||||
|
return xff.split(",")[0].strip()
|
||||||
|
return self._request.client.host if self._request.client else None
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_last_ndjson(chunk: bytes, prev: dict[str, object]) -> dict[str, object]:
|
||||||
|
"""Track the last complete NDJSON object seen in a raw byte chunk.
|
||||||
|
|
||||||
|
Token counts live on the final ``done`` frame. We parse only complete lines
|
||||||
|
and keep the last successfully-parsed object; partial trailing data is
|
||||||
|
ignored here and will be completed by a subsequent chunk.
|
||||||
|
"""
|
||||||
|
text = chunk.decode("utf-8", errors="ignore")
|
||||||
|
last = prev
|
||||||
|
for line in text.split("\n"):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
obj = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
last = obj
|
||||||
|
return last
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
try:
|
||||||
|
uuid.UUID(value)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def model_of(body: dict[str, object]) -> str:
|
||||||
|
"""Extract the requested model name from a request body (empty if absent)."""
|
||||||
|
model = body.get("model")
|
||||||
|
return model if isinstance(model, str) else ""
|
||||||
|
|
||||||
|
|
||||||
|
async def read_json_body(request: Request, settings: Settings) -> dict[str, object]:
|
||||||
|
"""Read + size-limit the request body, returning the parsed JSON object."""
|
||||||
|
raw = await request.body()
|
||||||
|
if len(raw) > settings.max_request_body_bytes:
|
||||||
|
raise RequestTooLargeError(internal_detail=f"body {len(raw)} bytes")
|
||||||
|
if not raw:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
parsed: Any = json.loads(raw)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise AuthorizationError(internal_detail="invalid JSON body") from exc
|
||||||
|
return parsed if isinstance(parsed, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"NDJSON_MEDIA_TYPE",
|
||||||
|
"SSE_MEDIA_TYPE",
|
||||||
|
"Pipeline",
|
||||||
|
"RateHeaders",
|
||||||
|
"model_of",
|
||||||
|
"read_json_body",
|
||||||
|
]
|
||||||
50
src/neuronetz_gateway/proxy/token_counter.py
Normal file
50
src/neuronetz_gateway/proxy/token_counter.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Precise token accounting parsed from Ollama responses (SPEC §2, §13.1).
|
||||||
|
|
||||||
|
Tokens are read from Ollama's reported ``prompt_eval_count`` (input) and
|
||||||
|
``eval_count`` (output) on the final stream frame — never heuristically
|
||||||
|
estimated. Embeddings charge ``prompt_eval_count`` only (SPEC §13.1); they have
|
||||||
|
no ``eval_count`` so output tokens are reported as zero.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class TokenUsage:
|
||||||
|
"""Token counts extracted from an Ollama response."""
|
||||||
|
|
||||||
|
tokens_in: int
|
||||||
|
tokens_out: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total(self) -> int:
|
||||||
|
"""Combined input + output tokens."""
|
||||||
|
return self.tokens_in + self.tokens_out
|
||||||
|
|
||||||
|
|
||||||
|
def _as_int(value: object) -> int:
|
||||||
|
"""Coerce an Ollama-reported count to a non-negative int (0 if absent/bad)."""
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return 0
|
||||||
|
if isinstance(value, int):
|
||||||
|
return max(value, 0)
|
||||||
|
if isinstance(value, float):
|
||||||
|
return max(int(value), 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def extract_usage(final_frame: dict[str, object]) -> TokenUsage:
|
||||||
|
"""Extract ``prompt_eval_count``/``eval_count`` from the final Ollama frame.
|
||||||
|
|
||||||
|
Works for chat/generate final frames and for embeddings responses (which
|
||||||
|
carry ``prompt_eval_count`` but no ``eval_count``).
|
||||||
|
"""
|
||||||
|
return TokenUsage(
|
||||||
|
tokens_in=_as_int(final_frame.get("prompt_eval_count")),
|
||||||
|
tokens_out=_as_int(final_frame.get("eval_count")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["TokenUsage", "extract_usage"]
|
||||||
245
src/neuronetz_gateway/proxy/translate.py
Normal file
245
src/neuronetz_gateway/proxy/translate.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""OpenAI <-> Ollama schema translation (SPEC §6.3).
|
||||||
|
|
||||||
|
Native ``/api/*`` speaks NDJSON; OpenAI-compat ``/v1/*`` speaks SSE
|
||||||
|
(``data: {...}\\n\\n`` … ``data: [DONE]\\n\\n``). These helpers translate request
|
||||||
|
bodies in both directions and convert native Ollama stream frames into OpenAI
|
||||||
|
chunk objects, preserving streaming. Unknown OpenAI sampling params are mapped
|
||||||
|
into Ollama's ``options`` block; unrecognized keys are dropped rather than
|
||||||
|
forwarded blindly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# OpenAI sampling params that map into Ollama's ``options`` object.
|
||||||
|
_OPTION_KEYS: tuple[str, ...] = (
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
"seed",
|
||||||
|
"stop",
|
||||||
|
"presence_penalty",
|
||||||
|
"frequency_penalty",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _as_int(value: object) -> int:
|
||||||
|
"""Coerce a JSON-derived value to a non-negative int (0 if absent/invalid)."""
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return 0
|
||||||
|
if isinstance(value, int):
|
||||||
|
return max(value, 0)
|
||||||
|
if isinstance(value, float):
|
||||||
|
return max(int(value), 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _build_options(payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Collect OpenAI sampling params into an Ollama ``options`` mapping."""
|
||||||
|
options: dict[str, Any] = {}
|
||||||
|
for key in _OPTION_KEYS:
|
||||||
|
if key in payload and payload[key] is not None:
|
||||||
|
options[key] = payload[key]
|
||||||
|
if "max_tokens" in payload and payload["max_tokens"] is not None:
|
||||||
|
options["num_predict"] = payload["max_tokens"]
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def openai_chat_to_ollama(payload: dict[str, object]) -> dict[str, object]:
|
||||||
|
"""Translate an OpenAI chat-completion request to an Ollama ``/api/chat`` body."""
|
||||||
|
body: dict[str, object] = {
|
||||||
|
"model": payload.get("model"),
|
||||||
|
"messages": payload.get("messages", []),
|
||||||
|
"stream": bool(payload.get("stream", False)),
|
||||||
|
}
|
||||||
|
options = _build_options(dict(payload))
|
||||||
|
if options:
|
||||||
|
body["options"] = options
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def openai_completion_to_ollama(payload: dict[str, object]) -> dict[str, object]:
|
||||||
|
"""Translate an OpenAI completion request to an Ollama ``/api/generate`` body."""
|
||||||
|
prompt = payload.get("prompt", "")
|
||||||
|
if isinstance(prompt, list):
|
||||||
|
prompt = "".join(str(p) for p in prompt)
|
||||||
|
body: dict[str, object] = {
|
||||||
|
"model": payload.get("model"),
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": bool(payload.get("stream", False)),
|
||||||
|
}
|
||||||
|
options = _build_options(dict(payload))
|
||||||
|
if options:
|
||||||
|
body["options"] = options
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def openai_embeddings_to_ollama(payload: dict[str, object]) -> dict[str, object]:
|
||||||
|
"""Translate an OpenAI embeddings request to an Ollama ``/api/embed`` body."""
|
||||||
|
return {
|
||||||
|
"model": payload.get("model"),
|
||||||
|
"input": payload.get("input", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _completion_id() -> str:
|
||||||
|
"""Generate an OpenAI-style completion id."""
|
||||||
|
return f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
|
||||||
|
def ollama_chat_chunk_to_openai(
|
||||||
|
chunk: dict[str, object], *, completion_id: str, model: str, created: int
|
||||||
|
) -> dict[str, object]:
|
||||||
|
"""Translate one Ollama ``/api/chat`` NDJSON frame to an OpenAI SSE chunk."""
|
||||||
|
done = bool(chunk.get("done"))
|
||||||
|
if done:
|
||||||
|
return {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": _as_int(chunk.get("prompt_eval_count")),
|
||||||
|
"completion_tokens": _as_int(chunk.get("eval_count")),
|
||||||
|
"total_tokens": _as_int(chunk.get("prompt_eval_count"))
|
||||||
|
+ _as_int(chunk.get("eval_count")),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
message = chunk.get("message")
|
||||||
|
content = ""
|
||||||
|
if isinstance(message, dict):
|
||||||
|
content = str(message.get("content", ""))
|
||||||
|
return {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def ollama_generate_chunk_to_openai(
|
||||||
|
chunk: dict[str, object], *, completion_id: str, model: str, created: int
|
||||||
|
) -> dict[str, object]:
|
||||||
|
"""Translate one Ollama ``/api/generate`` NDJSON frame to an OpenAI text chunk."""
|
||||||
|
done = bool(chunk.get("done"))
|
||||||
|
if done:
|
||||||
|
return {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": [{"index": 0, "text": "", "finish_reason": "stop"}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": _as_int(chunk.get("prompt_eval_count")),
|
||||||
|
"completion_tokens": _as_int(chunk.get("eval_count")),
|
||||||
|
"total_tokens": _as_int(chunk.get("prompt_eval_count"))
|
||||||
|
+ _as_int(chunk.get("eval_count")),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": [{"index": 0, "text": str(chunk.get("response", "")), "finish_reason": None}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def ollama_chat_to_openai(payload: dict[str, object]) -> dict[str, object]:
|
||||||
|
"""Translate a *non-streaming* Ollama chat response to an OpenAI completion."""
|
||||||
|
message = payload.get("message")
|
||||||
|
content = ""
|
||||||
|
if isinstance(message, dict):
|
||||||
|
content = str(message.get("content", ""))
|
||||||
|
prompt_tokens = _as_int(payload.get("prompt_eval_count"))
|
||||||
|
completion_tokens = _as_int(payload.get("eval_count"))
|
||||||
|
return {
|
||||||
|
"id": _completion_id(),
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": str(payload.get("model", "")),
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": content},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def ollama_generate_to_openai(payload: dict[str, object]) -> dict[str, object]:
|
||||||
|
"""Translate a *non-streaming* Ollama generate response to OpenAI completion."""
|
||||||
|
prompt_tokens = _as_int(payload.get("prompt_eval_count"))
|
||||||
|
completion_tokens = _as_int(payload.get("eval_count"))
|
||||||
|
return {
|
||||||
|
"id": _completion_id(),
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": str(payload.get("model", "")),
|
||||||
|
"choices": [
|
||||||
|
{"index": 0, "text": str(payload.get("response", "")), "finish_reason": "stop"}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def ollama_embed_to_openai(payload: dict[str, object], model: str) -> dict[str, object]:
|
||||||
|
"""Translate an Ollama ``/api/embed`` response to the OpenAI embeddings shape."""
|
||||||
|
raw = payload.get("embeddings")
|
||||||
|
vectors: list[list[float]] = raw if isinstance(raw, list) else []
|
||||||
|
prompt_tokens = _as_int(payload.get("prompt_eval_count"))
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{"object": "embedding", "index": i, "embedding": vec}
|
||||||
|
for i, vec in enumerate(vectors)
|
||||||
|
],
|
||||||
|
"model": model,
|
||||||
|
"usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def models_to_openai_list(names: list[str]) -> dict[str, object]:
|
||||||
|
"""Render a list of model names in the OpenAI ``/v1/models`` list format."""
|
||||||
|
created = int(time.time())
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{"id": name, "object": "model", "created": created, "owned_by": "neuronetz"}
|
||||||
|
for name in names
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def new_completion_id() -> str:
|
||||||
|
"""Public helper to mint a completion id for a streaming response."""
|
||||||
|
return _completion_id()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"models_to_openai_list",
|
||||||
|
"new_completion_id",
|
||||||
|
"ollama_chat_chunk_to_openai",
|
||||||
|
"ollama_chat_to_openai",
|
||||||
|
"ollama_embed_to_openai",
|
||||||
|
"ollama_generate_chunk_to_openai",
|
||||||
|
"ollama_generate_to_openai",
|
||||||
|
"openai_chat_to_ollama",
|
||||||
|
"openai_completion_to_ollama",
|
||||||
|
"openai_embeddings_to_ollama",
|
||||||
|
]
|
||||||
3
src/neuronetz_gateway/ratelimit/__init__.py
Normal file
3
src/neuronetz_gateway/ratelimit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Rate limiting: sliding-window RPM/TPM and concurrency semaphore (Redis)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
66
src/neuronetz_gateway/ratelimit/concurrency.py
Normal file
66
src/neuronetz_gateway/ratelimit/concurrency.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""Concurrent-connection semaphore backed by Redis ``INCR`` with a TTL guard.
|
||||||
|
|
||||||
|
Acquired before proxying and released on stream close. A TTL on the counter
|
||||||
|
prevents a crashed worker from leaking permits forever (self-healing). Fails
|
||||||
|
closed (SPEC §4.4): if Redis is unreachable, acquisition raises
|
||||||
|
:class:`DependencyUnavailableError` so the caller denies (503).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import redis.asyncio as redis
|
||||||
|
from redis.exceptions import RedisError
|
||||||
|
|
||||||
|
from neuronetz_gateway.errors import DependencyUnavailableError
|
||||||
|
|
||||||
|
# Default guard TTL: a permit auto-expires if not explicitly released, so a
|
||||||
|
# crashed worker cannot pin the semaphore. Comfortably longer than any single
|
||||||
|
# stream is expected to take is set per-call via ``ttl_s``.
|
||||||
|
|
||||||
|
|
||||||
|
class ConcurrencyLimiter:
|
||||||
|
"""Redis-backed concurrency cap with a self-healing TTL guard."""
|
||||||
|
|
||||||
|
def __init__(self, client: redis.Redis) -> None:
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
async def acquire(self, scope_key: str, limit: int, ttl_s: int) -> bool:
|
||||||
|
"""Try to acquire a permit; return False (deny) if at capacity.
|
||||||
|
|
||||||
|
Increments the counter, refreshes the TTL guard, and rolls back the
|
||||||
|
increment if the new value exceeds ``limit``. Raises on Redis failure so
|
||||||
|
the caller fails closed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
count = await cast("Awaitable[int]", self._client.incr(scope_key))
|
||||||
|
await self._client.expire(scope_key, ttl_s)
|
||||||
|
except RedisError as exc:
|
||||||
|
raise DependencyUnavailableError(
|
||||||
|
internal_detail=f"concurrency redis error: {exc!r}"
|
||||||
|
) from exc
|
||||||
|
if count > limit:
|
||||||
|
# Over capacity: undo our increment and deny.
|
||||||
|
try:
|
||||||
|
await self._client.decr(scope_key)
|
||||||
|
except RedisError:
|
||||||
|
# Permit self-heals via the TTL guard; denial still stands.
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def release(self, scope_key: str) -> None:
|
||||||
|
"""Release a previously acquired permit (best-effort; never raises)."""
|
||||||
|
try:
|
||||||
|
count = await cast("Awaitable[int]", self._client.decr(scope_key))
|
||||||
|
if count < 0:
|
||||||
|
await self._client.set(scope_key, 0)
|
||||||
|
except RedisError:
|
||||||
|
# The TTL guard will reclaim the permit; releasing must not break the
|
||||||
|
# request that already completed.
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ConcurrencyLimiter"]
|
||||||
109
src/neuronetz_gateway/ratelimit/sliding_window.py
Normal file
109
src/neuronetz_gateway/ratelimit/sliding_window.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
"""Sliding-window rate limiter backed by an atomic Redis Lua script.
|
||||||
|
|
||||||
|
Enforces per-key and per-tenant RPM and per-key TPM. The window is a sorted set
|
||||||
|
of timestamped hits (a true sliding window, not a fixed bucket): each check
|
||||||
|
trims entries older than ``window_s``, sums the cost of what remains, and admits
|
||||||
|
the new cost only if it keeps the total within ``limit``. The trim + sum +
|
||||||
|
conditional add run inside one Lua script so the decision is atomic across
|
||||||
|
concurrent workers (SPEC §4.3 step 4).
|
||||||
|
|
||||||
|
Fail-closed (SPEC §4.4): if Redis is unavailable the limiter raises
|
||||||
|
:class:`DependencyUnavailableError`; the caller must deny (503), never allow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import redis.asyncio as redis
|
||||||
|
from redis.exceptions import RedisError
|
||||||
|
|
||||||
|
from neuronetz_gateway.errors import DependencyUnavailableError
|
||||||
|
|
||||||
|
# KEYS[1] = zset key
|
||||||
|
# ARGV[1] = now (ms) ARGV[2] = window (ms) ARGV[3] = limit
|
||||||
|
# ARGV[4] = cost ARGV[5] = unique member suffix
|
||||||
|
# Returns: {allowed (1/0), used_after, retry_after_ms}
|
||||||
|
_LUA = """
|
||||||
|
local key = KEYS[1]
|
||||||
|
local now = tonumber(ARGV[1])
|
||||||
|
local window = tonumber(ARGV[2])
|
||||||
|
local limit = tonumber(ARGV[3])
|
||||||
|
local cost = tonumber(ARGV[4])
|
||||||
|
local member = ARGV[5]
|
||||||
|
|
||||||
|
redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
|
||||||
|
-- Each surviving member encodes its cost as a ':<cost>' suffix; sum them so the
|
||||||
|
-- window total is cost-weighted (RPM uses cost 1, TPM uses token count).
|
||||||
|
local total = 0
|
||||||
|
local data = redis.call('ZRANGEBYSCORE', key, now - window, now)
|
||||||
|
for i = 1, #data do
|
||||||
|
local sep = string.find(data[i], ':[^:]*$')
|
||||||
|
if sep then
|
||||||
|
total = total + tonumber(string.sub(data[i], sep + 1))
|
||||||
|
else
|
||||||
|
total = total + 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if total + cost > limit then
|
||||||
|
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
|
||||||
|
local retry = window
|
||||||
|
if oldest[2] then
|
||||||
|
retry = (tonumber(oldest[2]) + window) - now
|
||||||
|
if retry < 0 then retry = 0 end
|
||||||
|
end
|
||||||
|
return {0, total, retry}
|
||||||
|
end
|
||||||
|
|
||||||
|
redis.call('ZADD', key, now, member .. ':' .. cost)
|
||||||
|
redis.call('PEXPIRE', key, window)
|
||||||
|
return {1, total + cost, 0}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class RateLimitResult:
|
||||||
|
"""Outcome of a rate-limit check."""
|
||||||
|
|
||||||
|
allowed: bool
|
||||||
|
limit: int
|
||||||
|
remaining: int
|
||||||
|
retry_after_s: int | None
|
||||||
|
|
||||||
|
|
||||||
|
class SlidingWindowLimiter:
|
||||||
|
"""Redis-backed sliding-window limiter (atomic via Lua)."""
|
||||||
|
|
||||||
|
def __init__(self, client: redis.Redis) -> None:
|
||||||
|
self._client = client
|
||||||
|
self._script = client.register_script(_LUA)
|
||||||
|
|
||||||
|
async def check(self, key: str, limit: int, window_s: int, cost: int = 1) -> RateLimitResult:
|
||||||
|
"""Atomically record a hit of ``cost`` and report admission.
|
||||||
|
|
||||||
|
Raises :class:`DependencyUnavailableError` if Redis cannot be reached, so
|
||||||
|
the caller fails closed.
|
||||||
|
"""
|
||||||
|
now_ms = int(time.time() * 1000)
|
||||||
|
window_ms = window_s * 1000
|
||||||
|
member = f"{now_ms}-{id(object())}"
|
||||||
|
try:
|
||||||
|
raw = await self._script(
|
||||||
|
keys=[key], args=[now_ms, window_ms, limit, cost, member]
|
||||||
|
)
|
||||||
|
except RedisError as exc:
|
||||||
|
raise DependencyUnavailableError(
|
||||||
|
internal_detail=f"ratelimit redis error: {exc!r}"
|
||||||
|
) from exc
|
||||||
|
allowed_i, used_after, retry_ms = (int(raw[0]), int(raw[1]), int(raw[2]))
|
||||||
|
allowed = allowed_i == 1
|
||||||
|
remaining = max(limit - used_after, 0)
|
||||||
|
retry_after_s = None if allowed else max(1, (retry_ms + 999) // 1000)
|
||||||
|
return RateLimitResult(
|
||||||
|
allowed=allowed, limit=limit, remaining=remaining, retry_after_s=retry_after_s
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["RateLimitResult", "SlidingWindowLimiter"]
|
||||||
97
src/neuronetz_gateway/revocation.py
Normal file
97
src/neuronetz_gateway/revocation.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""Key-revocation NOTIFY listener (SPEC §4.5).
|
||||||
|
|
||||||
|
Console (or the gateway's own CLI) revokes a key by inserting into
|
||||||
|
``gateway.revocations``; an ``AFTER INSERT`` trigger fires
|
||||||
|
``pg_notify('key_revoked', key_id)``. This background task LISTENs on that
|
||||||
|
channel and, on each notification, evicts the Redis auth-cache entry for the
|
||||||
|
revoked key's prefix so the next request misses the cache, re-reads the DB,
|
||||||
|
finds the key non-active, and is rejected — making revocation effective within
|
||||||
|
one Redis RTT without any cross-service HTTP.
|
||||||
|
|
||||||
|
The listener resolves ``key_id -> prefix`` via a short DB lookup (the NOTIFY
|
||||||
|
payload is the key id, but the cache is keyed by prefix). It is resilient: a
|
||||||
|
dropped connection is retried with backoff.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
import redis.asyncio as redis
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from neuronetz_gateway.config import Settings
|
||||||
|
from neuronetz_gateway.db.models import ApiKey
|
||||||
|
from neuronetz_gateway.observability.logging import get_logger
|
||||||
|
|
||||||
|
_log = get_logger("revocation")
|
||||||
|
|
||||||
|
_CHANNEL = "key_revoked"
|
||||||
|
_CACHE_PREFIX = "gateway:key:"
|
||||||
|
|
||||||
|
|
||||||
|
def _asyncpg_dsn(database_url: str) -> str:
|
||||||
|
"""Strip the SQLAlchemy ``+asyncpg`` driver tag for a raw asyncpg connect."""
|
||||||
|
return database_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||||
|
|
||||||
|
|
||||||
|
async def _evict(
|
||||||
|
key_id_text: str,
|
||||||
|
sessionmaker: async_sessionmaker[AsyncSession],
|
||||||
|
redis_client: redis.Redis,
|
||||||
|
) -> None:
|
||||||
|
"""Resolve the key id to its prefix and delete the cached principal."""
|
||||||
|
try:
|
||||||
|
key_id = uuid.UUID(key_id_text)
|
||||||
|
except ValueError:
|
||||||
|
_log.warning("revocation_bad_payload", payload=key_id_text)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
async with sessionmaker() as session:
|
||||||
|
key = await session.get(ApiKey, key_id)
|
||||||
|
prefix = key.prefix if key is not None else None
|
||||||
|
if prefix is not None:
|
||||||
|
await redis_client.delete(_CACHE_PREFIX + prefix)
|
||||||
|
_log.info("revocation_cache_evicted", key_prefix=prefix)
|
||||||
|
except Exception as exc: # noqa: BLE001 - listener must survive transient errors
|
||||||
|
_log.warning("revocation_evict_failed", error=str(exc))
|
||||||
|
|
||||||
|
|
||||||
|
async def revocation_listener(
|
||||||
|
settings: Settings,
|
||||||
|
redis_client: redis.Redis,
|
||||||
|
sessionmaker: async_sessionmaker[AsyncSession],
|
||||||
|
) -> None:
|
||||||
|
"""LISTEN on ``key_revoked`` and evict the Redis cache on each notification."""
|
||||||
|
dsn = _asyncpg_dsn(settings.database_url)
|
||||||
|
while True:
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = await asyncpg.connect(dsn)
|
||||||
|
|
||||||
|
def _on_notify(
|
||||||
|
_c: object, _pid: int, _channel: str, payload: str
|
||||||
|
) -> None:
|
||||||
|
# Schedule the async eviction without blocking the callback.
|
||||||
|
asyncio.create_task(_evict(payload, sessionmaker, redis_client)) # noqa: RUF006
|
||||||
|
|
||||||
|
await conn.add_listener(_CHANNEL, _on_notify)
|
||||||
|
_log.info("revocation_listener_started")
|
||||||
|
# Wait forever; notifications arrive via the callback. asyncio.Event
|
||||||
|
# is a cancel-friendly substitute for an unbounded sleep loop.
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as exc: # noqa: BLE001 - reconnect on any failure
|
||||||
|
_log.warning("revocation_listener_reconnect", error=str(exc))
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
finally:
|
||||||
|
if conn is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["revocation_listener"]
|
||||||
168
src/neuronetz_gateway/routes/ollama_native.py
Normal file
168
src/neuronetz_gateway/routes/ollama_native.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Native Ollama passthrough routes (SPEC §6.1).
|
||||||
|
|
||||||
|
All proxied endpoints run through the shared :class:`Pipeline` (auth has already
|
||||||
|
attached the principal in middleware): scope + model + endpoint allowlist, rate
|
||||||
|
limit, budget, concurrency, body validation, then stream/relay with post-close
|
||||||
|
token-count + audit + budget accounting.
|
||||||
|
|
||||||
|
Mutating endpoints (``/api/pull|push|create|copy|delete``, ``/api/blobs/*``) and
|
||||||
|
``/api/ps`` are hard-blocked (SPEC §6.2) and intentionally NOT routed; a catch-all
|
||||||
|
returns a generic 403 so their existence is never confirmed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request
|
||||||
|
from starlette.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
from neuronetz_gateway import __version__
|
||||||
|
from neuronetz_gateway.deps import (
|
||||||
|
ConfigDep,
|
||||||
|
DiscoveryCacheDep,
|
||||||
|
OllamaClientDep,
|
||||||
|
PipelineDep,
|
||||||
|
PrincipalDep,
|
||||||
|
)
|
||||||
|
from neuronetz_gateway.errors import AuthorizationError
|
||||||
|
from neuronetz_gateway.proxy.allowlist import is_hard_blocked, resolve_effective_models
|
||||||
|
from neuronetz_gateway.proxy.pipeline import model_of, read_json_body
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api", tags=["ollama-native"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/chat")
|
||||||
|
async def chat(request: Request, pipeline: PipelineDep) -> Response:
|
||||||
|
"""Proxy ``POST /api/chat`` (streamed NDJSON / non-streamed)."""
|
||||||
|
return await _chat_or_generate(request, pipeline, path="/api/chat", scope="chat")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generate")
|
||||||
|
async def generate(request: Request, pipeline: PipelineDep) -> Response:
|
||||||
|
"""Proxy ``POST /api/generate`` (streamed NDJSON / non-streamed)."""
|
||||||
|
return await _chat_or_generate(request, pipeline, path="/api/generate", scope="chat")
|
||||||
|
|
||||||
|
|
||||||
|
async def _chat_or_generate(
|
||||||
|
request: Request, pipeline: PipelineDep, *, path: str, scope: str
|
||||||
|
) -> Response:
|
||||||
|
body = await read_json_body(request, pipeline.settings)
|
||||||
|
model = model_of(body)
|
||||||
|
pipeline.check_scope(scope)
|
||||||
|
pipeline.check_endpoint(path)
|
||||||
|
pipeline.check_model(model)
|
||||||
|
pipeline.validate_body(body)
|
||||||
|
await pipeline.enforce_limits()
|
||||||
|
if bool(body.get("stream", True)):
|
||||||
|
return await pipeline.stream_native("POST", path, body, model)
|
||||||
|
return await pipeline.request_native("POST", path, body, model)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/embeddings")
|
||||||
|
async def embeddings(request: Request, pipeline: PipelineDep) -> Response:
|
||||||
|
"""Proxy ``POST /api/embeddings`` (legacy, non-streamed)."""
|
||||||
|
return await _embeddings(request, pipeline, "/api/embeddings")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/embed")
|
||||||
|
async def embed(request: Request, pipeline: PipelineDep) -> Response:
|
||||||
|
"""Proxy ``POST /api/embed`` (non-streamed)."""
|
||||||
|
return await _embeddings(request, pipeline, "/api/embed")
|
||||||
|
|
||||||
|
|
||||||
|
async def _embeddings(request: Request, pipeline: PipelineDep, path: str) -> Response:
|
||||||
|
body = await read_json_body(request, pipeline.settings)
|
||||||
|
model = model_of(body)
|
||||||
|
pipeline.check_scope("embeddings")
|
||||||
|
pipeline.check_endpoint(path)
|
||||||
|
pipeline.check_model(model)
|
||||||
|
await pipeline.enforce_limits()
|
||||||
|
return await pipeline.request_native("POST", path, body, model)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tags")
|
||||||
|
async def tags(principal: PrincipalDep, discovery: DiscoveryCacheDep) -> JSONResponse:
|
||||||
|
"""Return the tenant's effective model set (live-discovered ∩ allowed)."""
|
||||||
|
effective = resolve_effective_models(
|
||||||
|
allow_all=principal.allow_all_models,
|
||||||
|
allowed_models=principal.allowed_models,
|
||||||
|
discovered=discovery.names,
|
||||||
|
)
|
||||||
|
models = [
|
||||||
|
{
|
||||||
|
"name": m.name,
|
||||||
|
"model": m.name,
|
||||||
|
"modified_at": m.modified_at,
|
||||||
|
"size": m.size_bytes,
|
||||||
|
"details": {
|
||||||
|
"family": m.family,
|
||||||
|
"parameter_size": m.parameter_size,
|
||||||
|
"quantization_level": m.quantization,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for m in discovery.models
|
||||||
|
if m.name in effective
|
||||||
|
]
|
||||||
|
return JSONResponse({"models": models})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/show")
|
||||||
|
async def show(
|
||||||
|
request: Request,
|
||||||
|
principal: PrincipalDep,
|
||||||
|
discovery: DiscoveryCacheDep,
|
||||||
|
ollama: OllamaClientDep,
|
||||||
|
settings: ConfigDep,
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""Proxy ``POST /api/show`` for an effective-set model; sanitize the result."""
|
||||||
|
body = await read_json_body(request, settings)
|
||||||
|
name = body.get("model") or body.get("name")
|
||||||
|
model = name if isinstance(name, str) else ""
|
||||||
|
if not model or model not in resolve_effective_models(
|
||||||
|
allow_all=principal.allow_all_models,
|
||||||
|
allowed_models=principal.allowed_models,
|
||||||
|
discovered=discovery.names,
|
||||||
|
):
|
||||||
|
raise AuthorizationError(internal_detail="show: model not in effective set")
|
||||||
|
resp = await ollama.request("POST", "/api/show", {"model": model})
|
||||||
|
payload = resp.json()
|
||||||
|
raw: dict[str, object] = payload if isinstance(payload, dict) else {}
|
||||||
|
# Strip system prompt + template (SPEC §6.1: no system prompts, no template).
|
||||||
|
raw_details = raw.get("details")
|
||||||
|
details: dict[str, object] = raw_details if isinstance(raw_details, dict) else {}
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"model": model,
|
||||||
|
"details": {
|
||||||
|
"family": details.get("family"),
|
||||||
|
"parameter_size": details.get("parameter_size"),
|
||||||
|
"quantization_level": details.get("quantization_level"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/version")
|
||||||
|
async def version() -> JSONResponse:
|
||||||
|
"""Return the gateway version (never Ollama's; SPEC §6.1)."""
|
||||||
|
return JSONResponse({"version": __version__})
|
||||||
|
|
||||||
|
|
||||||
|
@router.api_route(
|
||||||
|
"/{rest:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"],
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
async def catch_all(rest: str) -> Response:
|
||||||
|
"""Generic 403 for any other ``/api/*`` path (hard-blocked or unknown).
|
||||||
|
|
||||||
|
Mutating endpoints and ``/api/ps`` resolve here and return the same generic
|
||||||
|
forbidden response, so the gateway never confirms which upstream endpoints
|
||||||
|
exist (SPEC §6.2, §13.6).
|
||||||
|
"""
|
||||||
|
full = f"/api/{rest}"
|
||||||
|
if is_hard_blocked(full):
|
||||||
|
raise AuthorizationError(internal_detail=f"hard-blocked {full}")
|
||||||
|
raise AuthorizationError(internal_detail=f"unrouted upstream path {full}")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["router"]
|
||||||
118
src/neuronetz_gateway/routes/openai_compat.py
Normal file
118
src/neuronetz_gateway/routes/openai_compat.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""OpenAI-compatible routes (SPEC §6.3).
|
||||||
|
|
||||||
|
Each route translates the OpenAI request into the native Ollama body, runs the
|
||||||
|
same :class:`Pipeline` enforcement as the native routes, and translates the
|
||||||
|
response back. Streaming uses SSE (``data: {...}\\n\\n`` … ``data: [DONE]\\n\\n``);
|
||||||
|
non-streaming returns a single OpenAI-shaped JSON object. ``/v1/models`` returns
|
||||||
|
the tenant's effective discovered set in OpenAI list format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request
|
||||||
|
from starlette.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
from neuronetz_gateway.deps import (
|
||||||
|
DiscoveryCacheDep,
|
||||||
|
PipelineDep,
|
||||||
|
PrincipalDep,
|
||||||
|
)
|
||||||
|
from neuronetz_gateway.proxy.allowlist import resolve_effective_models
|
||||||
|
from neuronetz_gateway.proxy.pipeline import model_of, read_json_body
|
||||||
|
from neuronetz_gateway.proxy.translate import (
|
||||||
|
models_to_openai_list,
|
||||||
|
new_completion_id,
|
||||||
|
ollama_chat_chunk_to_openai,
|
||||||
|
ollama_chat_to_openai,
|
||||||
|
ollama_embed_to_openai,
|
||||||
|
ollama_generate_chunk_to_openai,
|
||||||
|
ollama_generate_to_openai,
|
||||||
|
openai_chat_to_ollama,
|
||||||
|
openai_completion_to_ollama,
|
||||||
|
openai_embeddings_to_ollama,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/v1", tags=["openai-compat"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/chat/completions")
|
||||||
|
async def chat_completions(request: Request, pipeline: PipelineDep) -> Response:
|
||||||
|
"""OpenAI ``/v1/chat/completions`` -> Ollama ``/api/chat``."""
|
||||||
|
payload = await read_json_body(request, pipeline.settings)
|
||||||
|
body = openai_chat_to_ollama(payload)
|
||||||
|
model = model_of(body)
|
||||||
|
pipeline.check_scope("chat")
|
||||||
|
pipeline.check_endpoint("/api/chat")
|
||||||
|
pipeline.check_model(model)
|
||||||
|
pipeline.validate_body(body)
|
||||||
|
await pipeline.enforce_limits()
|
||||||
|
if bool(payload.get("stream", False)):
|
||||||
|
completion_id = new_completion_id()
|
||||||
|
created = int(time.time())
|
||||||
|
|
||||||
|
def translate(chunk: dict[str, object]) -> dict[str, object]:
|
||||||
|
return ollama_chat_chunk_to_openai(
|
||||||
|
chunk, completion_id=completion_id, model=model, created=created
|
||||||
|
)
|
||||||
|
|
||||||
|
return await pipeline.stream_openai("POST", "/api/chat", body, model, translate)
|
||||||
|
return await pipeline.request_translated(
|
||||||
|
"POST", "/api/chat", body, model, ollama_chat_to_openai
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/completions")
|
||||||
|
async def completions(request: Request, pipeline: PipelineDep) -> Response:
|
||||||
|
"""OpenAI ``/v1/completions`` -> Ollama ``/api/generate``."""
|
||||||
|
payload = await read_json_body(request, pipeline.settings)
|
||||||
|
body = openai_completion_to_ollama(payload)
|
||||||
|
model = model_of(body)
|
||||||
|
pipeline.check_scope("chat")
|
||||||
|
pipeline.check_endpoint("/api/generate")
|
||||||
|
pipeline.check_model(model)
|
||||||
|
pipeline.validate_body(body)
|
||||||
|
await pipeline.enforce_limits()
|
||||||
|
if bool(payload.get("stream", False)):
|
||||||
|
completion_id = new_completion_id()
|
||||||
|
created = int(time.time())
|
||||||
|
|
||||||
|
def translate(chunk: dict[str, object]) -> dict[str, object]:
|
||||||
|
return ollama_generate_chunk_to_openai(
|
||||||
|
chunk, completion_id=completion_id, model=model, created=created
|
||||||
|
)
|
||||||
|
|
||||||
|
return await pipeline.stream_openai("POST", "/api/generate", body, model, translate)
|
||||||
|
return await pipeline.request_translated(
|
||||||
|
"POST", "/api/generate", body, model, ollama_generate_to_openai
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/embeddings")
|
||||||
|
async def embeddings(request: Request, pipeline: PipelineDep) -> Response:
|
||||||
|
"""OpenAI ``/v1/embeddings`` -> Ollama ``/api/embed``."""
|
||||||
|
payload = await read_json_body(request, pipeline.settings)
|
||||||
|
body = openai_embeddings_to_ollama(payload)
|
||||||
|
model = model_of(body)
|
||||||
|
pipeline.check_scope("embeddings")
|
||||||
|
pipeline.check_endpoint("/api/embed")
|
||||||
|
pipeline.check_model(model)
|
||||||
|
await pipeline.enforce_limits()
|
||||||
|
return await pipeline.request_translated(
|
||||||
|
"POST", "/api/embed", body, model, lambda obj: ollama_embed_to_openai(obj, model)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models")
|
||||||
|
async def models(principal: PrincipalDep, discovery: DiscoveryCacheDep) -> JSONResponse:
|
||||||
|
"""OpenAI ``/v1/models`` -> the tenant's effective discovered set."""
|
||||||
|
effective = resolve_effective_models(
|
||||||
|
allow_all=principal.allow_all_models,
|
||||||
|
allowed_models=principal.allowed_models,
|
||||||
|
discovered=discovery.names,
|
||||||
|
)
|
||||||
|
return JSONResponse(models_to_openai_list(sorted(effective)))
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["router"]
|
||||||
Reference in New Issue
Block a user