Files
neuronetz-gateway/src/neuronetz_gateway/ratelimit/sliding_window.py
Stephan Berbig 6a92bc8ce9 proxy: streaming, discovery, OpenAI-compat, rate-limit, budget, audit
The hot path. A single Pipeline class owns enforcement so the eight
non-negotiables can be reviewed in one place.

- Native /api/chat, /api/generate (NDJSON streaming + non-stream), /api/tags,
  /api/show (system-prompt + template stripped), /api/embed(dings), /api/version
  (returns gateway version, not Ollama's). Endpoint catch-all returns the same
  generic 403 for hard-blocked and unknown /api/* paths so attackers cannot
  enumerate which mutating endpoints exist.
- OpenAI-compat /v1/chat/completions, /v1/completions, /v1/embeddings,
  /v1/models with SSE (`data: {...}` + final `data: [DONE]`); preserves
  streaming end-to-end.
- Model discovery (SPEC §4.6): background poller against Ollama /api/tags;
  Redis + in-process cache (TTL = MODEL_DISCOVERY_CACHE_TTL_S, refresh =
  MODEL_DISCOVERY_REFRESH_S); fail-closed when the discovered set is empty.
- Effective-set resolution in proxy/allowlist.py:
    allow_all = key.allow_all_models ?? tenant.allow_all_models
    effective = discovered if allow_all
                else (key.allowed_models ?? tenant.allowed_models) ∩ discovered
  A non-effective model returns the same generic 403 whether it's installed-
  but-unpermitted or doesn't exist at all (no enumeration leak).
- Sliding-window rate limit (Redis Lua, single round-trip) for per-key +
  per-tenant RPM and per-key TPM. Redis-INCR/DECR concurrency semaphore with
  TTL guard. Token-budget counters per (key, period) with a Postgres ledger
  for reconciliation across resets. Headers per SPEC §6.5 on every response;
  429 carries Retry-After; Redis outage → 503 (fail closed, never 200).
- Token counting from the FINAL stream object (NDJSON `done` or the SSE chunk
  carrying `usage`); the audit row is written AFTER stream close so TTFB is
  never degraded by bookkeeping.
- Audit writer: asyncio.Queue + bounded ring buffer; deny-mode flip on overflow.
  Optional prompt log per key (TTL'd).
- Revocation listener: asyncpg LISTEN on key_revoked → evict the Redis cache
  entry within ~1s of the console writing to gateway.revocations.
- Prometheus counters/histograms labeled by tenant only (per SPEC §13.3).
2026-05-26 20:52:33 +02:00

110 lines
3.6 KiB
Python

"""Sliding-window rate limiter backed by an atomic Redis Lua script.
Enforces per-key and per-tenant RPM and per-key TPM. The window is a sorted set
of timestamped hits (a true sliding window, not a fixed bucket): each check
trims entries older than ``window_s``, sums the cost of what remains, and admits
the new cost only if it keeps the total within ``limit``. The trim + sum +
conditional add run inside one Lua script so the decision is atomic across
concurrent workers (SPEC §4.3 step 4).
Fail-closed (SPEC §4.4): if Redis is unavailable the limiter raises
:class:`DependencyUnavailableError`; the caller must deny (503), never allow.
"""
from __future__ import annotations
import time
from dataclasses import dataclass
import redis.asyncio as redis
from redis.exceptions import RedisError
from neuronetz_gateway.errors import DependencyUnavailableError
# KEYS[1] = zset key
# ARGV[1] = now (ms) ARGV[2] = window (ms) ARGV[3] = limit
# ARGV[4] = cost ARGV[5] = unique member suffix
# Returns: {allowed (1/0), used_after, retry_after_ms}
_LUA = """
local key = KEYS[1]
local now = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local limit = tonumber(ARGV[3])
local cost = tonumber(ARGV[4])
local member = ARGV[5]
redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
-- Each surviving member encodes its cost as a ':<cost>' suffix; sum them so the
-- window total is cost-weighted (RPM uses cost 1, TPM uses token count).
local total = 0
local data = redis.call('ZRANGEBYSCORE', key, now - window, now)
for i = 1, #data do
local sep = string.find(data[i], ':[^:]*$')
if sep then
total = total + tonumber(string.sub(data[i], sep + 1))
else
total = total + 1
end
end
if total + cost > limit then
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
local retry = window
if oldest[2] then
retry = (tonumber(oldest[2]) + window) - now
if retry < 0 then retry = 0 end
end
return {0, total, retry}
end
redis.call('ZADD', key, now, member .. ':' .. cost)
redis.call('PEXPIRE', key, window)
return {1, total + cost, 0}
"""
@dataclass(frozen=True, slots=True)
class RateLimitResult:
"""Outcome of a rate-limit check."""
allowed: bool
limit: int
remaining: int
retry_after_s: int | None
class SlidingWindowLimiter:
"""Redis-backed sliding-window limiter (atomic via Lua)."""
def __init__(self, client: redis.Redis) -> None:
self._client = client
self._script = client.register_script(_LUA)
async def check(self, key: str, limit: int, window_s: int, cost: int = 1) -> RateLimitResult:
"""Atomically record a hit of ``cost`` and report admission.
Raises :class:`DependencyUnavailableError` if Redis cannot be reached, so
the caller fails closed.
"""
now_ms = int(time.time() * 1000)
window_ms = window_s * 1000
member = f"{now_ms}-{id(object())}"
try:
raw = await self._script(
keys=[key], args=[now_ms, window_ms, limit, cost, member]
)
except RedisError as exc:
raise DependencyUnavailableError(
internal_detail=f"ratelimit redis error: {exc!r}"
) from exc
allowed_i, used_after, retry_ms = (int(raw[0]), int(raw[1]), int(raw[2]))
allowed = allowed_i == 1
remaining = max(limit - used_after, 0)
retry_after_s = None if allowed else max(1, (retry_ms + 999) // 1000)
return RateLimitResult(
allowed=allowed, limit=limit, remaining=remaining, retry_after_s=retry_after_s
)
__all__ = ["RateLimitResult", "SlidingWindowLimiter"]