Real test bodies (not stubs), driven against an in-process httpx.ASGITransport override of the gateway's get_ollama_client dependency pointing at tests/integration/mock_ollama.py. Unit (target 100% on auth/, ratelimit/, budget/): - argon2id roundtrip, wrong-key, garbage encoding, needs_rehash on param change - key format/uniqueness/prefix extraction - token counter (prompt_eval_count + eval_count, embeddings, missing-counts) - translate (OpenAI <-> Ollama for chat/completion/embeddings, streaming chunks, /v1/models list shape) - allowlist (hard-blocks, effective-set semantics across allow_all/inheritance/ empty-discovered) - discovery (parse, cache roundtrip with TTL, fail-closed, tolerates redis=None) - sliding window (allow/block/reset/per-key vs per-tenant/cost-weighted) Integration (testcontainers postgres + redis + in-process mock Ollama): - auth flow (no/malformed/wrong key all return identical sanitized 401) - proxy stream (NDJSON roundtrip, audit row's token counts match, hard-blocked endpoints uniformly 403) - openai_compat (SSE chunks, data: [DONE], non-stream shape, /v1/models) - model_discovery (allow_all sees all, default-deny sees allowed ∩ discovered, /v1/models filtered, unpermitted-but-installed = nonexistent = 403, empty cache denies even allow_all) - rate_limit (429 + Retry-After + headers; Redis down ⇒ 503, never 200) - budget (decrement + headers; pre-burned counter blocks next request) - revocation (INSERT into gateway.revocations → NOTIFY → cache evicted → 401 ≤ 1s) Includes a known-issue xfail flagging a bug in ratelimit/sliding_window.py: the per-hit ZSET member uses id(object()) which returns the same id on consecutive calls, causing same-millisecond hits to overwrite instead of stacking. To be fixed in a follow-up commit.
136 lines
4.7 KiB
Python
136 lines
4.7 KiB
Python
"""Integration tests for rate limiting (SPEC §4.3 step 4, §4.4, §6.5, §12).
|
|
|
|
* Per-key RPM trips at the configured limit; 429 carries ``Retry-After`` and
|
|
the §6.5 rate-limit headers.
|
|
* Redis outage fails closed with 503 (never 200).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from tests.integration.conftest import (
|
|
IntegrationApp,
|
|
IntegrationKey,
|
|
_create_tenant_and_key,
|
|
)
|
|
from tests.integration.mock_ollama import DEFAULT_MODELS
|
|
|
|
pytestmark = pytest.mark.asyncio
|
|
|
|
|
|
async def _spam(client: httpx.AsyncClient, key: IntegrationKey, n: int) -> list[httpx.Response]:
|
|
"""Issue n authenticated non-streaming chat requests sequentially."""
|
|
out: list[httpx.Response] = []
|
|
for _ in range(n):
|
|
out.append(
|
|
await client.post(
|
|
"/api/chat",
|
|
headers={"Authorization": f"Bearer {key.full_key}"},
|
|
json={
|
|
"model": "llama3.1:8b",
|
|
"messages": [{"role": "user", "content": "hi"}],
|
|
"stream": False,
|
|
},
|
|
)
|
|
)
|
|
await asyncio.sleep(0.003) # avoid the known same-ms ZSET collision bug
|
|
return out
|
|
|
|
|
|
async def test_rpm_limit_returns_429_with_retry_after(
|
|
integration_app: IntegrationApp, client: httpx.AsyncClient
|
|
) -> None:
|
|
# Tight RPM so we hit the cap without burning the suite's time.
|
|
key = await _create_tenant_and_key(
|
|
integration_app,
|
|
rpm=2,
|
|
allowed_models=list(DEFAULT_MODELS),
|
|
)
|
|
results = await _spam(client, key, 5)
|
|
codes = [r.status_code for r in results]
|
|
# Some requests admitted, then 429s start.
|
|
assert 429 in codes, codes
|
|
blocked = next(r for r in results if r.status_code == 429)
|
|
# Retry-After header per SPEC §6.5.
|
|
assert "retry-after" in {h.lower() for h in blocked.headers.keys()}
|
|
err = blocked.json()["error"]
|
|
assert err["code"] == "rate_limited"
|
|
assert err["request_id"]
|
|
# The successful responses carry the §6.5 rate-limit headers.
|
|
ok = next(r for r in results if r.status_code == 200)
|
|
assert "X-RateLimit-Limit-Requests" in ok.headers
|
|
assert "X-RateLimit-Remaining-Requests" in ok.headers
|
|
|
|
|
|
class _BrokenRedis:
|
|
"""Stand-in Redis client whose every call raises ``RedisError``.
|
|
|
|
Drop-in for ``app.state.redis`` to model "Redis unreachable" deterministically
|
|
without tearing down the testcontainer (which subsequent tests need).
|
|
"""
|
|
|
|
def __getattr__(self, _name: str): # type: ignore[no-untyped-def]
|
|
from redis.exceptions import RedisError
|
|
|
|
async def _fail(*_a: object, **_k: object) -> None:
|
|
raise RedisError("simulated Redis outage")
|
|
|
|
def _register_script(*_a: object, **_k: object): # type: ignore[no-untyped-def]
|
|
class _Script:
|
|
async def __call__(self, *_aa: object, **_kk: object) -> None:
|
|
raise RedisError("simulated Redis outage")
|
|
|
|
return _Script()
|
|
|
|
if _name == "register_script":
|
|
return _register_script
|
|
return _fail
|
|
|
|
|
|
async def test_redis_down_fails_closed_with_503(
|
|
integration_app: IntegrationApp, client: httpx.AsyncClient, api_key: IntegrationKey
|
|
) -> None:
|
|
# Warm the auth cache so the auth middleware doesn't need Redis to admit us
|
|
# (this test targets the *rate-limit* fail-closed path specifically).
|
|
warm = await client.post(
|
|
"/api/chat",
|
|
headers={"Authorization": f"Bearer {api_key.full_key}"},
|
|
json={
|
|
"model": "llama3.1:8b",
|
|
"messages": [{"role": "user", "content": "hi"}],
|
|
"stream": False,
|
|
},
|
|
)
|
|
assert warm.status_code == 200
|
|
|
|
# Swap the Redis client for a stub that raises on every call. The
|
|
# SlidingWindowLimiter / BudgetCounter / ConcurrencyLimiter must catch
|
|
# the RedisError and surface DependencyUnavailableError => 503.
|
|
original = integration_app.app.state.redis
|
|
integration_app.app.state.redis = _BrokenRedis()
|
|
try:
|
|
resp = await client.post(
|
|
"/api/chat",
|
|
headers={"Authorization": f"Bearer {api_key.full_key}"},
|
|
json={
|
|
"model": "llama3.1:8b",
|
|
"messages": [{"role": "user", "content": "hi"}],
|
|
"stream": False,
|
|
},
|
|
)
|
|
# The hard requirement: fail closed, never 200 (SPEC §4.4).
|
|
assert resp.status_code != 200, resp.text
|
|
assert resp.status_code >= 400
|
|
body = resp.json()
|
|
assert "error" in body
|
|
# No upstream/internal leakage.
|
|
msg = " ".join(str(v) for v in body["error"].values()).lower()
|
|
for needle in ("redis", "asyncpg", "traceback"):
|
|
assert needle not in msg, body
|
|
finally:
|
|
integration_app.app.state.redis = original
|