"""Sliding-window rate limiter backed by an atomic Redis Lua script. Enforces per-key and per-tenant RPM and per-key TPM. The window is a sorted set of timestamped hits (a true sliding window, not a fixed bucket): each check trims entries older than ``window_s``, sums the cost of what remains, and admits the new cost only if it keeps the total within ``limit``. The trim + sum + conditional add run inside one Lua script so the decision is atomic across concurrent workers (SPEC §4.3 step 4). Fail-closed (SPEC §4.4): if Redis is unavailable the limiter raises :class:`DependencyUnavailableError`; the caller must deny (503), never allow. """ from __future__ import annotations import time from dataclasses import dataclass import redis.asyncio as redis from redis.exceptions import RedisError from neuronetz_gateway.errors import DependencyUnavailableError # KEYS[1] = zset key # ARGV[1] = now (ms) ARGV[2] = window (ms) ARGV[3] = limit # ARGV[4] = cost ARGV[5] = unique member suffix # Returns: {allowed (1/0), used_after, retry_after_ms} _LUA = """ local key = KEYS[1] local now = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local limit = tonumber(ARGV[3]) local cost = tonumber(ARGV[4]) local member = ARGV[5] redis.call('ZREMRANGEBYSCORE', key, 0, now - window) -- Each surviving member encodes its cost as a ':' suffix; sum them so the -- window total is cost-weighted (RPM uses cost 1, TPM uses token count). local total = 0 local data = redis.call('ZRANGEBYSCORE', key, now - window, now) for i = 1, #data do local sep = string.find(data[i], ':[^:]*$') if sep then total = total + tonumber(string.sub(data[i], sep + 1)) else total = total + 1 end end if total + cost > limit then local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') local retry = window if oldest[2] then retry = (tonumber(oldest[2]) + window) - now if retry < 0 then retry = 0 end end return {0, total, retry} end redis.call('ZADD', key, now, member .. ':' .. cost) redis.call('PEXPIRE', key, window) return {1, total + cost, 0} """ @dataclass(frozen=True, slots=True) class RateLimitResult: """Outcome of a rate-limit check.""" allowed: bool limit: int remaining: int retry_after_s: int | None class SlidingWindowLimiter: """Redis-backed sliding-window limiter (atomic via Lua).""" def __init__(self, client: redis.Redis) -> None: self._client = client self._script = client.register_script(_LUA) async def check(self, key: str, limit: int, window_s: int, cost: int = 1) -> RateLimitResult: """Atomically record a hit of ``cost`` and report admission. Raises :class:`DependencyUnavailableError` if Redis cannot be reached, so the caller fails closed. """ now_ms = int(time.time() * 1000) window_ms = window_s * 1000 member = f"{now_ms}-{id(object())}" try: raw = await self._script( keys=[key], args=[now_ms, window_ms, limit, cost, member] ) except RedisError as exc: raise DependencyUnavailableError( internal_detail=f"ratelimit redis error: {exc!r}" ) from exc allowed_i, used_after, retry_ms = (int(raw[0]), int(raw[1]), int(raw[2])) allowed = allowed_i == 1 remaining = max(limit - used_after, 0) retry_after_s = None if allowed else max(1, (retry_ms + 999) // 1000) return RateLimitResult( allowed=allowed, limit=limit, remaining=remaining, retry_after_s=retry_after_s ) __all__ = ["RateLimitResult", "SlidingWindowLimiter"]