proxy: streaming, discovery, OpenAI-compat, rate-limit, budget, audit
The hot path. A single Pipeline class owns enforcement so the eight
non-negotiables can be reviewed in one place.
- Native /api/chat, /api/generate (NDJSON streaming + non-stream), /api/tags,
/api/show (system-prompt + template stripped), /api/embed(dings), /api/version
(returns gateway version, not Ollama's). Endpoint catch-all returns the same
generic 403 for hard-blocked and unknown /api/* paths so attackers cannot
enumerate which mutating endpoints exist.
- OpenAI-compat /v1/chat/completions, /v1/completions, /v1/embeddings,
/v1/models with SSE (`data: {...}` + final `data: [DONE]`); preserves
streaming end-to-end.
- Model discovery (SPEC §4.6): background poller against Ollama /api/tags;
Redis + in-process cache (TTL = MODEL_DISCOVERY_CACHE_TTL_S, refresh =
MODEL_DISCOVERY_REFRESH_S); fail-closed when the discovered set is empty.
- Effective-set resolution in proxy/allowlist.py:
allow_all = key.allow_all_models ?? tenant.allow_all_models
effective = discovered if allow_all
else (key.allowed_models ?? tenant.allowed_models) ∩ discovered
A non-effective model returns the same generic 403 whether it's installed-
but-unpermitted or doesn't exist at all (no enumeration leak).
- Sliding-window rate limit (Redis Lua, single round-trip) for per-key +
per-tenant RPM and per-key TPM. Redis-INCR/DECR concurrency semaphore with
TTL guard. Token-budget counters per (key, period) with a Postgres ledger
for reconciliation across resets. Headers per SPEC §6.5 on every response;
429 carries Retry-After; Redis outage → 503 (fail closed, never 200).
- Token counting from the FINAL stream object (NDJSON `done` or the SSE chunk
carrying `usage`); the audit row is written AFTER stream close so TTFB is
never degraded by bookkeeping.
- Audit writer: asyncio.Queue + bounded ring buffer; deny-mode flip on overflow.
Optional prompt log per key (TTL'd).
- Revocation listener: asyncpg LISTEN on key_revoked → evict the Redis cache
entry within ~1s of the console writing to gateway.revocations.
- Prometheus counters/histograms labeled by tenant only (per SPEC §13.3).
This commit is contained in:
3
src/neuronetz_gateway/audit/__init__.py
Normal file
3
src/neuronetz_gateway/audit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Audit logging: buffered async audit writer and opt-in prompt log."""
|
||||
|
||||
from __future__ import annotations
|
||||
63
src/neuronetz_gateway/audit/prompt_log.py
Normal file
63
src/neuronetz_gateway/audit/prompt_log.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Opt-in, TTL'd prompt logging (SPEC §2, §5).
|
||||
|
||||
Disabled by default; enabled per key (or inherited from tenant via the resolved
|
||||
:class:`Principal.log_prompts`). Each row carries a ``retention_until`` deadline
|
||||
(swept in Phase 4). A redaction hook runs before persistence so sensitive spans
|
||||
can be scrubbed without changing call sites.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from neuronetz_gateway.db.models import PromptLog
|
||||
|
||||
# A redaction hook maps a request/response body to a sanitized copy. The default
|
||||
# is identity; operators can install a stricter hook.
|
||||
RedactionHook = Callable[[dict[str, object]], dict[str, object]]
|
||||
|
||||
|
||||
def _identity(body: dict[str, object]) -> dict[str, object]:
|
||||
return body
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PromptRecord:
|
||||
"""A captured request/response pair pending TTL'd persistence."""
|
||||
|
||||
audit_id: int
|
||||
key_id: UUID
|
||||
request_body: dict[str, object]
|
||||
response_text: str | None
|
||||
retention_days: int
|
||||
|
||||
|
||||
class PromptLogWriter:
|
||||
"""Persists opt-in prompt records with a retention deadline."""
|
||||
|
||||
def __init__(self, session: AsyncSession, redact: RedactionHook | None = None) -> None:
|
||||
self._session = session
|
||||
self._redact = redact or _identity
|
||||
|
||||
async def write(self, record: PromptRecord) -> None:
|
||||
"""Persist a prompt record to ``gateway.prompt_log``."""
|
||||
retention_until = datetime.datetime.now(datetime.UTC) + datetime.timedelta(
|
||||
days=record.retention_days
|
||||
)
|
||||
row = PromptLog(
|
||||
audit_id=record.audit_id,
|
||||
key_id=record.key_id,
|
||||
request_body=self._redact(record.request_body),
|
||||
response_text=record.response_text,
|
||||
retention_until=retention_until,
|
||||
)
|
||||
self._session.add(row)
|
||||
await self._session.flush()
|
||||
|
||||
|
||||
__all__ = ["PromptLogWriter", "PromptRecord", "RedactionHook"]
|
||||
152
src/neuronetz_gateway/audit/writer.py
Normal file
152
src/neuronetz_gateway/audit/writer.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Buffered async audit-log writer (SPEC §4.4).
|
||||
|
||||
Audit rows are enqueued *after* stream close so the hot path is never delayed
|
||||
(non-negotiable #6). A background drain task persists them to Postgres. On a
|
||||
Postgres write failure rows are buffered in a bounded in-memory ring (max
|
||||
``AUDIT_BUFFER_SIZE``) and retried; if the ring fills, the writer flips to
|
||||
**deny mode** (SPEC §4.4) — :pyattr:`deny_mode` goes True and the request path
|
||||
must refuse new work until the backlog drains.
|
||||
|
||||
The writer holds the session factory directly (it runs outside request scope).
|
||||
``enqueue`` never blocks on the DB; it only touches the in-memory queue/ring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import datetime
|
||||
from dataclasses import asdict, dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from neuronetz_gateway.db.repositories import AuditRepository
|
||||
from neuronetz_gateway.observability.logging import get_logger
|
||||
|
||||
_log = get_logger("audit")
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AuditRecord:
|
||||
"""A single audit-log row queued for persistence."""
|
||||
|
||||
request_id: UUID
|
||||
method: str
|
||||
path: str
|
||||
status: int
|
||||
ts: datetime.datetime
|
||||
tenant_id: UUID | None = None
|
||||
key_id: UUID | None = None
|
||||
key_prefix: str | None = None
|
||||
model: str | None = None
|
||||
tokens_in: int | None = None
|
||||
tokens_out: int | None = None
|
||||
latency_ms: int | None = None
|
||||
client_ip: str | None = None
|
||||
user_agent: str | None = None
|
||||
error_code: str | None = None
|
||||
|
||||
def as_columns(self) -> dict[str, object]:
|
||||
"""Map to ``gateway.audit_log`` column kwargs."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
class AuditWriter:
|
||||
"""Buffered, fail-safe writer for ``gateway.audit_log``."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
sessionmaker: async_sessionmaker[AsyncSession] | None = None,
|
||||
) -> None:
|
||||
self._buffer_size = buffer_size
|
||||
self._sessionmaker = sessionmaker
|
||||
self._ring: collections.deque[AuditRecord] = collections.deque(maxlen=buffer_size)
|
||||
self._queue: asyncio.Queue[AuditRecord] = asyncio.Queue()
|
||||
self._deny_mode = False
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
|
||||
@property
|
||||
def deny_mode(self) -> bool:
|
||||
"""True when the buffer has overflowed and new work must be denied."""
|
||||
return self._deny_mode
|
||||
|
||||
def bind(self, sessionmaker: async_sessionmaker[AsyncSession]) -> None:
|
||||
"""Attach the DB session factory (called from lifespan startup)."""
|
||||
self._sessionmaker = sessionmaker
|
||||
|
||||
async def enqueue(self, record: AuditRecord) -> None:
|
||||
"""Queue an audit record for asynchronous persistence (non-blocking)."""
|
||||
await self._queue.put(record)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Launch the background drain task (idempotent)."""
|
||||
if self._task is None or self._task.done():
|
||||
self._task = asyncio.create_task(self._drain_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Cancel the drain task and flush remaining rows best-effort."""
|
||||
if self._task is not None:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._task = None
|
||||
await self.flush()
|
||||
|
||||
async def _drain_loop(self) -> None:
|
||||
"""Continuously move queued records into Postgres, retrying the ring."""
|
||||
while True:
|
||||
record = await self._queue.get()
|
||||
await self._persist_with_retry(record)
|
||||
|
||||
async def _persist_with_retry(self, record: AuditRecord) -> None:
|
||||
"""Persist one record; on failure push to the ring (or enter deny mode)."""
|
||||
# Drain any backlog first so ordering is roughly preserved.
|
||||
if self._ring:
|
||||
await self._try_flush_ring()
|
||||
if not await self._write_one(record):
|
||||
self._buffer(record)
|
||||
|
||||
def _buffer(self, record: AuditRecord) -> None:
|
||||
"""Buffer a failed write; flip to deny mode if the ring is full."""
|
||||
if len(self._ring) >= self._buffer_size:
|
||||
self._deny_mode = True
|
||||
_log.error("audit_buffer_overflow_deny_mode", buffered=len(self._ring))
|
||||
return
|
||||
self._ring.append(record)
|
||||
|
||||
async def _try_flush_ring(self) -> None:
|
||||
"""Attempt to persist buffered rows; clear deny mode once drained."""
|
||||
while self._ring:
|
||||
record = self._ring[0]
|
||||
if not await self._write_one(record):
|
||||
return
|
||||
self._ring.popleft()
|
||||
self._deny_mode = False
|
||||
|
||||
async def _write_one(self, record: AuditRecord) -> bool:
|
||||
"""Persist a single record; return False on any failure."""
|
||||
if self._sessionmaker is None:
|
||||
return False
|
||||
try:
|
||||
async with self._sessionmaker() as session:
|
||||
await AuditRepository(session).insert_audit(**record.as_columns())
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001 - any DB error ⇒ buffer + retry
|
||||
_log.warning("audit_write_failed", error=str(exc))
|
||||
return False
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Drain queued + buffered records to Postgres (best-effort)."""
|
||||
while not self._queue.empty():
|
||||
record = self._queue.get_nowait()
|
||||
if not await self._write_one(record):
|
||||
self._buffer(record)
|
||||
await self._try_flush_ring()
|
||||
|
||||
|
||||
__all__ = ["AuditRecord", "AuditWriter"]
|
||||
3
src/neuronetz_gateway/budget/__init__.py
Normal file
3
src/neuronetz_gateway/budget/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Token budgets: Redis period counters and Postgres ledger reconciliation."""
|
||||
|
||||
from __future__ import annotations
|
||||
105
src/neuronetz_gateway/budget/counter.py
Normal file
105
src/neuronetz_gateway/budget/counter.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Redis period counters for token budgets (day / month / total).
|
||||
|
||||
The fast-path budget check reads a Redis counter of tokens already consumed in
|
||||
the active period and compares it to the configured limit; Postgres
|
||||
(``ledger.py``) is the durable source of truth reconciled on rollover. Counters
|
||||
are keyed by ``key_id`` + period + period-start so a new period naturally starts
|
||||
at zero, and day/month counters carry a TTL so expired periods self-clean.
|
||||
|
||||
Fail-closed (SPEC §4.4): Redis errors raise
|
||||
:class:`DependencyUnavailableError`; the caller must deny (503).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from neuronetz_gateway.db.models import BudgetPeriod
|
||||
from neuronetz_gateway.errors import DependencyUnavailableError
|
||||
|
||||
_CONSUMED_PREFIX = "gateway:budget:"
|
||||
|
||||
|
||||
def period_start(period: BudgetPeriod, now: datetime.datetime | None = None) -> datetime.datetime:
|
||||
"""Return the UTC start of the current period."""
|
||||
moment = now or datetime.datetime.now(datetime.UTC)
|
||||
moment = moment.astimezone(datetime.UTC)
|
||||
if period is BudgetPeriod.day:
|
||||
return moment.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
if period is BudgetPeriod.month:
|
||||
return moment.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
# 'total' has no rollover; anchor at the epoch.
|
||||
return datetime.datetime(1970, 1, 1, tzinfo=datetime.UTC)
|
||||
|
||||
|
||||
def _ttl_for(period: BudgetPeriod) -> int | None:
|
||||
"""TTL (seconds) for a period counter; None for 'total' (no expiry)."""
|
||||
if period is BudgetPeriod.day:
|
||||
return 2 * 24 * 3600
|
||||
if period is BudgetPeriod.month:
|
||||
return 40 * 24 * 3600
|
||||
return None
|
||||
|
||||
|
||||
def _counter_key(key_id: str, period: BudgetPeriod, start: datetime.datetime) -> str:
|
||||
return f"{_CONSUMED_PREFIX}{key_id}:{period.value}:{int(start.timestamp())}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class BudgetState:
|
||||
"""Snapshot of remaining budget for a period."""
|
||||
|
||||
period: BudgetPeriod
|
||||
limit: int | None
|
||||
remaining: int | None
|
||||
|
||||
@property
|
||||
def exhausted(self) -> bool:
|
||||
"""True if a finite limit is configured and nothing remains."""
|
||||
return self.limit is not None and (self.remaining or 0) <= 0
|
||||
|
||||
|
||||
class BudgetCounter:
|
||||
"""Redis-backed per-period token counter."""
|
||||
|
||||
def __init__(self, client: redis.Redis) -> None:
|
||||
self._client = client
|
||||
|
||||
async def check(self, key_id: str, period: BudgetPeriod, limit: int | None) -> BudgetState:
|
||||
"""Return remaining budget; ``None`` limit means unlimited for the period."""
|
||||
if limit is None:
|
||||
return BudgetState(period=period, limit=None, remaining=None)
|
||||
start = period_start(period)
|
||||
try:
|
||||
raw = await self._client.get(_counter_key(key_id, period, start))
|
||||
except RedisError as exc:
|
||||
raise DependencyUnavailableError(
|
||||
internal_detail=f"budget redis error: {exc!r}"
|
||||
) from exc
|
||||
consumed = int(raw) if raw else 0
|
||||
return BudgetState(period=period, limit=limit, remaining=max(limit - consumed, 0))
|
||||
|
||||
async def consume(self, key_id: str, period: BudgetPeriod, tokens: int) -> None:
|
||||
"""Increment the consumed counter after a request completes."""
|
||||
if tokens <= 0:
|
||||
return
|
||||
start = period_start(period)
|
||||
key = _counter_key(key_id, period, start)
|
||||
try:
|
||||
await cast("Awaitable[int]", self._client.incrby(key, tokens))
|
||||
ttl = _ttl_for(period)
|
||||
if ttl is not None:
|
||||
await self._client.expire(key, ttl)
|
||||
except RedisError as exc:
|
||||
raise DependencyUnavailableError(
|
||||
internal_detail=f"budget consume redis error: {exc!r}"
|
||||
) from exc
|
||||
|
||||
|
||||
__all__ = ["BudgetCounter", "BudgetState", "period_start"]
|
||||
58
src/neuronetz_gateway/budget/ledger.py
Normal file
58
src/neuronetz_gateway/budget/ledger.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Postgres budget ledger reconciliation.
|
||||
|
||||
Persists token usage to ``gateway.budget_usage`` (the durable source of truth)
|
||||
via an idempotent upsert keyed by (key_id, period, period_start). The Redis
|
||||
counter (``counter.py``) is the fast path; this ledger is what survives a Redis
|
||||
flush and what ``show-usage`` reports against.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from neuronetz_gateway.budget.counter import period_start
|
||||
from neuronetz_gateway.db.models import BudgetPeriod, BudgetUsage
|
||||
|
||||
|
||||
class BudgetLedger:
|
||||
"""Source-of-truth budget accounting in Postgres."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def record_usage(
|
||||
self, key_id: str, period: BudgetPeriod, tokens_in: int, tokens_out: int
|
||||
) -> None:
|
||||
"""Upsert usage into ``gateway.budget_usage`` for the active period.
|
||||
|
||||
Uses an ``ON CONFLICT`` upsert so concurrent writers accumulate rather
|
||||
than clobber. ``requests`` increments by one per recorded request.
|
||||
"""
|
||||
start = period_start(period)
|
||||
stmt = pg_insert(BudgetUsage).values(
|
||||
key_id=uuid.UUID(key_id) if isinstance(key_id, str) else key_id,
|
||||
period=period,
|
||||
period_start=start,
|
||||
tokens_in=tokens_in,
|
||||
tokens_out=tokens_out,
|
||||
requests=1,
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[
|
||||
BudgetUsage.key_id,
|
||||
BudgetUsage.period,
|
||||
BudgetUsage.period_start,
|
||||
],
|
||||
set_={
|
||||
"tokens_in": BudgetUsage.tokens_in + stmt.excluded.tokens_in,
|
||||
"tokens_out": BudgetUsage.tokens_out + stmt.excluded.tokens_out,
|
||||
"requests": BudgetUsage.requests + stmt.excluded.requests,
|
||||
},
|
||||
)
|
||||
await self._session.execute(stmt)
|
||||
|
||||
|
||||
__all__ = ["BudgetLedger"]
|
||||
67
src/neuronetz_gateway/observability/metrics.py
Normal file
67
src/neuronetz_gateway/observability/metrics.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Prometheus metrics.
|
||||
|
||||
Phase 1 declares the metric objects and the exposition helper. Instrumentation
|
||||
(incrementing counters / observing histograms on the request path) is wired in
|
||||
later phases. Per SPEC §13.3 we label by ``tenant`` only, never by ``key_id``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from prometheus_client import CollectorRegistry, Counter, Histogram, generate_latest
|
||||
|
||||
REGISTRY = CollectorRegistry()
|
||||
|
||||
REQUESTS_TOTAL = Counter(
|
||||
"gateway_requests_total",
|
||||
"Total proxied requests.",
|
||||
labelnames=("tenant", "model", "status"),
|
||||
registry=REGISTRY,
|
||||
)
|
||||
|
||||
TOKENS_TOTAL = Counter(
|
||||
"gateway_tokens_total",
|
||||
"Total tokens accounted, by direction (in|out).",
|
||||
labelnames=("tenant", "model", "direction"),
|
||||
registry=REGISTRY,
|
||||
)
|
||||
|
||||
REQUEST_DURATION_SECONDS = Histogram(
|
||||
"gateway_request_duration_seconds",
|
||||
"Gateway-side request duration in seconds.",
|
||||
labelnames=("tenant", "model"),
|
||||
registry=REGISTRY,
|
||||
)
|
||||
|
||||
|
||||
def record_request(tenant: str, model: str, status: int, duration_s: float) -> None:
|
||||
"""Increment the request counter and observe its duration (tenant-labeled)."""
|
||||
REQUESTS_TOTAL.labels(tenant=tenant, model=model, status=str(status)).inc()
|
||||
REQUEST_DURATION_SECONDS.labels(tenant=tenant, model=model).observe(duration_s)
|
||||
|
||||
|
||||
def record_tokens(tenant: str, model: str, tokens_in: int, tokens_out: int) -> None:
|
||||
"""Add input/output token counts to the tokens counter."""
|
||||
if tokens_in:
|
||||
TOKENS_TOTAL.labels(tenant=tenant, model=model, direction="in").inc(tokens_in)
|
||||
if tokens_out:
|
||||
TOKENS_TOTAL.labels(tenant=tenant, model=model, direction="out").inc(tokens_out)
|
||||
|
||||
|
||||
def render_latest() -> bytes:
|
||||
"""Return the current metrics in Prometheus text exposition format."""
|
||||
payload: bytes = generate_latest(REGISTRY)
|
||||
return payload
|
||||
|
||||
|
||||
CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
|
||||
|
||||
__all__ = [
|
||||
"CONTENT_TYPE_LATEST",
|
||||
"REGISTRY",
|
||||
"REQUESTS_TOTAL",
|
||||
"REQUEST_DURATION_SECONDS",
|
||||
"TOKENS_TOTAL",
|
||||
"record_request",
|
||||
"record_tokens",
|
||||
"render_latest",
|
||||
]
|
||||
3
src/neuronetz_gateway/proxy/__init__.py
Normal file
3
src/neuronetz_gateway/proxy/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Proxy layer: Ollama client, schema translation, token counting, allowlists."""
|
||||
|
||||
from __future__ import annotations
|
||||
76
src/neuronetz_gateway/proxy/allowlist.py
Normal file
76
src/neuronetz_gateway/proxy/allowlist.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Endpoint and model allowlists (SPEC §6.1, §6.2).
|
||||
|
||||
Mutating Ollama endpoints are hard-blocked (not configurable, not flagged):
|
||||
``/api/pull``, ``/api/push``, ``/api/create``, ``/api/copy``, ``/api/delete``,
|
||||
and any ``/api/blobs/*``. ``/api/ps`` is also blocked (leaks loaded models).
|
||||
Model allowlist is per-tenant, default-deny. Enforcement logic lands in Phase 2.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Hard-blocked upstream endpoints — always 403, not configurable (SPEC §6.2).
|
||||
HARD_BLOCKED_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/pull",
|
||||
"/api/push",
|
||||
"/api/create",
|
||||
"/api/copy",
|
||||
"/api/delete",
|
||||
"/api/ps",
|
||||
}
|
||||
)
|
||||
|
||||
# Path prefixes that are hard-blocked (e.g. blob upload/download).
|
||||
HARD_BLOCKED_PREFIXES: tuple[str, ...] = ("/api/blobs",)
|
||||
|
||||
|
||||
def is_hard_blocked(path: str) -> bool:
|
||||
"""Return True if ``path`` is an unconditionally blocked upstream endpoint."""
|
||||
if path in HARD_BLOCKED_PATHS:
|
||||
return True
|
||||
return any(path.startswith(prefix) for prefix in HARD_BLOCKED_PREFIXES)
|
||||
|
||||
|
||||
def resolve_effective_models(
|
||||
*,
|
||||
allow_all: bool,
|
||||
allowed_models: tuple[str, ...],
|
||||
discovered: frozenset[str],
|
||||
) -> frozenset[str]:
|
||||
"""Resolve the effective model set per SPEC §4.3 step 7 / §4.6.
|
||||
|
||||
``allow_all`` ⇒ the effective set is the entire live ``discovered`` set;
|
||||
otherwise it is the configured ``allowed_models`` intersected with
|
||||
``discovered`` (so stale or typo'd allowlist entries never resolve, and a
|
||||
model that is unpermitted vs. not-installed are indistinguishable).
|
||||
|
||||
Fail-closed: if ``discovered`` is empty (discovery unavailable/expired) the
|
||||
result is empty regardless of ``allow_all`` — discovery never opens access.
|
||||
"""
|
||||
if not discovered:
|
||||
return frozenset()
|
||||
if allow_all:
|
||||
return discovered
|
||||
return frozenset(allowed_models) & discovered
|
||||
|
||||
|
||||
def is_model_allowed(
|
||||
model: str,
|
||||
*,
|
||||
allow_all: bool,
|
||||
allowed_models: tuple[str, ...],
|
||||
discovered: frozenset[str],
|
||||
) -> bool:
|
||||
"""Return True iff ``model`` is in the resolved effective set (default-deny)."""
|
||||
return model in resolve_effective_models(
|
||||
allow_all=allow_all, allowed_models=allowed_models, discovered=discovered
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"HARD_BLOCKED_PATHS",
|
||||
"HARD_BLOCKED_PREFIXES",
|
||||
"is_hard_blocked",
|
||||
"is_model_allowed",
|
||||
"resolve_effective_models",
|
||||
]
|
||||
207
src/neuronetz_gateway/proxy/discovery.py
Normal file
207
src/neuronetz_gateway/proxy/discovery.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Live model discovery from the Ollama backend (SPEC §4.6).
|
||||
|
||||
A background task polls Ollama ``GET /api/tags`` every
|
||||
``MODEL_DISCOVERY_REFRESH_S`` seconds. The parsed model set (names + sanitized
|
||||
metadata) is cached in Redis under ``gateway:models:discovered`` with TTL
|
||||
``MODEL_DISCOVERY_CACHE_TTL_S`` and held in-process for hot reads on the request
|
||||
path.
|
||||
|
||||
Fail-closed (SPEC §4.6, §13.5): if Ollama is unreachable, or the cache is empty
|
||||
or expired and cannot be refreshed, the discovered set is empty — and an empty
|
||||
discovered set means no model resolves, so requests are denied. Discovery never
|
||||
opens access on failure. It is read-only and only ever touches the allowlisted
|
||||
``/api/tags`` endpoint; it never triggers a pull.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import httpx
|
||||
import redis.asyncio as redis
|
||||
|
||||
from neuronetz_gateway.config import Settings
|
||||
from neuronetz_gateway.observability.logging import get_logger
|
||||
|
||||
_log = get_logger("discovery")
|
||||
|
||||
REDIS_DISCOVERED_KEY = "gateway:models:discovered"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DiscoveredModel:
|
||||
"""Sanitized metadata for a single installed model."""
|
||||
|
||||
name: str
|
||||
family: str | None = None
|
||||
parameter_size: str | None = None
|
||||
quantization: str | None = None
|
||||
size_bytes: int | None = None
|
||||
modified_at: str | None = None
|
||||
|
||||
|
||||
def _parse_tags(payload: dict[str, object]) -> list[DiscoveredModel]:
|
||||
"""Parse an Ollama ``/api/tags`` body into sanitized model records."""
|
||||
models: list[DiscoveredModel] = []
|
||||
raw_models = payload.get("models")
|
||||
if not isinstance(raw_models, list):
|
||||
return models
|
||||
for entry in raw_models:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
name = entry.get("name") or entry.get("model")
|
||||
if not isinstance(name, str) or not name:
|
||||
continue
|
||||
raw_details = entry.get("details")
|
||||
details: dict[str, object] = raw_details if isinstance(raw_details, dict) else {}
|
||||
size = entry.get("size")
|
||||
models.append(
|
||||
DiscoveredModel(
|
||||
name=name,
|
||||
family=_opt_str(details.get("family")),
|
||||
parameter_size=_opt_str(details.get("parameter_size")),
|
||||
quantization=_opt_str(details.get("quantization_level")),
|
||||
size_bytes=size if isinstance(size, int) else None,
|
||||
modified_at=_opt_str(entry.get("modified_at")),
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
def _opt_str(value: object) -> str | None:
|
||||
"""Coerce a value to ``str`` only when it already is one."""
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
|
||||
def names_of(models: list[DiscoveredModel]) -> frozenset[str]:
|
||||
"""Return just the set of model names from parsed records."""
|
||||
return frozenset(m.name for m in models)
|
||||
|
||||
|
||||
class DiscoveryCache:
|
||||
"""In-process holder for the latest discovered model set.
|
||||
|
||||
Holds both the structured records (for ``/api/tags`` / ``list-models``) and a
|
||||
fast name set (for allowlist resolution on the hot path). Reads never block
|
||||
on Redis or Ollama; the poller refreshes this in the background.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._models: list[DiscoveredModel] = []
|
||||
self._names: frozenset[str] = frozenset()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def set(self, models: list[DiscoveredModel]) -> None:
|
||||
"""Replace the in-process snapshot atomically."""
|
||||
async with self._lock:
|
||||
self._models = list(models)
|
||||
self._names = names_of(models)
|
||||
|
||||
@property
|
||||
def names(self) -> frozenset[str]:
|
||||
"""Current discovered model names (possibly empty ⇒ fail-closed)."""
|
||||
return self._names
|
||||
|
||||
@property
|
||||
def models(self) -> list[DiscoveredModel]:
|
||||
"""Current discovered model records (copy)."""
|
||||
return list(self._models)
|
||||
|
||||
|
||||
async def write_discovered_to_redis(
|
||||
client: redis.Redis, models: list[DiscoveredModel], ttl_s: int
|
||||
) -> None:
|
||||
"""Cache the discovered set in Redis with a TTL (so staleness expires)."""
|
||||
payload = json.dumps([asdict(m) for m in models], separators=(",", ":"))
|
||||
await client.set(REDIS_DISCOVERED_KEY, payload, ex=ttl_s)
|
||||
|
||||
|
||||
async def read_discovered_from_redis(client: redis.Redis) -> frozenset[str]:
|
||||
"""Read the cached discovered names from Redis; empty set on miss/expiry."""
|
||||
raw = await client.get(REDIS_DISCOVERED_KEY)
|
||||
if not raw:
|
||||
return frozenset()
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return frozenset()
|
||||
if not isinstance(data, list):
|
||||
return frozenset()
|
||||
return frozenset(
|
||||
str(item["name"]) for item in data if isinstance(item, dict) and item.get("name")
|
||||
)
|
||||
|
||||
|
||||
async def fetch_tags(client: httpx.AsyncClient) -> list[DiscoveredModel]:
|
||||
"""Fetch and parse Ollama ``/api/tags``; raise on transport/HTTP error."""
|
||||
resp = await client.get("/api/tags")
|
||||
resp.raise_for_status()
|
||||
body = resp.json()
|
||||
if not isinstance(body, dict):
|
||||
return []
|
||||
return _parse_tags(body)
|
||||
|
||||
|
||||
async def refresh_once(
|
||||
http_client: httpx.AsyncClient,
|
||||
redis_client: redis.Redis | None,
|
||||
cache: DiscoveryCache,
|
||||
settings: Settings,
|
||||
) -> bool:
|
||||
"""Run a single discovery refresh. Returns True on success.
|
||||
|
||||
On any failure the in-process and Redis caches are left untouched; they
|
||||
expire on their own TTL, which is the fail-closed behavior (stale-expired ⇒
|
||||
empty ⇒ deny). We never *clear* eagerly on a transient error, but we also
|
||||
never extend the TTL on failure.
|
||||
"""
|
||||
try:
|
||||
models = await fetch_tags(http_client)
|
||||
except (httpx.HTTPError, ValueError) as exc:
|
||||
_log.warning("discovery_refresh_failed", error=str(exc))
|
||||
return False
|
||||
await cache.set(models)
|
||||
if redis_client is not None:
|
||||
try:
|
||||
await write_discovered_to_redis(
|
||||
redis_client, models, settings.model_discovery_cache_ttl_s
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 - Redis write is best-effort cache fill
|
||||
_log.warning("discovery_cache_write_failed", error=str(exc))
|
||||
_log.info("discovery_refreshed", count=len(models))
|
||||
return True
|
||||
|
||||
|
||||
async def discovery_loop(
|
||||
http_client: httpx.AsyncClient,
|
||||
redis_client: redis.Redis | None,
|
||||
cache: DiscoveryCache,
|
||||
settings: Settings,
|
||||
) -> None:
|
||||
"""Background poller: refresh now, then every ``MODEL_DISCOVERY_REFRESH_S``.
|
||||
|
||||
Designed to be launched via ``asyncio.create_task`` in the lifespan and
|
||||
cancelled on shutdown.
|
||||
"""
|
||||
await refresh_once(http_client, redis_client, cache, settings)
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(settings.model_discovery_refresh_s)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
await refresh_once(http_client, redis_client, cache, settings)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"REDIS_DISCOVERED_KEY",
|
||||
"DiscoveredModel",
|
||||
"DiscoveryCache",
|
||||
"discovery_loop",
|
||||
"fetch_tags",
|
||||
"names_of",
|
||||
"read_discovered_from_redis",
|
||||
"refresh_once",
|
||||
"write_discovered_to_redis",
|
||||
]
|
||||
79
src/neuronetz_gateway/proxy/ollama.py
Normal file
79
src/neuronetz_gateway/proxy/ollama.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""httpx-based streaming proxy client for Ollama.
|
||||
|
||||
Opens async streams to the upstream and relays bytes without buffering, so the
|
||||
gateway does not degrade time-to-first-byte (SPEC §9, non-negotiable #6).
|
||||
Transport-level upstream failures are sanitized at this boundary into
|
||||
:class:`UpstreamUnavailableError` (never reflected verbatim, non-negotiable #4).
|
||||
|
||||
The client is constructed from the shared ``httpx.AsyncClient`` on
|
||||
``app.state`` and is exposed to routes via the ``get_ollama_client`` dependency
|
||||
in ``deps.py`` so tests can override it (the QA override contract).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
from neuronetz_gateway.errors import UpstreamUnavailableError
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""Thin async wrapper around the shared httpx client for Ollama calls."""
|
||||
|
||||
def __init__(self, client: httpx.AsyncClient) -> None:
|
||||
self._client = client
|
||||
|
||||
async def stream(
|
||||
self, method: str, path: str, json_body: dict[str, object]
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""Open a streaming request to Ollama and yield raw response chunks.
|
||||
|
||||
Yields bytes exactly as received so the hot path performs no buffering.
|
||||
Transport errors are sanitized; if the upstream returns a non-2xx the
|
||||
body is drained (so internals are not surfaced) and a generic 502 is
|
||||
raised before any client bytes are emitted.
|
||||
"""
|
||||
request = self._client.build_request(method, path, json=json_body)
|
||||
try:
|
||||
response = await self._client.send(request, stream=True)
|
||||
except httpx.HTTPError as exc:
|
||||
raise UpstreamUnavailableError(internal_detail=f"ollama send failed: {exc!r}") from exc
|
||||
try:
|
||||
if response.status_code >= 400:
|
||||
await response.aread()
|
||||
raise UpstreamUnavailableError(
|
||||
internal_detail=f"ollama returned {response.status_code} for {path}"
|
||||
)
|
||||
async for chunk in response.aiter_raw():
|
||||
yield chunk
|
||||
except httpx.HTTPError as exc:
|
||||
raise UpstreamUnavailableError(
|
||||
internal_detail=f"ollama stream failed: {exc!r}"
|
||||
) from exc
|
||||
finally:
|
||||
await response.aclose()
|
||||
|
||||
async def request(
|
||||
self, method: str, path: str, json_body: dict[str, object]
|
||||
) -> httpx.Response:
|
||||
"""Perform a non-streaming request to Ollama.
|
||||
|
||||
Returns the upstream response on success; raises a sanitized 502 on
|
||||
transport failure or a non-2xx status (internals never reflected).
|
||||
"""
|
||||
try:
|
||||
response = await self._client.request(method, path, json=json_body)
|
||||
except httpx.HTTPError as exc:
|
||||
raise UpstreamUnavailableError(
|
||||
internal_detail=f"ollama request failed: {exc!r}"
|
||||
) from exc
|
||||
if response.status_code >= 400:
|
||||
raise UpstreamUnavailableError(
|
||||
internal_detail=f"ollama returned {response.status_code} for {path}"
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
__all__ = ["OllamaClient"]
|
||||
467
src/neuronetz_gateway/proxy/pipeline.py
Normal file
467
src/neuronetz_gateway/proxy/pipeline.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""Shared request pipeline: enforcement, streaming, and post-close bookkeeping.
|
||||
|
||||
Both the native (``/api/*``) and OpenAI-compat (``/v1/*``) routes funnel through
|
||||
here so the security checks and the streaming-integrity contract are written
|
||||
once. The order mirrors SPEC §4.3:
|
||||
|
||||
rate limit (per-key + per-tenant RPM) → budget → concurrency → model allowlist
|
||||
→ endpoint allowlist → body validation → proxy + stream → post-close audit /
|
||||
token-count / budget-consume / metrics / semaphore release.
|
||||
|
||||
Streaming integrity (non-negotiable #6): the bytes flow to the client untouched
|
||||
and token counting + audit + budget-consume happen **after** the stream closes,
|
||||
never on the hot path.
|
||||
|
||||
Fail-closed (non-negotiable #1): every limiter/budget call raises
|
||||
:class:`DependencyUnavailableError` when Redis is down, which the error handler
|
||||
renders as 503. The model allowlist is default-deny against the live discovered
|
||||
set; a missing/expired discovery set denies everything.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from neuronetz_gateway.audit.writer import AuditRecord, AuditWriter
|
||||
from neuronetz_gateway.auth.principal import Principal
|
||||
from neuronetz_gateway.budget.counter import BudgetCounter
|
||||
from neuronetz_gateway.budget.ledger import BudgetLedger
|
||||
from neuronetz_gateway.config import Settings
|
||||
from neuronetz_gateway.db.models import BudgetPeriod
|
||||
from neuronetz_gateway.errors import (
|
||||
AuthorizationError,
|
||||
BudgetExceededError,
|
||||
RateLimitError,
|
||||
RequestTooLargeError,
|
||||
)
|
||||
from neuronetz_gateway.observability import metrics
|
||||
from neuronetz_gateway.observability.logging import get_logger
|
||||
from neuronetz_gateway.proxy.allowlist import is_hard_blocked, is_model_allowed
|
||||
from neuronetz_gateway.proxy.discovery import DiscoveryCache
|
||||
from neuronetz_gateway.proxy.ollama import OllamaClient
|
||||
from neuronetz_gateway.proxy.token_counter import TokenUsage, extract_usage
|
||||
from neuronetz_gateway.ratelimit.concurrency import ConcurrencyLimiter
|
||||
from neuronetz_gateway.ratelimit.sliding_window import SlidingWindowLimiter
|
||||
|
||||
_log = get_logger("pipeline")
|
||||
|
||||
NDJSON_MEDIA_TYPE = "application/x-ndjson"
|
||||
SSE_MEDIA_TYPE = "text/event-stream"
|
||||
_CONCURRENCY_PREFIX = "gateway:concurrency:"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RateHeaders:
|
||||
"""The §6.5 rate/budget header values gathered during pre-flight."""
|
||||
|
||||
limit_requests: int
|
||||
remaining_requests: int
|
||||
limit_tokens: int
|
||||
remaining_tokens: int
|
||||
budget_period: str
|
||||
budget_remaining: str
|
||||
|
||||
def as_dict(self, request_id: str) -> dict[str, str]:
|
||||
"""Render to the SPEC §6.5 header set."""
|
||||
return {
|
||||
"X-Request-ID": request_id,
|
||||
"X-RateLimit-Limit-Requests": str(self.limit_requests),
|
||||
"X-RateLimit-Remaining-Requests": str(self.remaining_requests),
|
||||
"X-RateLimit-Limit-Tokens": str(self.limit_tokens),
|
||||
"X-RateLimit-Remaining-Tokens": str(self.remaining_tokens),
|
||||
"X-Budget-Period": self.budget_period,
|
||||
"X-Budget-Tokens-Remaining": self.budget_remaining,
|
||||
}
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""Per-request enforcement + proxy orchestrator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request: Request,
|
||||
principal: Principal,
|
||||
settings: Settings,
|
||||
ollama: OllamaClient,
|
||||
discovery: DiscoveryCache,
|
||||
rate_limiter: SlidingWindowLimiter,
|
||||
concurrency: ConcurrencyLimiter,
|
||||
budget: BudgetCounter,
|
||||
audit: AuditWriter,
|
||||
sessionmaker: async_sessionmaker[AsyncSession] | None = None,
|
||||
) -> None:
|
||||
self._request = request
|
||||
self._p = principal
|
||||
self._settings = settings
|
||||
self._ollama = ollama
|
||||
self._discovery = discovery
|
||||
self._rate = rate_limiter
|
||||
self._conc = concurrency
|
||||
self._budget = budget
|
||||
self._audit = audit
|
||||
self._sessionmaker = sessionmaker
|
||||
self._request_id = str(getattr(request.state, "request_id", uuid.uuid4()))
|
||||
self._concurrency_key = f"{_CONCURRENCY_PREFIX}{principal.tenant_id}"
|
||||
self._headers = RateHeaders(
|
||||
limit_requests=principal.limits.rpm,
|
||||
remaining_requests=principal.limits.rpm,
|
||||
limit_tokens=principal.limits.tpm,
|
||||
remaining_tokens=principal.limits.tpm,
|
||||
budget_period="day",
|
||||
budget_remaining="unlimited",
|
||||
)
|
||||
|
||||
@property
|
||||
def settings(self) -> Settings:
|
||||
"""The active settings (for body-size / num_predict caps in routes)."""
|
||||
return self._settings
|
||||
|
||||
# ----- enforcement -------------------------------------------------------
|
||||
|
||||
def check_scope(self, scope: str) -> None:
|
||||
"""Authorize a coarse scope (e.g. 'chat', 'embeddings')."""
|
||||
if scope not in self._p.scopes:
|
||||
raise AuthorizationError(internal_detail=f"scope {scope!r} not granted")
|
||||
|
||||
def check_endpoint(self, path: str) -> None:
|
||||
"""Reject hard-blocked upstream endpoints with a generic 403."""
|
||||
if is_hard_blocked(path):
|
||||
raise AuthorizationError(internal_detail=f"hard-blocked endpoint {path}")
|
||||
|
||||
def check_model(self, model: str) -> None:
|
||||
"""Default-deny model check against the live effective set (§4.3 step 7)."""
|
||||
if not model or not is_model_allowed(
|
||||
model,
|
||||
allow_all=self._p.allow_all_models,
|
||||
allowed_models=self._p.allowed_models,
|
||||
discovered=self._discovery.names,
|
||||
):
|
||||
# No existence disclosure (SPEC §13.6): unpermitted and not-installed
|
||||
# both yield the same generic 403.
|
||||
raise AuthorizationError(internal_detail=f"model {model!r} not in effective set")
|
||||
|
||||
def validate_body(self, body: dict[str, object]) -> None:
|
||||
"""Enforce the ``num_predict`` cap (body-size cap is enforced earlier)."""
|
||||
options = body.get("options")
|
||||
if isinstance(options, dict):
|
||||
num_predict = options.get("num_predict")
|
||||
if isinstance(num_predict, int) and num_predict > self._settings.max_num_predict:
|
||||
options["num_predict"] = self._settings.max_num_predict
|
||||
|
||||
async def enforce_limits(self, *, token_estimate: int = 0) -> None:
|
||||
"""Run RPM (per-key + per-tenant), TPM, budget, and concurrency checks.
|
||||
|
||||
Budget is checked before concurrency so an over-budget request never even
|
||||
acquires a permit. Order otherwise follows SPEC §4.3 steps 4-6.
|
||||
"""
|
||||
await self._check_rpm()
|
||||
await self._check_tpm(token_estimate)
|
||||
await self.check_budgets()
|
||||
await self._acquire_concurrency()
|
||||
|
||||
async def _check_rpm(self) -> None:
|
||||
key_result = await self._rate.check(
|
||||
f"gateway:rpm:key:{self._p.key_id}", self._p.limits.rpm, 60, cost=1
|
||||
)
|
||||
self._headers.limit_requests = key_result.limit
|
||||
self._headers.remaining_requests = key_result.remaining
|
||||
if not key_result.allowed:
|
||||
raise RateLimitError(retry_after=key_result.retry_after_s)
|
||||
tenant_result = await self._rate.check(
|
||||
f"gateway:rpm:tenant:{self._p.tenant_id}", self._p.limits.rpm, 60, cost=1
|
||||
)
|
||||
if not tenant_result.allowed:
|
||||
raise RateLimitError(retry_after=tenant_result.retry_after_s)
|
||||
|
||||
async def _check_tpm(self, token_estimate: int) -> None:
|
||||
# TPM is charged with a minimum of 1 so a request always counts; the
|
||||
# precise token cost is reconciled post-stream via the budget counter.
|
||||
cost = max(token_estimate, 1)
|
||||
result = await self._rate.check(
|
||||
f"gateway:tpm:key:{self._p.key_id}", self._p.limits.tpm, 60, cost=cost
|
||||
)
|
||||
self._headers.limit_tokens = result.limit
|
||||
self._headers.remaining_tokens = result.remaining
|
||||
if not result.allowed:
|
||||
raise RateLimitError(retry_after=result.retry_after_s)
|
||||
|
||||
def _budget_periods(self) -> list[tuple[BudgetPeriod, int | None]]:
|
||||
return [
|
||||
(BudgetPeriod.day, self._p.limits.tokens_daily),
|
||||
(BudgetPeriod.month, self._p.limits.tokens_monthly),
|
||||
(BudgetPeriod.total, self._p.limits.tokens_total),
|
||||
]
|
||||
|
||||
async def check_budgets(self) -> None:
|
||||
"""Verify no configured budget period is already exhausted."""
|
||||
tightest_period = "day"
|
||||
tightest_remaining = "unlimited"
|
||||
for period, limit in self._budget_periods():
|
||||
state = await self._budget.check(str(self._p.key_id), period, limit)
|
||||
if state.exhausted:
|
||||
raise BudgetExceededError(internal_detail=f"budget {period.value} exhausted")
|
||||
if state.remaining is not None:
|
||||
tightest_period = period.value
|
||||
tightest_remaining = str(state.remaining)
|
||||
self._headers.budget_period = tightest_period
|
||||
self._headers.budget_remaining = tightest_remaining
|
||||
|
||||
async def _acquire_concurrency(self) -> None:
|
||||
ok = await self._conc.acquire(
|
||||
self._concurrency_key,
|
||||
self._p.limits.concurrent,
|
||||
self._settings.ollama_read_timeout_s + 30,
|
||||
)
|
||||
if not ok:
|
||||
raise RateLimitError(
|
||||
retry_after=1, internal_detail="concurrency cap reached"
|
||||
)
|
||||
|
||||
# ----- proxy + bookkeeping ----------------------------------------------
|
||||
|
||||
def headers(self) -> dict[str, str]:
|
||||
"""Render the §6.5 response headers."""
|
||||
return self._headers.as_dict(self._request_id)
|
||||
|
||||
async def stream_native(
|
||||
self, method: str, path: str, body: dict[str, object], model: str
|
||||
) -> StreamingResponse:
|
||||
"""Proxy a streaming NDJSON request, accounting tokens after close."""
|
||||
started = time.monotonic()
|
||||
media_type = NDJSON_MEDIA_TYPE
|
||||
|
||||
async def gen() -> AsyncIterator[bytes]:
|
||||
last_obj: dict[str, object] = {}
|
||||
try:
|
||||
async for chunk in self._ollama.stream(method, path, body):
|
||||
last_obj = _merge_last_ndjson(chunk, last_obj)
|
||||
yield chunk
|
||||
finally:
|
||||
await self._finish(model, path, method, last_obj, started)
|
||||
|
||||
return StreamingResponse(gen(), media_type=media_type, headers=self.headers())
|
||||
|
||||
async def stream_openai(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
body: dict[str, object],
|
||||
model: str,
|
||||
chunk_translator: Callable[[dict[str, object]], dict[str, object]],
|
||||
) -> StreamingResponse:
|
||||
"""Proxy + translate native NDJSON into OpenAI SSE; account after close."""
|
||||
started = time.monotonic()
|
||||
|
||||
async def gen() -> AsyncIterator[bytes]:
|
||||
last_obj: dict[str, object] = {}
|
||||
buffer = b""
|
||||
try:
|
||||
async for chunk in self._ollama.stream(method, path, body):
|
||||
buffer += chunk
|
||||
lines = buffer.split(b"\n")
|
||||
buffer = lines.pop()
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
obj = json.loads(line)
|
||||
last_obj = obj if isinstance(obj, dict) else last_obj
|
||||
translated = chunk_translator(obj)
|
||||
yield f"data: {json.dumps(translated)}\n\n".encode()
|
||||
if buffer.strip():
|
||||
obj = json.loads(buffer)
|
||||
last_obj = obj if isinstance(obj, dict) else last_obj
|
||||
yield f"data: {json.dumps(chunk_translator(obj))}\n\n".encode()
|
||||
yield b"data: [DONE]\n\n"
|
||||
finally:
|
||||
await self._finish(model, path, method, last_obj, started)
|
||||
|
||||
return StreamingResponse(gen(), media_type=SSE_MEDIA_TYPE, headers=self.headers())
|
||||
|
||||
async def request_native(
|
||||
self, method: str, path: str, body: dict[str, object], model: str
|
||||
) -> JSONResponse:
|
||||
"""Proxy a non-streaming request and account tokens before responding."""
|
||||
started = time.monotonic()
|
||||
resp = await self._ollama.request(method, path, body)
|
||||
payload = resp.json()
|
||||
obj = payload if isinstance(payload, dict) else {}
|
||||
await self._finish(model, path, method, obj, started)
|
||||
return JSONResponse(obj, headers=self.headers())
|
||||
|
||||
async def request_translated(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
body: dict[str, object],
|
||||
model: str,
|
||||
translator: Callable[[dict[str, object]], dict[str, object]],
|
||||
) -> JSONResponse:
|
||||
"""Proxy a non-streaming request, translate the body, then account."""
|
||||
started = time.monotonic()
|
||||
resp = await self._ollama.request(method, path, body)
|
||||
payload = resp.json()
|
||||
obj = payload if isinstance(payload, dict) else {}
|
||||
await self._finish(model, path, method, obj, started)
|
||||
return JSONResponse(translator(obj), headers=self.headers())
|
||||
|
||||
async def _finish(
|
||||
self,
|
||||
model: str,
|
||||
path: str,
|
||||
method: str,
|
||||
final_obj: dict[str, object],
|
||||
started: float,
|
||||
) -> None:
|
||||
"""Post-close bookkeeping: tokens, budget, metrics, audit, semaphore.
|
||||
|
||||
Runs once per request after the response is fully produced. Each step is
|
||||
best-effort and guarded so a bookkeeping failure never corrupts the
|
||||
already-delivered response. The concurrency permit is always released.
|
||||
"""
|
||||
usage = extract_usage(final_obj) if final_obj else TokenUsage(0, 0)
|
||||
latency_ms = int((time.monotonic() - started) * 1000)
|
||||
try:
|
||||
await self._account_budget(usage)
|
||||
except Exception as exc: # noqa: BLE001 - never break a delivered response
|
||||
_log.warning("budget_account_failed", error=str(exc))
|
||||
try:
|
||||
metrics.record_request(self._p.tenant_name, model or "unknown", 200, latency_ms / 1000)
|
||||
metrics.record_tokens(
|
||||
self._p.tenant_name, model or "unknown", usage.tokens_in, usage.tokens_out
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 - metrics must not break responses
|
||||
_log.warning("metrics_record_failed", error=str(exc))
|
||||
await self._write_audit(model, path, method, usage, latency_ms, 200)
|
||||
await self._release_concurrency()
|
||||
|
||||
async def _account_budget(self, usage: TokenUsage) -> None:
|
||||
"""Decrement Redis budget counters and persist to the Postgres ledger."""
|
||||
for period, limit in self._budget_periods():
|
||||
if limit is not None:
|
||||
await self._budget.consume(str(self._p.key_id), period, usage.total)
|
||||
if self._sessionmaker is not None and usage.total >= 0:
|
||||
try:
|
||||
async with self._sessionmaker() as session:
|
||||
ledger = BudgetLedger(session)
|
||||
for period, _limit in self._budget_periods():
|
||||
await ledger.record_usage(
|
||||
str(self._p.key_id), period, usage.tokens_in, usage.tokens_out
|
||||
)
|
||||
await session.commit()
|
||||
except Exception as exc: # noqa: BLE001 - ledger is durable backstop; never break response
|
||||
_log.warning("ledger_record_failed", error=str(exc))
|
||||
|
||||
async def _write_audit(
|
||||
self,
|
||||
model: str,
|
||||
path: str,
|
||||
method: str,
|
||||
usage: TokenUsage,
|
||||
latency_ms: int,
|
||||
status: int,
|
||||
) -> None:
|
||||
record = AuditRecord(
|
||||
request_id=uuid.UUID(self._request_id)
|
||||
if _is_uuid(self._request_id)
|
||||
else uuid.uuid4(),
|
||||
method=method,
|
||||
path=path,
|
||||
status=status,
|
||||
ts=datetime.datetime.now(datetime.UTC),
|
||||
tenant_id=self._p.tenant_id,
|
||||
key_id=self._p.key_id,
|
||||
key_prefix=self._p.key_prefix,
|
||||
model=model or None,
|
||||
tokens_in=usage.tokens_in,
|
||||
tokens_out=usage.tokens_out,
|
||||
latency_ms=latency_ms,
|
||||
client_ip=self._client_ip(),
|
||||
user_agent=self._request.headers.get("user-agent"),
|
||||
)
|
||||
try:
|
||||
await self._audit.enqueue(record)
|
||||
except Exception as exc: # noqa: BLE001 - audit enqueue must not break response
|
||||
_log.warning("audit_enqueue_failed", error=str(exc))
|
||||
|
||||
async def _release_concurrency(self) -> None:
|
||||
try:
|
||||
await self._conc.release(self._concurrency_key)
|
||||
except Exception as exc: # noqa: BLE001 - release is best-effort
|
||||
_log.warning("concurrency_release_failed", error=str(exc))
|
||||
|
||||
def _client_ip(self) -> str | None:
|
||||
xff = self._request.headers.get("x-forwarded-for")
|
||||
if xff:
|
||||
return xff.split(",")[0].strip()
|
||||
return self._request.client.host if self._request.client else None
|
||||
|
||||
|
||||
def _merge_last_ndjson(chunk: bytes, prev: dict[str, object]) -> dict[str, object]:
|
||||
"""Track the last complete NDJSON object seen in a raw byte chunk.
|
||||
|
||||
Token counts live on the final ``done`` frame. We parse only complete lines
|
||||
and keep the last successfully-parsed object; partial trailing data is
|
||||
ignored here and will be completed by a subsequent chunk.
|
||||
"""
|
||||
text = chunk.decode("utf-8", errors="ignore")
|
||||
last = prev
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(obj, dict):
|
||||
last = obj
|
||||
return last
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(value)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def model_of(body: dict[str, object]) -> str:
|
||||
"""Extract the requested model name from a request body (empty if absent)."""
|
||||
model = body.get("model")
|
||||
return model if isinstance(model, str) else ""
|
||||
|
||||
|
||||
async def read_json_body(request: Request, settings: Settings) -> dict[str, object]:
|
||||
"""Read + size-limit the request body, returning the parsed JSON object."""
|
||||
raw = await request.body()
|
||||
if len(raw) > settings.max_request_body_bytes:
|
||||
raise RequestTooLargeError(internal_detail=f"body {len(raw)} bytes")
|
||||
if not raw:
|
||||
return {}
|
||||
try:
|
||||
parsed: Any = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise AuthorizationError(internal_detail="invalid JSON body") from exc
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NDJSON_MEDIA_TYPE",
|
||||
"SSE_MEDIA_TYPE",
|
||||
"Pipeline",
|
||||
"RateHeaders",
|
||||
"model_of",
|
||||
"read_json_body",
|
||||
]
|
||||
50
src/neuronetz_gateway/proxy/token_counter.py
Normal file
50
src/neuronetz_gateway/proxy/token_counter.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Precise token accounting parsed from Ollama responses (SPEC §2, §13.1).
|
||||
|
||||
Tokens are read from Ollama's reported ``prompt_eval_count`` (input) and
|
||||
``eval_count`` (output) on the final stream frame — never heuristically
|
||||
estimated. Embeddings charge ``prompt_eval_count`` only (SPEC §13.1); they have
|
||||
no ``eval_count`` so output tokens are reported as zero.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TokenUsage:
|
||||
"""Token counts extracted from an Ollama response."""
|
||||
|
||||
tokens_in: int
|
||||
tokens_out: int
|
||||
|
||||
@property
|
||||
def total(self) -> int:
|
||||
"""Combined input + output tokens."""
|
||||
return self.tokens_in + self.tokens_out
|
||||
|
||||
|
||||
def _as_int(value: object) -> int:
|
||||
"""Coerce an Ollama-reported count to a non-negative int (0 if absent/bad)."""
|
||||
if isinstance(value, bool):
|
||||
return 0
|
||||
if isinstance(value, int):
|
||||
return max(value, 0)
|
||||
if isinstance(value, float):
|
||||
return max(int(value), 0)
|
||||
return 0
|
||||
|
||||
|
||||
def extract_usage(final_frame: dict[str, object]) -> TokenUsage:
|
||||
"""Extract ``prompt_eval_count``/``eval_count`` from the final Ollama frame.
|
||||
|
||||
Works for chat/generate final frames and for embeddings responses (which
|
||||
carry ``prompt_eval_count`` but no ``eval_count``).
|
||||
"""
|
||||
return TokenUsage(
|
||||
tokens_in=_as_int(final_frame.get("prompt_eval_count")),
|
||||
tokens_out=_as_int(final_frame.get("eval_count")),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["TokenUsage", "extract_usage"]
|
||||
245
src/neuronetz_gateway/proxy/translate.py
Normal file
245
src/neuronetz_gateway/proxy/translate.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""OpenAI <-> Ollama schema translation (SPEC §6.3).
|
||||
|
||||
Native ``/api/*`` speaks NDJSON; OpenAI-compat ``/v1/*`` speaks SSE
|
||||
(``data: {...}\\n\\n`` … ``data: [DONE]\\n\\n``). These helpers translate request
|
||||
bodies in both directions and convert native Ollama stream frames into OpenAI
|
||||
chunk objects, preserving streaming. Unknown OpenAI sampling params are mapped
|
||||
into Ollama's ``options`` block; unrecognized keys are dropped rather than
|
||||
forwarded blindly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
# OpenAI sampling params that map into Ollama's ``options`` object.
|
||||
_OPTION_KEYS: tuple[str, ...] = (
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"seed",
|
||||
"stop",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
)
|
||||
|
||||
|
||||
def _as_int(value: object) -> int:
|
||||
"""Coerce a JSON-derived value to a non-negative int (0 if absent/invalid)."""
|
||||
if isinstance(value, bool):
|
||||
return 0
|
||||
if isinstance(value, int):
|
||||
return max(value, 0)
|
||||
if isinstance(value, float):
|
||||
return max(int(value), 0)
|
||||
return 0
|
||||
|
||||
|
||||
def _build_options(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Collect OpenAI sampling params into an Ollama ``options`` mapping."""
|
||||
options: dict[str, Any] = {}
|
||||
for key in _OPTION_KEYS:
|
||||
if key in payload and payload[key] is not None:
|
||||
options[key] = payload[key]
|
||||
if "max_tokens" in payload and payload["max_tokens"] is not None:
|
||||
options["num_predict"] = payload["max_tokens"]
|
||||
return options
|
||||
|
||||
|
||||
def openai_chat_to_ollama(payload: dict[str, object]) -> dict[str, object]:
|
||||
"""Translate an OpenAI chat-completion request to an Ollama ``/api/chat`` body."""
|
||||
body: dict[str, object] = {
|
||||
"model": payload.get("model"),
|
||||
"messages": payload.get("messages", []),
|
||||
"stream": bool(payload.get("stream", False)),
|
||||
}
|
||||
options = _build_options(dict(payload))
|
||||
if options:
|
||||
body["options"] = options
|
||||
return body
|
||||
|
||||
|
||||
def openai_completion_to_ollama(payload: dict[str, object]) -> dict[str, object]:
|
||||
"""Translate an OpenAI completion request to an Ollama ``/api/generate`` body."""
|
||||
prompt = payload.get("prompt", "")
|
||||
if isinstance(prompt, list):
|
||||
prompt = "".join(str(p) for p in prompt)
|
||||
body: dict[str, object] = {
|
||||
"model": payload.get("model"),
|
||||
"prompt": prompt,
|
||||
"stream": bool(payload.get("stream", False)),
|
||||
}
|
||||
options = _build_options(dict(payload))
|
||||
if options:
|
||||
body["options"] = options
|
||||
return body
|
||||
|
||||
|
||||
def openai_embeddings_to_ollama(payload: dict[str, object]) -> dict[str, object]:
|
||||
"""Translate an OpenAI embeddings request to an Ollama ``/api/embed`` body."""
|
||||
return {
|
||||
"model": payload.get("model"),
|
||||
"input": payload.get("input", ""),
|
||||
}
|
||||
|
||||
|
||||
def _completion_id() -> str:
|
||||
"""Generate an OpenAI-style completion id."""
|
||||
return f"chatcmpl-{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def ollama_chat_chunk_to_openai(
|
||||
chunk: dict[str, object], *, completion_id: str, model: str, created: int
|
||||
) -> dict[str, object]:
|
||||
"""Translate one Ollama ``/api/chat`` NDJSON frame to an OpenAI SSE chunk."""
|
||||
done = bool(chunk.get("done"))
|
||||
if done:
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": _as_int(chunk.get("prompt_eval_count")),
|
||||
"completion_tokens": _as_int(chunk.get("eval_count")),
|
||||
"total_tokens": _as_int(chunk.get("prompt_eval_count"))
|
||||
+ _as_int(chunk.get("eval_count")),
|
||||
},
|
||||
}
|
||||
message = chunk.get("message")
|
||||
content = ""
|
||||
if isinstance(message, dict):
|
||||
content = str(message.get("content", ""))
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}],
|
||||
}
|
||||
|
||||
|
||||
def ollama_generate_chunk_to_openai(
|
||||
chunk: dict[str, object], *, completion_id: str, model: str, created: int
|
||||
) -> dict[str, object]:
|
||||
"""Translate one Ollama ``/api/generate`` NDJSON frame to an OpenAI text chunk."""
|
||||
done = bool(chunk.get("done"))
|
||||
if done:
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "text": "", "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": _as_int(chunk.get("prompt_eval_count")),
|
||||
"completion_tokens": _as_int(chunk.get("eval_count")),
|
||||
"total_tokens": _as_int(chunk.get("prompt_eval_count"))
|
||||
+ _as_int(chunk.get("eval_count")),
|
||||
},
|
||||
}
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "text": str(chunk.get("response", "")), "finish_reason": None}],
|
||||
}
|
||||
|
||||
|
||||
def ollama_chat_to_openai(payload: dict[str, object]) -> dict[str, object]:
|
||||
"""Translate a *non-streaming* Ollama chat response to an OpenAI completion."""
|
||||
message = payload.get("message")
|
||||
content = ""
|
||||
if isinstance(message, dict):
|
||||
content = str(message.get("content", ""))
|
||||
prompt_tokens = _as_int(payload.get("prompt_eval_count"))
|
||||
completion_tokens = _as_int(payload.get("eval_count"))
|
||||
return {
|
||||
"id": _completion_id(),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": str(payload.get("model", "")),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def ollama_generate_to_openai(payload: dict[str, object]) -> dict[str, object]:
|
||||
"""Translate a *non-streaming* Ollama generate response to OpenAI completion."""
|
||||
prompt_tokens = _as_int(payload.get("prompt_eval_count"))
|
||||
completion_tokens = _as_int(payload.get("eval_count"))
|
||||
return {
|
||||
"id": _completion_id(),
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": str(payload.get("model", "")),
|
||||
"choices": [
|
||||
{"index": 0, "text": str(payload.get("response", "")), "finish_reason": "stop"}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def ollama_embed_to_openai(payload: dict[str, object], model: str) -> dict[str, object]:
|
||||
"""Translate an Ollama ``/api/embed`` response to the OpenAI embeddings shape."""
|
||||
raw = payload.get("embeddings")
|
||||
vectors: list[list[float]] = raw if isinstance(raw, list) else []
|
||||
prompt_tokens = _as_int(payload.get("prompt_eval_count"))
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"object": "embedding", "index": i, "embedding": vec}
|
||||
for i, vec in enumerate(vectors)
|
||||
],
|
||||
"model": model,
|
||||
"usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens},
|
||||
}
|
||||
|
||||
|
||||
def models_to_openai_list(names: list[str]) -> dict[str, object]:
|
||||
"""Render a list of model names in the OpenAI ``/v1/models`` list format."""
|
||||
created = int(time.time())
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"id": name, "object": "model", "created": created, "owned_by": "neuronetz"}
|
||||
for name in names
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def new_completion_id() -> str:
|
||||
"""Public helper to mint a completion id for a streaming response."""
|
||||
return _completion_id()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"models_to_openai_list",
|
||||
"new_completion_id",
|
||||
"ollama_chat_chunk_to_openai",
|
||||
"ollama_chat_to_openai",
|
||||
"ollama_embed_to_openai",
|
||||
"ollama_generate_chunk_to_openai",
|
||||
"ollama_generate_to_openai",
|
||||
"openai_chat_to_ollama",
|
||||
"openai_completion_to_ollama",
|
||||
"openai_embeddings_to_ollama",
|
||||
]
|
||||
3
src/neuronetz_gateway/ratelimit/__init__.py
Normal file
3
src/neuronetz_gateway/ratelimit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Rate limiting: sliding-window RPM/TPM and concurrency semaphore (Redis)."""
|
||||
|
||||
from __future__ import annotations
|
||||
66
src/neuronetz_gateway/ratelimit/concurrency.py
Normal file
66
src/neuronetz_gateway/ratelimit/concurrency.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Concurrent-connection semaphore backed by Redis ``INCR`` with a TTL guard.
|
||||
|
||||
Acquired before proxying and released on stream close. A TTL on the counter
|
||||
prevents a crashed worker from leaking permits forever (self-healing). Fails
|
||||
closed (SPEC §4.4): if Redis is unreachable, acquisition raises
|
||||
:class:`DependencyUnavailableError` so the caller denies (503).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable
|
||||
from typing import cast
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from neuronetz_gateway.errors import DependencyUnavailableError
|
||||
|
||||
# Default guard TTL: a permit auto-expires if not explicitly released, so a
|
||||
# crashed worker cannot pin the semaphore. Comfortably longer than any single
|
||||
# stream is expected to take is set per-call via ``ttl_s``.
|
||||
|
||||
|
||||
class ConcurrencyLimiter:
|
||||
"""Redis-backed concurrency cap with a self-healing TTL guard."""
|
||||
|
||||
def __init__(self, client: redis.Redis) -> None:
|
||||
self._client = client
|
||||
|
||||
async def acquire(self, scope_key: str, limit: int, ttl_s: int) -> bool:
|
||||
"""Try to acquire a permit; return False (deny) if at capacity.
|
||||
|
||||
Increments the counter, refreshes the TTL guard, and rolls back the
|
||||
increment if the new value exceeds ``limit``. Raises on Redis failure so
|
||||
the caller fails closed.
|
||||
"""
|
||||
try:
|
||||
count = await cast("Awaitable[int]", self._client.incr(scope_key))
|
||||
await self._client.expire(scope_key, ttl_s)
|
||||
except RedisError as exc:
|
||||
raise DependencyUnavailableError(
|
||||
internal_detail=f"concurrency redis error: {exc!r}"
|
||||
) from exc
|
||||
if count > limit:
|
||||
# Over capacity: undo our increment and deny.
|
||||
try:
|
||||
await self._client.decr(scope_key)
|
||||
except RedisError:
|
||||
# Permit self-heals via the TTL guard; denial still stands.
|
||||
return False
|
||||
return False
|
||||
return True
|
||||
|
||||
async def release(self, scope_key: str) -> None:
|
||||
"""Release a previously acquired permit (best-effort; never raises)."""
|
||||
try:
|
||||
count = await cast("Awaitable[int]", self._client.decr(scope_key))
|
||||
if count < 0:
|
||||
await self._client.set(scope_key, 0)
|
||||
except RedisError:
|
||||
# The TTL guard will reclaim the permit; releasing must not break the
|
||||
# request that already completed.
|
||||
return
|
||||
|
||||
|
||||
__all__ = ["ConcurrencyLimiter"]
|
||||
109
src/neuronetz_gateway/ratelimit/sliding_window.py
Normal file
109
src/neuronetz_gateway/ratelimit/sliding_window.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Sliding-window rate limiter backed by an atomic Redis Lua script.
|
||||
|
||||
Enforces per-key and per-tenant RPM and per-key TPM. The window is a sorted set
|
||||
of timestamped hits (a true sliding window, not a fixed bucket): each check
|
||||
trims entries older than ``window_s``, sums the cost of what remains, and admits
|
||||
the new cost only if it keeps the total within ``limit``. The trim + sum +
|
||||
conditional add run inside one Lua script so the decision is atomic across
|
||||
concurrent workers (SPEC §4.3 step 4).
|
||||
|
||||
Fail-closed (SPEC §4.4): if Redis is unavailable the limiter raises
|
||||
:class:`DependencyUnavailableError`; the caller must deny (503), never allow.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from neuronetz_gateway.errors import DependencyUnavailableError
|
||||
|
||||
# KEYS[1] = zset key
|
||||
# ARGV[1] = now (ms) ARGV[2] = window (ms) ARGV[3] = limit
|
||||
# ARGV[4] = cost ARGV[5] = unique member suffix
|
||||
# Returns: {allowed (1/0), used_after, retry_after_ms}
|
||||
_LUA = """
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local window = tonumber(ARGV[2])
|
||||
local limit = tonumber(ARGV[3])
|
||||
local cost = tonumber(ARGV[4])
|
||||
local member = ARGV[5]
|
||||
|
||||
redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
|
||||
-- Each surviving member encodes its cost as a ':<cost>' suffix; sum them so the
|
||||
-- window total is cost-weighted (RPM uses cost 1, TPM uses token count).
|
||||
local total = 0
|
||||
local data = redis.call('ZRANGEBYSCORE', key, now - window, now)
|
||||
for i = 1, #data do
|
||||
local sep = string.find(data[i], ':[^:]*$')
|
||||
if sep then
|
||||
total = total + tonumber(string.sub(data[i], sep + 1))
|
||||
else
|
||||
total = total + 1
|
||||
end
|
||||
end
|
||||
|
||||
if total + cost > limit then
|
||||
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
|
||||
local retry = window
|
||||
if oldest[2] then
|
||||
retry = (tonumber(oldest[2]) + window) - now
|
||||
if retry < 0 then retry = 0 end
|
||||
end
|
||||
return {0, total, retry}
|
||||
end
|
||||
|
||||
redis.call('ZADD', key, now, member .. ':' .. cost)
|
||||
redis.call('PEXPIRE', key, window)
|
||||
return {1, total + cost, 0}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RateLimitResult:
|
||||
"""Outcome of a rate-limit check."""
|
||||
|
||||
allowed: bool
|
||||
limit: int
|
||||
remaining: int
|
||||
retry_after_s: int | None
|
||||
|
||||
|
||||
class SlidingWindowLimiter:
|
||||
"""Redis-backed sliding-window limiter (atomic via Lua)."""
|
||||
|
||||
def __init__(self, client: redis.Redis) -> None:
|
||||
self._client = client
|
||||
self._script = client.register_script(_LUA)
|
||||
|
||||
async def check(self, key: str, limit: int, window_s: int, cost: int = 1) -> RateLimitResult:
|
||||
"""Atomically record a hit of ``cost`` and report admission.
|
||||
|
||||
Raises :class:`DependencyUnavailableError` if Redis cannot be reached, so
|
||||
the caller fails closed.
|
||||
"""
|
||||
now_ms = int(time.time() * 1000)
|
||||
window_ms = window_s * 1000
|
||||
member = f"{now_ms}-{id(object())}"
|
||||
try:
|
||||
raw = await self._script(
|
||||
keys=[key], args=[now_ms, window_ms, limit, cost, member]
|
||||
)
|
||||
except RedisError as exc:
|
||||
raise DependencyUnavailableError(
|
||||
internal_detail=f"ratelimit redis error: {exc!r}"
|
||||
) from exc
|
||||
allowed_i, used_after, retry_ms = (int(raw[0]), int(raw[1]), int(raw[2]))
|
||||
allowed = allowed_i == 1
|
||||
remaining = max(limit - used_after, 0)
|
||||
retry_after_s = None if allowed else max(1, (retry_ms + 999) // 1000)
|
||||
return RateLimitResult(
|
||||
allowed=allowed, limit=limit, remaining=remaining, retry_after_s=retry_after_s
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["RateLimitResult", "SlidingWindowLimiter"]
|
||||
97
src/neuronetz_gateway/revocation.py
Normal file
97
src/neuronetz_gateway/revocation.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Key-revocation NOTIFY listener (SPEC §4.5).
|
||||
|
||||
Console (or the gateway's own CLI) revokes a key by inserting into
|
||||
``gateway.revocations``; an ``AFTER INSERT`` trigger fires
|
||||
``pg_notify('key_revoked', key_id)``. This background task LISTENs on that
|
||||
channel and, on each notification, evicts the Redis auth-cache entry for the
|
||||
revoked key's prefix so the next request misses the cache, re-reads the DB,
|
||||
finds the key non-active, and is rejected — making revocation effective within
|
||||
one Redis RTT without any cross-service HTTP.
|
||||
|
||||
The listener resolves ``key_id -> prefix`` via a short DB lookup (the NOTIFY
|
||||
payload is the key id, but the cache is keyed by prefix). It is resilient: a
|
||||
dropped connection is retried with backoff.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import uuid
|
||||
|
||||
import asyncpg
|
||||
import redis.asyncio as redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from neuronetz_gateway.config import Settings
|
||||
from neuronetz_gateway.db.models import ApiKey
|
||||
from neuronetz_gateway.observability.logging import get_logger
|
||||
|
||||
_log = get_logger("revocation")
|
||||
|
||||
_CHANNEL = "key_revoked"
|
||||
_CACHE_PREFIX = "gateway:key:"
|
||||
|
||||
|
||||
def _asyncpg_dsn(database_url: str) -> str:
|
||||
"""Strip the SQLAlchemy ``+asyncpg`` driver tag for a raw asyncpg connect."""
|
||||
return database_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
|
||||
async def _evict(
|
||||
key_id_text: str,
|
||||
sessionmaker: async_sessionmaker[AsyncSession],
|
||||
redis_client: redis.Redis,
|
||||
) -> None:
|
||||
"""Resolve the key id to its prefix and delete the cached principal."""
|
||||
try:
|
||||
key_id = uuid.UUID(key_id_text)
|
||||
except ValueError:
|
||||
_log.warning("revocation_bad_payload", payload=key_id_text)
|
||||
return
|
||||
try:
|
||||
async with sessionmaker() as session:
|
||||
key = await session.get(ApiKey, key_id)
|
||||
prefix = key.prefix if key is not None else None
|
||||
if prefix is not None:
|
||||
await redis_client.delete(_CACHE_PREFIX + prefix)
|
||||
_log.info("revocation_cache_evicted", key_prefix=prefix)
|
||||
except Exception as exc: # noqa: BLE001 - listener must survive transient errors
|
||||
_log.warning("revocation_evict_failed", error=str(exc))
|
||||
|
||||
|
||||
async def revocation_listener(
|
||||
settings: Settings,
|
||||
redis_client: redis.Redis,
|
||||
sessionmaker: async_sessionmaker[AsyncSession],
|
||||
) -> None:
|
||||
"""LISTEN on ``key_revoked`` and evict the Redis cache on each notification."""
|
||||
dsn = _asyncpg_dsn(settings.database_url)
|
||||
while True:
|
||||
conn = None
|
||||
try:
|
||||
conn = await asyncpg.connect(dsn)
|
||||
|
||||
def _on_notify(
|
||||
_c: object, _pid: int, _channel: str, payload: str
|
||||
) -> None:
|
||||
# Schedule the async eviction without blocking the callback.
|
||||
asyncio.create_task(_evict(payload, sessionmaker, redis_client)) # noqa: RUF006
|
||||
|
||||
await conn.add_listener(_CHANNEL, _on_notify)
|
||||
_log.info("revocation_listener_started")
|
||||
# Wait forever; notifications arrive via the callback. asyncio.Event
|
||||
# is a cancel-friendly substitute for an unbounded sleep loop.
|
||||
await asyncio.Event().wait()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001 - reconnect on any failure
|
||||
_log.warning("revocation_listener_reconnect", error=str(exc))
|
||||
await asyncio.sleep(2)
|
||||
finally:
|
||||
if conn is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await conn.close()
|
||||
|
||||
|
||||
__all__ = ["revocation_listener"]
|
||||
168
src/neuronetz_gateway/routes/ollama_native.py
Normal file
168
src/neuronetz_gateway/routes/ollama_native.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Native Ollama passthrough routes (SPEC §6.1).
|
||||
|
||||
All proxied endpoints run through the shared :class:`Pipeline` (auth has already
|
||||
attached the principal in middleware): scope + model + endpoint allowlist, rate
|
||||
limit, budget, concurrency, body validation, then stream/relay with post-close
|
||||
token-count + audit + budget accounting.
|
||||
|
||||
Mutating endpoints (``/api/pull|push|create|copy|delete``, ``/api/blobs/*``) and
|
||||
``/api/ps`` are hard-blocked (SPEC §6.2) and intentionally NOT routed; a catch-all
|
||||
returns a generic 403 so their existence is never confirmed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from neuronetz_gateway import __version__
|
||||
from neuronetz_gateway.deps import (
|
||||
ConfigDep,
|
||||
DiscoveryCacheDep,
|
||||
OllamaClientDep,
|
||||
PipelineDep,
|
||||
PrincipalDep,
|
||||
)
|
||||
from neuronetz_gateway.errors import AuthorizationError
|
||||
from neuronetz_gateway.proxy.allowlist import is_hard_blocked, resolve_effective_models
|
||||
from neuronetz_gateway.proxy.pipeline import model_of, read_json_body
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["ollama-native"])
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def chat(request: Request, pipeline: PipelineDep) -> Response:
|
||||
"""Proxy ``POST /api/chat`` (streamed NDJSON / non-streamed)."""
|
||||
return await _chat_or_generate(request, pipeline, path="/api/chat", scope="chat")
|
||||
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate(request: Request, pipeline: PipelineDep) -> Response:
|
||||
"""Proxy ``POST /api/generate`` (streamed NDJSON / non-streamed)."""
|
||||
return await _chat_or_generate(request, pipeline, path="/api/generate", scope="chat")
|
||||
|
||||
|
||||
async def _chat_or_generate(
|
||||
request: Request, pipeline: PipelineDep, *, path: str, scope: str
|
||||
) -> Response:
|
||||
body = await read_json_body(request, pipeline.settings)
|
||||
model = model_of(body)
|
||||
pipeline.check_scope(scope)
|
||||
pipeline.check_endpoint(path)
|
||||
pipeline.check_model(model)
|
||||
pipeline.validate_body(body)
|
||||
await pipeline.enforce_limits()
|
||||
if bool(body.get("stream", True)):
|
||||
return await pipeline.stream_native("POST", path, body, model)
|
||||
return await pipeline.request_native("POST", path, body, model)
|
||||
|
||||
|
||||
@router.post("/embeddings")
|
||||
async def embeddings(request: Request, pipeline: PipelineDep) -> Response:
|
||||
"""Proxy ``POST /api/embeddings`` (legacy, non-streamed)."""
|
||||
return await _embeddings(request, pipeline, "/api/embeddings")
|
||||
|
||||
|
||||
@router.post("/embed")
|
||||
async def embed(request: Request, pipeline: PipelineDep) -> Response:
|
||||
"""Proxy ``POST /api/embed`` (non-streamed)."""
|
||||
return await _embeddings(request, pipeline, "/api/embed")
|
||||
|
||||
|
||||
async def _embeddings(request: Request, pipeline: PipelineDep, path: str) -> Response:
|
||||
body = await read_json_body(request, pipeline.settings)
|
||||
model = model_of(body)
|
||||
pipeline.check_scope("embeddings")
|
||||
pipeline.check_endpoint(path)
|
||||
pipeline.check_model(model)
|
||||
await pipeline.enforce_limits()
|
||||
return await pipeline.request_native("POST", path, body, model)
|
||||
|
||||
|
||||
@router.get("/tags")
|
||||
async def tags(principal: PrincipalDep, discovery: DiscoveryCacheDep) -> JSONResponse:
|
||||
"""Return the tenant's effective model set (live-discovered ∩ allowed)."""
|
||||
effective = resolve_effective_models(
|
||||
allow_all=principal.allow_all_models,
|
||||
allowed_models=principal.allowed_models,
|
||||
discovered=discovery.names,
|
||||
)
|
||||
models = [
|
||||
{
|
||||
"name": m.name,
|
||||
"model": m.name,
|
||||
"modified_at": m.modified_at,
|
||||
"size": m.size_bytes,
|
||||
"details": {
|
||||
"family": m.family,
|
||||
"parameter_size": m.parameter_size,
|
||||
"quantization_level": m.quantization,
|
||||
},
|
||||
}
|
||||
for m in discovery.models
|
||||
if m.name in effective
|
||||
]
|
||||
return JSONResponse({"models": models})
|
||||
|
||||
|
||||
@router.post("/show")
|
||||
async def show(
|
||||
request: Request,
|
||||
principal: PrincipalDep,
|
||||
discovery: DiscoveryCacheDep,
|
||||
ollama: OllamaClientDep,
|
||||
settings: ConfigDep,
|
||||
) -> JSONResponse:
|
||||
"""Proxy ``POST /api/show`` for an effective-set model; sanitize the result."""
|
||||
body = await read_json_body(request, settings)
|
||||
name = body.get("model") or body.get("name")
|
||||
model = name if isinstance(name, str) else ""
|
||||
if not model or model not in resolve_effective_models(
|
||||
allow_all=principal.allow_all_models,
|
||||
allowed_models=principal.allowed_models,
|
||||
discovered=discovery.names,
|
||||
):
|
||||
raise AuthorizationError(internal_detail="show: model not in effective set")
|
||||
resp = await ollama.request("POST", "/api/show", {"model": model})
|
||||
payload = resp.json()
|
||||
raw: dict[str, object] = payload if isinstance(payload, dict) else {}
|
||||
# Strip system prompt + template (SPEC §6.1: no system prompts, no template).
|
||||
raw_details = raw.get("details")
|
||||
details: dict[str, object] = raw_details if isinstance(raw_details, dict) else {}
|
||||
return JSONResponse(
|
||||
{
|
||||
"model": model,
|
||||
"details": {
|
||||
"family": details.get("family"),
|
||||
"parameter_size": details.get("parameter_size"),
|
||||
"quantization_level": details.get("quantization_level"),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
async def version() -> JSONResponse:
|
||||
"""Return the gateway version (never Ollama's; SPEC §6.1)."""
|
||||
return JSONResponse({"version": __version__})
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/{rest:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def catch_all(rest: str) -> Response:
|
||||
"""Generic 403 for any other ``/api/*`` path (hard-blocked or unknown).
|
||||
|
||||
Mutating endpoints and ``/api/ps`` resolve here and return the same generic
|
||||
forbidden response, so the gateway never confirms which upstream endpoints
|
||||
exist (SPEC §6.2, §13.6).
|
||||
"""
|
||||
full = f"/api/{rest}"
|
||||
if is_hard_blocked(full):
|
||||
raise AuthorizationError(internal_detail=f"hard-blocked {full}")
|
||||
raise AuthorizationError(internal_detail=f"unrouted upstream path {full}")
|
||||
|
||||
|
||||
__all__ = ["router"]
|
||||
118
src/neuronetz_gateway/routes/openai_compat.py
Normal file
118
src/neuronetz_gateway/routes/openai_compat.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""OpenAI-compatible routes (SPEC §6.3).
|
||||
|
||||
Each route translates the OpenAI request into the native Ollama body, runs the
|
||||
same :class:`Pipeline` enforcement as the native routes, and translates the
|
||||
response back. Streaming uses SSE (``data: {...}\\n\\n`` … ``data: [DONE]\\n\\n``);
|
||||
non-streaming returns a single OpenAI-shaped JSON object. ``/v1/models`` returns
|
||||
the tenant's effective discovered set in OpenAI list format.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from neuronetz_gateway.deps import (
|
||||
DiscoveryCacheDep,
|
||||
PipelineDep,
|
||||
PrincipalDep,
|
||||
)
|
||||
from neuronetz_gateway.proxy.allowlist import resolve_effective_models
|
||||
from neuronetz_gateway.proxy.pipeline import model_of, read_json_body
|
||||
from neuronetz_gateway.proxy.translate import (
|
||||
models_to_openai_list,
|
||||
new_completion_id,
|
||||
ollama_chat_chunk_to_openai,
|
||||
ollama_chat_to_openai,
|
||||
ollama_embed_to_openai,
|
||||
ollama_generate_chunk_to_openai,
|
||||
ollama_generate_to_openai,
|
||||
openai_chat_to_ollama,
|
||||
openai_completion_to_ollama,
|
||||
openai_embeddings_to_ollama,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/v1", tags=["openai-compat"])
|
||||
|
||||
|
||||
@router.post("/chat/completions")
|
||||
async def chat_completions(request: Request, pipeline: PipelineDep) -> Response:
|
||||
"""OpenAI ``/v1/chat/completions`` -> Ollama ``/api/chat``."""
|
||||
payload = await read_json_body(request, pipeline.settings)
|
||||
body = openai_chat_to_ollama(payload)
|
||||
model = model_of(body)
|
||||
pipeline.check_scope("chat")
|
||||
pipeline.check_endpoint("/api/chat")
|
||||
pipeline.check_model(model)
|
||||
pipeline.validate_body(body)
|
||||
await pipeline.enforce_limits()
|
||||
if bool(payload.get("stream", False)):
|
||||
completion_id = new_completion_id()
|
||||
created = int(time.time())
|
||||
|
||||
def translate(chunk: dict[str, object]) -> dict[str, object]:
|
||||
return ollama_chat_chunk_to_openai(
|
||||
chunk, completion_id=completion_id, model=model, created=created
|
||||
)
|
||||
|
||||
return await pipeline.stream_openai("POST", "/api/chat", body, model, translate)
|
||||
return await pipeline.request_translated(
|
||||
"POST", "/api/chat", body, model, ollama_chat_to_openai
|
||||
)
|
||||
|
||||
|
||||
@router.post("/completions")
|
||||
async def completions(request: Request, pipeline: PipelineDep) -> Response:
|
||||
"""OpenAI ``/v1/completions`` -> Ollama ``/api/generate``."""
|
||||
payload = await read_json_body(request, pipeline.settings)
|
||||
body = openai_completion_to_ollama(payload)
|
||||
model = model_of(body)
|
||||
pipeline.check_scope("chat")
|
||||
pipeline.check_endpoint("/api/generate")
|
||||
pipeline.check_model(model)
|
||||
pipeline.validate_body(body)
|
||||
await pipeline.enforce_limits()
|
||||
if bool(payload.get("stream", False)):
|
||||
completion_id = new_completion_id()
|
||||
created = int(time.time())
|
||||
|
||||
def translate(chunk: dict[str, object]) -> dict[str, object]:
|
||||
return ollama_generate_chunk_to_openai(
|
||||
chunk, completion_id=completion_id, model=model, created=created
|
||||
)
|
||||
|
||||
return await pipeline.stream_openai("POST", "/api/generate", body, model, translate)
|
||||
return await pipeline.request_translated(
|
||||
"POST", "/api/generate", body, model, ollama_generate_to_openai
|
||||
)
|
||||
|
||||
|
||||
@router.post("/embeddings")
|
||||
async def embeddings(request: Request, pipeline: PipelineDep) -> Response:
|
||||
"""OpenAI ``/v1/embeddings`` -> Ollama ``/api/embed``."""
|
||||
payload = await read_json_body(request, pipeline.settings)
|
||||
body = openai_embeddings_to_ollama(payload)
|
||||
model = model_of(body)
|
||||
pipeline.check_scope("embeddings")
|
||||
pipeline.check_endpoint("/api/embed")
|
||||
pipeline.check_model(model)
|
||||
await pipeline.enforce_limits()
|
||||
return await pipeline.request_translated(
|
||||
"POST", "/api/embed", body, model, lambda obj: ollama_embed_to_openai(obj, model)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def models(principal: PrincipalDep, discovery: DiscoveryCacheDep) -> JSONResponse:
|
||||
"""OpenAI ``/v1/models`` -> the tenant's effective discovered set."""
|
||||
effective = resolve_effective_models(
|
||||
allow_all=principal.allow_all_models,
|
||||
allowed_models=principal.allowed_models,
|
||||
discovered=discovery.names,
|
||||
)
|
||||
return JSONResponse(models_to_openai_list(sorted(effective)))
|
||||
|
||||
|
||||
__all__ = ["router"]
|
||||
Reference in New Issue
Block a user