The hot path. A single Pipeline class owns enforcement so the eight
non-negotiables can be reviewed in one place.
- Native /api/chat, /api/generate (NDJSON streaming + non-stream), /api/tags,
/api/show (system-prompt + template stripped), /api/embed(dings), /api/version
(returns gateway version, not Ollama's). Endpoint catch-all returns the same
generic 403 for hard-blocked and unknown /api/* paths so attackers cannot
enumerate which mutating endpoints exist.
- OpenAI-compat /v1/chat/completions, /v1/completions, /v1/embeddings,
/v1/models with SSE (`data: {...}` + final `data: [DONE]`); preserves
streaming end-to-end.
- Model discovery (SPEC §4.6): background poller against Ollama /api/tags;
Redis + in-process cache (TTL = MODEL_DISCOVERY_CACHE_TTL_S, refresh =
MODEL_DISCOVERY_REFRESH_S); fail-closed when the discovered set is empty.
- Effective-set resolution in proxy/allowlist.py:
allow_all = key.allow_all_models ?? tenant.allow_all_models
effective = discovered if allow_all
else (key.allowed_models ?? tenant.allowed_models) ∩ discovered
A non-effective model returns the same generic 403 whether it's installed-
but-unpermitted or doesn't exist at all (no enumeration leak).
- Sliding-window rate limit (Redis Lua, single round-trip) for per-key +
per-tenant RPM and per-key TPM. Redis-INCR/DECR concurrency semaphore with
TTL guard. Token-budget counters per (key, period) with a Postgres ledger
for reconciliation across resets. Headers per SPEC §6.5 on every response;
429 carries Retry-After; Redis outage → 503 (fail closed, never 200).
- Token counting from the FINAL stream object (NDJSON `done` or the SSE chunk
carrying `usage`); the audit row is written AFTER stream close so TTFB is
never degraded by bookkeeping.
- Audit writer: asyncio.Queue + bounded ring buffer; deny-mode flip on overflow.
Optional prompt log per key (TTL'd).
- Revocation listener: asyncpg LISTEN on key_revoked → evict the Redis cache
entry within ~1s of the console writing to gateway.revocations.
- Prometheus counters/histograms labeled by tenant only (per SPEC §13.3).
110 lines
3.6 KiB
Python
110 lines
3.6 KiB
Python
"""Sliding-window rate limiter backed by an atomic Redis Lua script.
|
|
|
|
Enforces per-key and per-tenant RPM and per-key TPM. The window is a sorted set
|
|
of timestamped hits (a true sliding window, not a fixed bucket): each check
|
|
trims entries older than ``window_s``, sums the cost of what remains, and admits
|
|
the new cost only if it keeps the total within ``limit``. The trim + sum +
|
|
conditional add run inside one Lua script so the decision is atomic across
|
|
concurrent workers (SPEC §4.3 step 4).
|
|
|
|
Fail-closed (SPEC §4.4): if Redis is unavailable the limiter raises
|
|
:class:`DependencyUnavailableError`; the caller must deny (503), never allow.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from dataclasses import dataclass
|
|
|
|
import redis.asyncio as redis
|
|
from redis.exceptions import RedisError
|
|
|
|
from neuronetz_gateway.errors import DependencyUnavailableError
|
|
|
|
# KEYS[1] = zset key
|
|
# ARGV[1] = now (ms) ARGV[2] = window (ms) ARGV[3] = limit
|
|
# ARGV[4] = cost ARGV[5] = unique member suffix
|
|
# Returns: {allowed (1/0), used_after, retry_after_ms}
|
|
_LUA = """
|
|
local key = KEYS[1]
|
|
local now = tonumber(ARGV[1])
|
|
local window = tonumber(ARGV[2])
|
|
local limit = tonumber(ARGV[3])
|
|
local cost = tonumber(ARGV[4])
|
|
local member = ARGV[5]
|
|
|
|
redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
|
|
-- Each surviving member encodes its cost as a ':<cost>' suffix; sum them so the
|
|
-- window total is cost-weighted (RPM uses cost 1, TPM uses token count).
|
|
local total = 0
|
|
local data = redis.call('ZRANGEBYSCORE', key, now - window, now)
|
|
for i = 1, #data do
|
|
local sep = string.find(data[i], ':[^:]*$')
|
|
if sep then
|
|
total = total + tonumber(string.sub(data[i], sep + 1))
|
|
else
|
|
total = total + 1
|
|
end
|
|
end
|
|
|
|
if total + cost > limit then
|
|
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
|
|
local retry = window
|
|
if oldest[2] then
|
|
retry = (tonumber(oldest[2]) + window) - now
|
|
if retry < 0 then retry = 0 end
|
|
end
|
|
return {0, total, retry}
|
|
end
|
|
|
|
redis.call('ZADD', key, now, member .. ':' .. cost)
|
|
redis.call('PEXPIRE', key, window)
|
|
return {1, total + cost, 0}
|
|
"""
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class RateLimitResult:
|
|
"""Outcome of a rate-limit check."""
|
|
|
|
allowed: bool
|
|
limit: int
|
|
remaining: int
|
|
retry_after_s: int | None
|
|
|
|
|
|
class SlidingWindowLimiter:
|
|
"""Redis-backed sliding-window limiter (atomic via Lua)."""
|
|
|
|
def __init__(self, client: redis.Redis) -> None:
|
|
self._client = client
|
|
self._script = client.register_script(_LUA)
|
|
|
|
async def check(self, key: str, limit: int, window_s: int, cost: int = 1) -> RateLimitResult:
|
|
"""Atomically record a hit of ``cost`` and report admission.
|
|
|
|
Raises :class:`DependencyUnavailableError` if Redis cannot be reached, so
|
|
the caller fails closed.
|
|
"""
|
|
now_ms = int(time.time() * 1000)
|
|
window_ms = window_s * 1000
|
|
member = f"{now_ms}-{id(object())}"
|
|
try:
|
|
raw = await self._script(
|
|
keys=[key], args=[now_ms, window_ms, limit, cost, member]
|
|
)
|
|
except RedisError as exc:
|
|
raise DependencyUnavailableError(
|
|
internal_detail=f"ratelimit redis error: {exc!r}"
|
|
) from exc
|
|
allowed_i, used_after, retry_ms = (int(raw[0]), int(raw[1]), int(raw[2]))
|
|
allowed = allowed_i == 1
|
|
remaining = max(limit - used_after, 0)
|
|
retry_after_s = None if allowed else max(1, (retry_ms + 999) // 1000)
|
|
return RateLimitResult(
|
|
allowed=allowed, limit=limit, remaining=remaining, retry_after_s=retry_after_s
|
|
)
|
|
|
|
|
|
__all__ = ["RateLimitResult", "SlidingWindowLimiter"]
|