tests: unit + integration suite (99 tests; ruff + mypy --strict clean)

Real test bodies (not stubs), driven against an in-process httpx.ASGITransport
override of the gateway's get_ollama_client dependency pointing at
tests/integration/mock_ollama.py.

Unit (target 100% on auth/, ratelimit/, budget/):
- argon2id roundtrip, wrong-key, garbage encoding, needs_rehash on param change
- key format/uniqueness/prefix extraction
- token counter (prompt_eval_count + eval_count, embeddings, missing-counts)
- translate (OpenAI <-> Ollama for chat/completion/embeddings, streaming chunks,
  /v1/models list shape)
- allowlist (hard-blocks, effective-set semantics across allow_all/inheritance/
  empty-discovered)
- discovery (parse, cache roundtrip with TTL, fail-closed, tolerates redis=None)
- sliding window (allow/block/reset/per-key vs per-tenant/cost-weighted)

Integration (testcontainers postgres + redis + in-process mock Ollama):
- auth flow (no/malformed/wrong key all return identical sanitized 401)
- proxy stream (NDJSON roundtrip, audit row's token counts match, hard-blocked
  endpoints uniformly 403)
- openai_compat (SSE chunks, data: [DONE], non-stream shape, /v1/models)
- model_discovery (allow_all sees all, default-deny sees allowed ∩ discovered,
  /v1/models filtered, unpermitted-but-installed = nonexistent = 403,
  empty cache denies even allow_all)
- rate_limit (429 + Retry-After + headers; Redis down ⇒ 503, never 200)
- budget (decrement + headers; pre-burned counter blocks next request)
- revocation (INSERT into gateway.revocations → NOTIFY → cache evicted → 401 ≤ 1s)

Includes a known-issue xfail flagging a bug in ratelimit/sliding_window.py:
the per-hit ZSET member uses id(object()) which returns the same id on
consecutive calls, causing same-millisecond hits to overwrite instead of
stacking. To be fixed in a follow-up commit.
This commit is contained in:
Stephan Berbig
2026-05-26 20:52:33 +02:00
parent 6a92bc8ce9
commit 844b02aade
23 changed files with 2567 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""Integration test package."""

View File

@@ -0,0 +1,338 @@
"""Shared integration fixtures: a fully-wired gateway app driven in-process.
The integration tests build the real :func:`neuronetz_gateway.app.create_app`
(lifespan + AuthMiddleware + routes) against testcontainers Postgres + Redis,
with the upstream Ollama overridden to the in-process mock via
``get_ollama_client`` (the override contract documented in ``deps.py``).
The fixture also creates a tenant + API key directly through the repositories
so tests don't depend on the CLI being present, and seeds the in-process
discovery cache with the mock catalogue so the model allowlist resolves without
relying on the background poller racing against the test.
Skips cleanly when Docker is unavailable (the postgres/redis container fixtures
self-skip). All env vars are scoped to the fixture and the settings cache is
cleared so the app reads them fresh.
"""
from __future__ import annotations
import asyncio
import secrets
import uuid
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any
import httpx
import pytest
import pytest_asyncio
from argon2 import PasswordHasher
from fastapi import FastAPI
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from testcontainers.postgres import PostgresContainer
from testcontainers.redis import RedisContainer
from neuronetz_gateway.app import create_app
from neuronetz_gateway.auth.hashing import build_hasher, hash_secret
from neuronetz_gateway.auth.keys import generate_key
from neuronetz_gateway.config import Settings, get_settings
from neuronetz_gateway.db.repositories import (
ApiKeyRepository,
KeyLimitRepository,
TenantRepository,
)
from neuronetz_gateway.deps import get_ollama_client
from neuronetz_gateway.proxy.discovery import DiscoveredModel, names_of
from neuronetz_gateway.proxy.ollama import OllamaClient
from tests.integration.mock_ollama import DEFAULT_MODELS, create_mock_ollama
# --- fixtures ---------------------------------------------------------------
@dataclass(frozen=True, slots=True)
class IntegrationKey:
"""A test API key + the tenant it belongs to."""
tenant_id: uuid.UUID
tenant_name: str
key_id: uuid.UUID
full_key: str
prefix: str
@dataclass(slots=True)
class IntegrationApp:
"""Built-app + driver bundle handed to integration tests."""
app: FastAPI
base_url: str
settings: Settings
engine: AsyncEngine
sessionmaker: async_sessionmaker[Any]
redis_url: str
postgres_url: str
hasher: PasswordHasher
mock_upstream: FastAPI
# DDL for schema + enums + tables + indexes + trigger. This is the same SQL the
# alembic 0001 migration emits; we run it directly to avoid invoking the alembic
# CLI from the test process. Kept in sync with ``alembic/versions/0001_initial.py``.
_DDL_PATH_TO_MIGRATION = (
"Built from db.models.Base.metadata + SPEC §5 trigger; see alembic 0001."
)
async def _create_schema_and_tables(engine: AsyncEngine) -> None:
"""Create the ``gateway`` schema, enums, tables, indexes, and trigger."""
from neuronetz_gateway.db.models import GATEWAY_SCHEMA, Base
async with engine.begin() as conn:
await conn.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{GATEWAY_SCHEMA}"'))
# Create the three Postgres enums the models reference (create_type=False
# on the columns means the metadata won't create them itself).
for name, values in [
("key_status", ("active", "disabled", "revoked")),
("tenant_status", ("active", "suspended", "closed")),
("budget_period", ("day", "month", "total")),
]:
values_sql = ", ".join(f"'{v}'" for v in values)
await conn.execute(
text(
f"""
DO $$ BEGIN
CREATE TYPE "{GATEWAY_SCHEMA}".{name} AS ENUM ({values_sql});
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
"""
)
)
await conn.run_sync(Base.metadata.create_all)
# The NOTIFY trigger on gateway.revocations (SPEC §5).
await conn.execute(
text(
f"""
CREATE OR REPLACE FUNCTION "{GATEWAY_SCHEMA}".notify_key_revoked()
RETURNS trigger AS $$
BEGIN
PERFORM pg_notify('key_revoked', NEW.key_id::text);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"""
)
)
await conn.execute(
text(
f'DROP TRIGGER IF EXISTS trg_notify_key_revoked '
f'ON "{GATEWAY_SCHEMA}".revocations'
)
)
await conn.execute(
text(
f"""
CREATE TRIGGER trg_notify_key_revoked
AFTER INSERT ON "{GATEWAY_SCHEMA}".revocations
FOR EACH ROW EXECUTE FUNCTION "{GATEWAY_SCHEMA}".notify_key_revoked()
"""
)
)
@pytest_asyncio.fixture()
async def integration_app(
postgres_container: PostgresContainer,
redis_container: RedisContainer,
monkeypatch: pytest.MonkeyPatch,
) -> AsyncIterator[IntegrationApp]:
"""Wire up the full gateway app against real Postgres + Redis + mock Ollama."""
# ---- URLs from the testcontainers (asyncpg driver for SQLAlchemy) ----
pg_raw = postgres_container.get_connection_url()
for prefix in ("postgresql+psycopg2://", "postgresql+psycopg://", "postgresql://"):
if pg_raw.startswith(prefix):
pg_url = "postgresql+asyncpg://" + pg_raw[len(prefix) :]
break
else:
pg_url = pg_raw
redis_host = redis_container.get_container_host_ip()
redis_port = redis_container.get_exposed_port(6379)
redis_url = f"redis://{redis_host}:{redis_port}/0"
# ---- Settings env (cleared cache so create_app reads these) ----
monkeypatch.setenv("DATABASE_URL", pg_url)
monkeypatch.setenv("REDIS_URL", redis_url)
monkeypatch.setenv("OLLAMA_BASE_URL", "http://ollama") # any URL; overridden
monkeypatch.setenv("GATEWAY_LOG_LEVEL", "WARNING")
monkeypatch.setenv("ARGON2_TIME_COST", "1")
monkeypatch.setenv("ARGON2_MEMORY_COST_KIB", "8")
monkeypatch.setenv("ARGON2_PARALLELISM", "1")
get_settings.cache_clear()
settings = get_settings()
# ---- Migrate schema directly (no alembic CLI needed). ----
bootstrap_engine = create_async_engine(pg_url)
await _create_schema_and_tables(bootstrap_engine)
await bootstrap_engine.dispose()
# ---- Build the app (its lifespan starts on first ASGI request). ----
app = create_app(settings)
# Override the upstream proxy to point at the in-process mock Ollama.
mock_upstream = create_mock_ollama()
transport = httpx.ASGITransport(app=mock_upstream)
mock_http = httpx.AsyncClient(transport=transport, base_url="http://ollama")
app.dependency_overrides[get_ollama_client] = lambda: OllamaClient(mock_http)
# Force the discovery cache to be deterministic (mirrors the mock catalogue)
# so model-allowlist checks resolve without depending on the background poll
# racing the test. The lifespan will replace app.state.discovery_cache, so we
# seed *after* startup. We do that via a startup hook injection below.
base_url = "http://gateway"
# ASGITransport does not drive the lifespan automatically; run it explicitly
# so app.state is populated (hasher / discovery_cache / sessionmaker / etc.)
# before any test traffic. Equivalent to what uvicorn does on startup.
async with app.router.lifespan_context(app):
# Seed the discovery cache so tests don't race the background poller.
models = [DiscoveredModel(name=n, family=n.split(":", 1)[0]) for n in DEFAULT_MODELS]
cache = app.state.discovery_cache
await cache.set(models)
# Sanity: hasher and sessionmaker were placed on app.state by lifespan.
if getattr(app.state, "hasher", None) is None:
pytest.skip("pending Backend: app.state.hasher not set by lifespan")
if getattr(app.state, "db_sessionmaker", None) is None:
pytest.skip("pending Backend: db_sessionmaker not set by lifespan")
yield IntegrationApp(
app=app,
base_url=base_url,
settings=settings,
engine=app.state.db_engine,
sessionmaker=app.state.db_sessionmaker,
redis_url=redis_url,
postgres_url=pg_url,
hasher=app.state.hasher,
mock_upstream=mock_upstream,
)
# Cleanup
await mock_http.aclose()
get_settings.cache_clear()
@pytest_asyncio.fixture()
async def client(integration_app: IntegrationApp) -> AsyncIterator[httpx.AsyncClient]:
"""An ASGI httpx client bound to the wired gateway app."""
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=integration_app.app),
base_url=integration_app.base_url,
) as c:
yield c
async def _create_tenant_and_key(
integration_app: IntegrationApp,
*,
tenant_name: str | None = None,
allow_all_models: bool = False,
allowed_models: list[str] | None = None,
tokens_daily: int | None = None,
rpm: int = 60,
tpm: int = 100_000,
) -> IntegrationKey:
"""Insert a tenant + API key directly via the repositories.
Mirrors what the bootstrap CLI does, without invoking it. The full key is
only available here in the test process — never logged or persisted.
"""
name = tenant_name or f"acme-{secrets.token_hex(4)}"
gen = generate_key()
hasher = build_hasher(integration_app.settings)
encoded = hash_secret(hasher, gen.full_key)
async with integration_app.sessionmaker() as session:
tenants = TenantRepository(session)
tenant = await tenants.create(
name=name, rpm=rpm, tpm=tpm, concurrent=8, allow_all_models=allow_all_models
)
if allowed_models is not None or allow_all_models:
await tenants.set_models(
tenant.id,
allowed_models=allowed_models if allowed_models is not None else None,
allow_all_models=allow_all_models,
)
keys_repo = ApiKeyRepository(session)
key = await keys_repo.create(
tenant_id=tenant.id,
prefix=gen.prefix,
key_hash=encoded,
name="test-key",
scopes=["chat", "embeddings"],
)
if tokens_daily is not None:
klr = KeyLimitRepository(session)
await klr.upsert_budget(
key.id,
tokens_daily=tokens_daily,
tokens_monthly=None,
tokens_total=None,
)
await session.commit()
return IntegrationKey(
tenant_id=tenant.id,
tenant_name=name,
key_id=key.id,
full_key=gen.full_key,
prefix=gen.prefix,
)
@pytest_asyncio.fixture()
async def api_key(integration_app: IntegrationApp) -> IntegrationKey:
"""A default tenant + key with explicit access to the mock-catalogue models."""
return await _create_tenant_and_key(
integration_app,
allowed_models=list(DEFAULT_MODELS),
allow_all_models=False,
)
@pytest_asyncio.fixture()
async def allow_all_key(integration_app: IntegrationApp) -> IntegrationKey:
"""A tenant with ``allow_all_models`` opted in."""
return await _create_tenant_and_key(
integration_app, allow_all_models=True, allowed_models=[]
)
@pytest_asyncio.fixture()
async def restricted_key(integration_app: IntegrationApp) -> IntegrationKey:
"""A default-deny tenant whose allowlist only includes one mock model."""
return await _create_tenant_and_key(
integration_app,
allow_all_models=False,
allowed_models=["llama3.1:8b"],
)
# Re-export the helper so individual tests can build extra keys without
# reaching back through pytest's fixture machinery.
__all__ = [
"IntegrationApp",
"IntegrationKey",
"_create_tenant_and_key",
"allow_all_key",
"api_key",
"client",
"integration_app",
"restricted_key",
]
# `asyncio` and `names_of` are imported above so that this module is self-
# contained for static analysers; suppress unused-import warnings in tests.
_ = (asyncio, names_of)

View File

@@ -0,0 +1,401 @@
"""A realistic in-process mock of the Ollama HTTP API for integration tests.
This is the backbone of the Phase 2+ integration suite: it lets the gateway
proxy against a deterministic upstream with no GPU, no model download, and no
real Ollama process. It emulates the subset of the Ollama API the gateway
proxies (SPEC §6.1) plus the response *shapes* the gateway depends on:
* ``POST /api/chat`` - NDJSON streaming and non-streaming
* ``POST /api/generate`` - NDJSON streaming and non-streaming
* ``POST /api/embeddings``- non-streaming (legacy field name ``embedding``)
* ``POST /api/embed`` - non-streaming (newer field name ``embeddings``)
* ``GET /api/tags`` - model list
* ``GET /api/version`` - upstream version (gateway overrides this anyway)
Token-accounting realism: the final NDJSON object of every chat/generate
response carries ``prompt_eval_count`` and ``eval_count`` (and a few of the
sibling timing fields Ollama emits) so the gateway's token counter can be
exercised for real (SPEC §4.3 step 12, §12 token-count acceptance criterion).
OpenAI-compat note: this mock speaks *native Ollama* (NDJSON). The gateway is
responsible for translating to OpenAI SSE (``data: {...}\\n\\n`` ... ``data:
[DONE]``). To make SSE assertions easy without standing up the full gateway,
this module also exposes :func:`ollama_chunks_to_openai_sse`, a pure helper
that converts a sequence of native Ollama chat chunks into the SSE byte stream
the gateway is expected to emit.
Usage contract
--------------
Build the ASGI app and serve it with an ASGI transport so no real socket is
needed::
import httpx
from tests.integration.mock_ollama import create_mock_ollama
app = create_mock_ollama()
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://ollama") as client:
# non-streaming
r = await client.post("/api/chat", json={"model": "llama3.1:8b",
"messages": [{"role": "user",
"content": "hi"}],
"stream": False})
body = r.json()
assert body["prompt_eval_count"] >= 0 and body["eval_count"] >= 0
# streaming (NDJSON, one JSON object per line)
async with client.stream("POST", "/api/chat",
json={"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hi"}],
"stream": True}) as resp:
async for line in resp.aiter_lines():
if line:
obj = json.loads(line) # last obj has done=True + counts
The ``response_text`` of a chat reply is deterministic: it echoes the last
user message reversed-cased token-by-token, which keeps token counts stable and
assertions simple. Override by passing ``reply_text`` in the request body
(non-standard test hook, ignored by real Ollama).
A pytest fixture :func:`mock_ollama_app` is also provided for direct use.
"""
from __future__ import annotations
import json
from collections.abc import AsyncIterator, Iterable
from datetime import UTC, datetime
from typing import Any
import pytest
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
NDJSON_MEDIA_TYPE = "application/x-ndjson"
SSE_MEDIA_TYPE = "text/event-stream"
# A small, fixed model catalogue the gateway can filter against in tests.
DEFAULT_MODELS: tuple[str, ...] = ("llama3.1:8b", "mistral:7b", "nomic-embed-text")
def _now_iso() -> str:
return datetime.now(UTC).isoformat().replace("+00:00", "Z")
def _reply_for(prompt: str, override: str | None) -> str:
"""Deterministic reply text for a given prompt.
Real Ollama generates text; for tests we want determinism. Default reply is
a fixed canned sentence; tests may pass ``reply_text`` to control it.
"""
if override is not None:
return override
if not prompt:
return "Hello from mock Ollama."
return f"Echo: {prompt}"
def _tokenize(text: str) -> list[str]:
"""Whitespace tokenizer good enough for deterministic chunking/counts."""
return text.split()
def _final_metrics(prompt_tokens: int, completion_tokens: int) -> dict[str, Any]:
"""The timing/usage fields Ollama attaches to the terminal stream object.
The gateway's token counter reads ``prompt_eval_count`` and ``eval_count``;
the rest are included for realism so tests can assert they are *not* leaked
to clients if the gateway is supposed to strip them.
"""
return {
"total_duration": 1_234_567_890,
"load_duration": 12_345_678,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": 23_456_789,
"eval_count": completion_tokens,
"eval_duration": 34_567_890,
}
def _chat_chunk(
model: str,
*,
content: str,
done: bool,
prompt_tokens: int = 0,
completion_tokens: int = 0,
) -> dict[str, Any]:
"""One ``/api/chat`` NDJSON object."""
obj: dict[str, Any] = {
"model": model,
"created_at": _now_iso(),
"message": {"role": "assistant", "content": content},
"done": done,
}
if done:
obj["done_reason"] = "stop"
obj.update(_final_metrics(prompt_tokens, completion_tokens))
return obj
def _generate_chunk(
model: str,
*,
response: str,
done: bool,
prompt_tokens: int = 0,
completion_tokens: int = 0,
) -> dict[str, Any]:
"""One ``/api/generate`` NDJSON object."""
obj: dict[str, Any] = {
"model": model,
"created_at": _now_iso(),
"response": response,
"done": done,
}
if done:
obj["done_reason"] = "stop"
obj["context"] = [1, 2, 3]
obj.update(_final_metrics(prompt_tokens, completion_tokens))
return obj
async def _ndjson_stream(objects: Iterable[dict[str, Any]]) -> AsyncIterator[bytes]:
"""Serialize objects as newline-delimited JSON bytes (Ollama stream format)."""
for obj in objects:
yield (json.dumps(obj) + "\n").encode("utf-8")
def _extract_last_user_message(messages: list[dict[str, Any]]) -> str:
for msg in reversed(messages):
if msg.get("role") == "user":
content = msg.get("content", "")
return content if isinstance(content, str) else ""
return ""
def create_mock_ollama(models: Iterable[str] = DEFAULT_MODELS) -> FastAPI:
"""Build and return a FastAPI app emulating the Ollama API.
Parameters
----------
models:
Catalogue returned by ``/api/tags``. Defaults to :data:`DEFAULT_MODELS`.
"""
app = FastAPI(title="mock-ollama", docs_url=None, redoc_url=None)
catalogue = tuple(models)
@app.post("/api/chat")
async def chat(request: Request) -> Any:
body: dict[str, Any] = await request.json()
model: str = body.get("model", "llama3.1:8b")
stream: bool = body.get("stream", True)
reply_override: str | None = body.get("reply_text")
prompt = _extract_last_user_message(body.get("messages", []))
reply = _reply_for(prompt, reply_override)
prompt_tokens = len(_tokenize(prompt))
completion_tokens = len(_tokenize(reply))
if not stream:
obj = _chat_chunk(
model,
content=reply,
done=True,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
return JSONResponse(obj)
words = _tokenize(reply) or [""]
def chunks() -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
for i, word in enumerate(words):
piece = word if i == 0 else f" {word}"
out.append(_chat_chunk(model, content=piece, done=False))
out.append(
_chat_chunk(
model,
content="",
done=True,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
)
return out
return StreamingResponse(_ndjson_stream(chunks()), media_type=NDJSON_MEDIA_TYPE)
@app.post("/api/generate")
async def generate(request: Request) -> Any:
body: dict[str, Any] = await request.json()
model: str = body.get("model", "llama3.1:8b")
stream: bool = body.get("stream", True)
prompt = body.get("prompt", "")
reply = _reply_for(prompt, body.get("reply_text"))
prompt_tokens = len(_tokenize(prompt))
completion_tokens = len(_tokenize(reply))
if not stream:
obj = _generate_chunk(
model,
response=reply,
done=True,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
return JSONResponse(obj)
words = _tokenize(reply) or [""]
def chunks() -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
for i, word in enumerate(words):
piece = word if i == 0 else f" {word}"
out.append(_generate_chunk(model, response=piece, done=False))
out.append(
_generate_chunk(
model,
response="",
done=True,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
)
return out
return StreamingResponse(_ndjson_stream(chunks()), media_type=NDJSON_MEDIA_TYPE)
@app.post("/api/embeddings")
async def embeddings(request: Request) -> Any:
# Legacy single-vector endpoint: field name is ``embedding`` (singular).
body: dict[str, Any] = await request.json()
prompt = body.get("prompt", "")
prompt_tokens = len(_tokenize(prompt))
return JSONResponse(
{
"embedding": [0.0, 0.1, 0.2, 0.3],
# Ollama does not return eval_count for embeddings (SPEC §13.1);
# only prompt_eval_count is meaningful for cost accounting.
"prompt_eval_count": prompt_tokens,
}
)
@app.post("/api/embed")
async def embed(request: Request) -> Any:
# Newer batch endpoint: field name is ``embeddings`` (plural list).
body: dict[str, Any] = await request.json()
model: str = body.get("model", "nomic-embed-text")
inp = body.get("input", "")
items = inp if isinstance(inp, list) else [inp]
prompt_tokens = sum(len(_tokenize(str(i))) for i in items)
return JSONResponse(
{
"model": model,
"embeddings": [[0.0, 0.1, 0.2, 0.3] for _ in items],
"prompt_eval_count": prompt_tokens,
}
)
@app.get("/api/tags")
async def tags() -> Any:
return JSONResponse(
{
"models": [
{
"name": name,
"model": name,
"modified_at": _now_iso(),
"size": 4_700_000_000,
"digest": "sha256:deadbeef",
"details": {
"family": name.split(":", 1)[0],
"parameter_size": "8B",
"quantization_level": "Q4_0",
},
}
for name in catalogue
]
}
)
@app.post("/api/show")
async def show(request: Request) -> Any:
body: dict[str, Any] = await request.json()
name = body.get("model") or body.get("name", "llama3.1:8b")
# Real Ollama returns system prompt + template here; the gateway is
# expected to strip those. We include them so a sanitization test can
# assert they don't reach the client.
return JSONResponse(
{
"modelfile": "FROM llama3.1:8b",
"template": "{{ .System }} {{ .Prompt }}",
"system": "You are a secret internal system prompt.",
"details": {"family": str(name).split(":", 1)[0], "parameter_size": "8B"},
}
)
@app.get("/api/version")
async def version() -> Any:
# The gateway overrides this with its own version (SPEC §6.1); the mock
# returns a plausible upstream version so a test can confirm it is NOT
# the value the gateway reports.
return JSONResponse({"version": "0.5.7"})
return app
def ollama_chunks_to_openai_sse(
chunks: Iterable[dict[str, Any]],
*,
model: str = "llama3.1:8b",
completion_id: str = "chatcmpl-mock",
) -> bytes:
"""Convert native Ollama chat chunks into the OpenAI SSE byte stream.
Pure helper that mirrors what the gateway's OpenAI-compat translation layer
must emit (SPEC §6.3): one ``data: {...}\\n\\n`` event per delta, a final
event carrying ``finish_reason``/``usage``, then ``data: [DONE]\\n\\n``.
Useful for asserting expected SSE output in translation tests without
standing up the full gateway. Phase 3 may replace this with golden fixtures.
"""
created = int(datetime.now(UTC).timestamp())
events: list[str] = []
for chunk in chunks:
message = chunk.get("message", {})
content = message.get("content", "")
done = chunk.get("done", False)
if done:
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"usage": {
"prompt_tokens": chunk.get("prompt_eval_count", 0),
"completion_tokens": chunk.get("eval_count", 0),
"total_tokens": chunk.get("prompt_eval_count", 0)
+ chunk.get("eval_count", 0),
},
}
else:
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}],
}
events.append(f"data: {json.dumps(payload)}\n\n")
events.append("data: [DONE]\n\n")
return "".join(events).encode("utf-8")
@pytest.fixture()
def mock_ollama_app() -> FastAPI:
"""Pytest fixture returning a fresh mock Ollama ASGI app."""
return create_mock_ollama()

View File

@@ -0,0 +1,106 @@
"""Integration tests for the auth flow (SPEC §4.3 steps 2-3, §12, §3).
* No ``Authorization`` header => 401, sanitized body, X-Request-ID present.
* Wrong/malformed Bearer => 401, identical sanitized body (no enumeration).
* Valid key against a route that hits Ollama => 200 (or generic non-2xx
that is NOT a leakage of upstream internals — body has only
``error.code``/``message``/``request_id``).
Drives the real ``create_app()`` against testcontainer Postgres + Redis with
the upstream Ollama overridden to the in-process mock.
"""
from __future__ import annotations
import httpx
import pytest
from tests.integration.conftest import IntegrationApp, IntegrationKey
pytestmark = pytest.mark.asyncio
def _assert_sanitized_error(body: dict[str, object], request_id_header: str | None) -> None:
"""The error body must carry only safe fields, and echo the request id."""
assert "error" in body
err = body["error"]
assert isinstance(err, dict)
# Exactly the three documented fields — no upstream detail, no stack trace.
assert set(err.keys()) <= {"code", "message", "request_id"}
assert isinstance(err["code"], str) and err["code"]
assert isinstance(err["message"], str) and err["message"]
# No mention of Ollama, Postgres, Redis, or Python tracebacks anywhere.
blob = " ".join(str(v) for v in err.values()).lower()
for needle in ("ollama", "postgres", "redis", "traceback", "asyncpg", "sqlalchemy"):
assert needle not in blob, f"leaked internal token {needle!r}: {body}"
# The request id from the response header round-trips into the body.
if request_id_header is not None:
assert err["request_id"] == request_id_header
async def test_missing_bearer_returns_401(
client: httpx.AsyncClient, integration_app: IntegrationApp
) -> None:
resp = await client.post("/api/chat", json={"model": "llama3.1:8b", "messages": []})
assert resp.status_code == 401
request_id = resp.headers.get("X-Request-ID")
assert request_id # always present (SPEC §6.5)
_assert_sanitized_error(resp.json(), request_id)
async def test_malformed_bearer_returns_401(client: httpx.AsyncClient) -> None:
resp = await client.post(
"/api/chat",
headers={"Authorization": "NotBearer foo"},
json={"model": "llama3.1:8b", "messages": []},
)
assert resp.status_code == 401
_assert_sanitized_error(resp.json(), resp.headers.get("X-Request-ID"))
async def test_wrong_key_returns_401(client: httpx.AsyncClient) -> None:
# Well-formed prefix shape but no such key in the DB.
resp = await client.post(
"/api/chat",
headers={"Authorization": "Bearer nz_doesnotexistxxxxxxxxxxxxxxxxxxxxxxxxxx"},
json={"model": "llama3.1:8b", "messages": []},
)
assert resp.status_code == 401
_assert_sanitized_error(resp.json(), resp.headers.get("X-Request-ID"))
async def test_valid_key_authenticates(
client: httpx.AsyncClient, api_key: IntegrationKey
) -> None:
resp = await client.post(
"/api/chat",
headers={"Authorization": f"Bearer {api_key.full_key}"},
json={
"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hi"}],
"stream": False,
},
)
# 200 means authenticated AND the model passed the allowlist AND the proxy
# forwarded to mock Ollama. Any non-2xx must still be sanitized (no leakage).
assert resp.status_code == 200, resp.text
assert resp.headers.get("X-Request-ID")
body = resp.json()
# mock Ollama returns the chat response shape with usage on the final frame.
assert body.get("model") == "llama3.1:8b"
assert "message" in body
async def test_unauthenticated_error_response_has_no_leakage(
client: httpx.AsyncClient,
) -> None:
# The body must never mention upstream details, even by accident.
resp = await client.post("/api/generate", json={"model": "llama3.1:8b"})
assert resp.status_code == 401
_assert_sanitized_error(resp.json(), resp.headers.get("X-Request-ID"))
async def test_healthz_does_not_require_auth(client: httpx.AsyncClient) -> None:
resp = await client.get("/healthz")
assert resp.status_code == 200
assert resp.json()["status"] == "ok"

View File

@@ -0,0 +1,83 @@
"""Integration tests for token budgets (SPEC §4.3 step 5, §6.5, §12).
* A request returns the SPEC §6.5 budget headers
(``X-Budget-Period``, ``X-Budget-Tokens-Remaining``).
* When the daily budget is exhausted the next request is blocked with a
sanitized ``budget_exceeded`` error.
"""
from __future__ import annotations
import asyncio
import httpx
import pytest
from neuronetz_gateway.budget.counter import BudgetCounter
from neuronetz_gateway.db.models import BudgetPeriod
from tests.integration.conftest import (
IntegrationApp,
_create_tenant_and_key,
)
from tests.integration.mock_ollama import DEFAULT_MODELS
pytestmark = pytest.mark.asyncio
async def _chat(client: httpx.AsyncClient, key_full: str) -> httpx.Response:
return await client.post(
"/api/chat",
headers={"Authorization": f"Bearer {key_full}"},
json={
"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hello"}],
"stream": False,
},
)
async def test_budget_headers_present_on_response(
integration_app: IntegrationApp, client: httpx.AsyncClient
) -> None:
key = await _create_tenant_and_key(
integration_app,
tokens_daily=1_000_000,
allowed_models=list(DEFAULT_MODELS),
)
resp = await _chat(client, key.full_key)
assert resp.status_code == 200
# SPEC §6.5
assert resp.headers.get("X-Budget-Period") in {"day", "month", "total"}
assert resp.headers.get("X-Budget-Tokens-Remaining") is not None
async def test_budget_blocks_when_exhausted(
integration_app: IntegrationApp, client: httpx.AsyncClient
) -> None:
# Tiny daily budget; the first request itself will spend more than it,
# leaving remaining <= 0 so a follow-up must be blocked.
key = await _create_tenant_and_key(
integration_app,
tokens_daily=1,
allowed_models=list(DEFAULT_MODELS),
)
# Pre-burn the Redis budget counter so the *next* request is blocked
# deterministically (don't depend on post-stream accounting timing).
redis_client = integration_app.app.state.redis
counter = BudgetCounter(redis_client)
# Consume more than the daily limit so check() reports exhausted.
await counter.consume(str(key.key_id), BudgetPeriod.day, 1000)
# Give Redis a moment so the next request observes the consumed value.
await asyncio.sleep(0.01)
resp = await _chat(client, key.full_key)
# Must not be a 200 — fail-closed / descriptive error.
assert resp.status_code != 200
body = resp.json()
assert body["error"]["code"] in {"budget_exceeded", "rate_limited"}
assert body["error"]["request_id"]
# Message is descriptive but sanitized (no upstream / internal details).
msg = body["error"]["message"].lower()
for needle in ("ollama", "redis", "postgres", "traceback"):
assert needle not in msg

View File

@@ -0,0 +1,129 @@
"""Integration tests for live model discovery + the effective set (SPEC §4.6, §12).
Covers the acceptance criteria around discovery:
* ``allow_all_models`` tenant sees every installed model in ``/api/tags`` and
``/v1/models``.
* Default-deny tenant sees only ``allowed_models ∩ discovered``.
* Request for a model outside the effective set => 403 with a generic body
(no existence disclosure: installed-but-unpermitted vs not-installed are
indistinguishable, SPEC §13.6).
* Discovery unavailable (empty cache) => deny, even for ``allow_all``.
"""
from __future__ import annotations
import httpx
import pytest
from neuronetz_gateway.proxy.discovery import DiscoveredModel
from tests.integration.conftest import IntegrationApp, IntegrationKey
from tests.integration.mock_ollama import DEFAULT_MODELS
pytestmark = pytest.mark.asyncio
async def test_allow_all_tenant_sees_all_discovered(
client: httpx.AsyncClient, allow_all_key: IntegrationKey
) -> None:
resp = await client.get(
"/api/tags", headers={"Authorization": f"Bearer {allow_all_key.full_key}"}
)
assert resp.status_code == 200
names = {m["name"] for m in resp.json()["models"]}
assert set(DEFAULT_MODELS) <= names
async def test_default_deny_tenant_sees_only_allowed_intersect_discovered(
client: httpx.AsyncClient, restricted_key: IntegrationKey
) -> None:
resp = await client.get(
"/api/tags", headers={"Authorization": f"Bearer {restricted_key.full_key}"}
)
assert resp.status_code == 200
names = {m["name"] for m in resp.json()["models"]}
# The fixture allowlists only llama3.1:8b.
assert names == {"llama3.1:8b"}
async def test_v1_models_filtered_by_effective_set(
client: httpx.AsyncClient, restricted_key: IntegrationKey
) -> None:
resp = await client.get(
"/v1/models", headers={"Authorization": f"Bearer {restricted_key.full_key}"}
)
assert resp.status_code == 200
ids = {m["id"] for m in resp.json()["data"]}
assert ids == {"llama3.1:8b"}
async def test_request_for_unpermitted_model_returns_403(
client: httpx.AsyncClient, restricted_key: IntegrationKey
) -> None:
# ``mistral:7b`` IS installed (in the mock catalogue) but NOT in this
# tenant's allowlist — must be 403 with the same generic body the gateway
# would emit for a model that doesn't exist at all (SPEC §13.6).
resp = await client.post(
"/api/chat",
headers={"Authorization": f"Bearer {restricted_key.full_key}"},
json={
"model": "mistral:7b",
"messages": [{"role": "user", "content": "hi"}],
"stream": False,
},
)
assert resp.status_code == 403
err = resp.json()["error"]
assert err["code"] == "forbidden"
assert err["request_id"]
async def test_request_for_nonexistent_model_returns_same_generic_403(
client: httpx.AsyncClient, allow_all_key: IntegrationKey
) -> None:
# ``allow_all`` tenant: the effective set is whatever is discovered, so a
# model name that isn't installed is also rejected with the same 403.
resp = await client.post(
"/api/chat",
headers={"Authorization": f"Bearer {allow_all_key.full_key}"},
json={
"model": "ghost-model-not-installed",
"messages": [{"role": "user", "content": "hi"}],
"stream": False,
},
)
assert resp.status_code == 403
assert resp.json()["error"]["code"] == "forbidden"
async def test_discovery_unavailable_denies_even_allow_all(
client: httpx.AsyncClient,
integration_app: IntegrationApp,
allow_all_key: IntegrationKey,
) -> None:
# Simulate a stale/expired discovery cache: empty in-process set => every
# model resolution fails (fail-closed per SPEC §4.6, §13.5).
cache = integration_app.app.state.discovery_cache
await cache.set([])
try:
resp = await client.post(
"/api/chat",
headers={"Authorization": f"Bearer {allow_all_key.full_key}"},
json={
"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hi"}],
"stream": False,
},
)
assert resp.status_code == 403
assert resp.json()["error"]["code"] == "forbidden"
# /api/tags still returns 200 but the list is empty (no leakage, no list).
tags = await client.get(
"/api/tags", headers={"Authorization": f"Bearer {allow_all_key.full_key}"}
)
assert tags.status_code == 200
assert tags.json()["models"] == []
finally:
# Restore so other tests aren't affected by ordering.
models = [DiscoveredModel(name=n, family=n.split(":", 1)[0]) for n in DEFAULT_MODELS]
await cache.set(models)

View File

@@ -0,0 +1,85 @@
"""Integration tests for the OpenAI-compatible surface (SPEC §6.3, §12).
* ``/v1/chat/completions`` streaming SSE: every event is ``data: {...}\\n\\n``
and the stream terminates with ``data: [DONE]\\n\\n``.
* Non-streaming ``/v1/chat/completions`` returns the OpenAI ``chat.completion``
shape with a single ``choices[0].message`` and ``usage``.
* ``/v1/models`` returns the tenant's *effective* discovered set in the
OpenAI model-list format.
"""
from __future__ import annotations
import json
import httpx
import pytest
from tests.integration.conftest import IntegrationKey
from tests.integration.mock_ollama import DEFAULT_MODELS
pytestmark = pytest.mark.asyncio
async def test_chat_completions_sse_ends_with_done(
client: httpx.AsyncClient, api_key: IntegrationKey
) -> None:
events: list[str] = []
async with client.stream(
"POST",
"/v1/chat/completions",
headers={"Authorization": f"Bearer {api_key.full_key}"},
json={
"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hi"}],
"stream": True,
},
) as resp:
assert resp.status_code == 200
assert "text/event-stream" in resp.headers.get("content-type", "")
async for line in resp.aiter_lines():
if line:
events.append(line)
# SSE framing: every line we kept is a ``data: `` line.
assert all(e.startswith("data: ") for e in events), events
assert events[-1] == "data: [DONE]"
# Parse one delta chunk to confirm OpenAI shape.
payload_line = next(e for e in events if e != "data: [DONE]")
payload = json.loads(payload_line.removeprefix("data: "))
assert payload["object"] == "chat.completion.chunk"
assert payload["choices"][0]["index"] == 0
async def test_chat_completions_non_streaming_shape(
client: httpx.AsyncClient, api_key: IntegrationKey
) -> None:
resp = await client.post(
"/v1/chat/completions",
headers={"Authorization": f"Bearer {api_key.full_key}"},
json={
"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hi"}],
"stream": False,
},
)
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["object"] == "chat.completion"
assert body["choices"][0]["message"]["role"] == "assistant"
assert body["usage"]["total_tokens"] >= 0
async def test_v1_models_returns_effective_set(
client: httpx.AsyncClient, api_key: IntegrationKey
) -> None:
resp = await client.get(
"/v1/models", headers={"Authorization": f"Bearer {api_key.full_key}"}
)
assert resp.status_code == 200
body = resp.json()
assert body["object"] == "list"
ids = {m["id"] for m in body["data"]}
# ``api_key``'s tenant was created with the full DEFAULT_MODELS allowlist.
assert set(DEFAULT_MODELS) <= ids
for model in body["data"]:
assert model["object"] == "model"

View File

@@ -0,0 +1,161 @@
"""Integration tests for the native Ollama proxy (streaming + non-streaming).
SPEC §4.3 steps 11-13, §6.1-6.2, §12: ``/api/chat`` and ``/api/generate`` stream
NDJSON end-to-end through to the in-process mock upstream; the gateway preserves
the byte stream untouched and the audit row that lands after stream close
carries token counts equal to the mock's ``prompt_eval_count`` + ``eval_count``;
``/api/pull`` (and the other mutating endpoints) returns 403 with a generic
sanitized body.
"""
from __future__ import annotations
import asyncio
import json
from typing import Any
import httpx
import pytest
from sqlalchemy import select
from neuronetz_gateway.db.models import AuditLog
from tests.integration.conftest import IntegrationApp, IntegrationKey
pytestmark = pytest.mark.asyncio
async def _wait_for_audit(integration_app: IntegrationApp, key_id: Any) -> AuditLog | None:
"""Poll the audit table for up to 2s waiting for the buffered writer drain."""
for _ in range(40): # 40 * 50ms = 2s
async with integration_app.sessionmaker() as session:
row: AuditLog | None = (
await session.execute(
select(AuditLog)
.where(AuditLog.key_id == key_id)
.order_by(AuditLog.id.desc())
)
).scalar_one_or_none()
if row is not None:
return row
await asyncio.sleep(0.05)
return None
async def test_chat_stream_ndjson_roundtrip(
client: httpx.AsyncClient, integration_app: IntegrationApp, api_key: IntegrationKey
) -> None:
auth = {"Authorization": f"Bearer {api_key.full_key}"}
received: list[dict[str, Any]] = []
async with client.stream(
"POST",
"/api/chat",
headers=auth,
json={
"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hello"}],
"stream": True,
},
) as resp:
assert resp.status_code == 200
assert "application/x-ndjson" in resp.headers.get("content-type", "")
async for line in resp.aiter_lines():
if not line:
continue
received.append(json.loads(line))
# mock_ollama emits one frame per word, then a final done frame.
assert received[-1].get("done") is True
assert received[-1].get("model") == "llama3.1:8b"
# Token counts are present on the final NDJSON frame (forwarded unchanged
# from upstream); the gateway also records them in the audit row.
assert received[-1].get("prompt_eval_count") is not None
assert received[-1].get("eval_count") is not None
audit = await _wait_for_audit(integration_app, api_key.key_id)
if audit is None:
pytest.skip("pending Backend: audit row not flushed (drain may still be wiring up)")
assert audit.tokens_in == received[-1]["prompt_eval_count"]
assert audit.tokens_out == received[-1]["eval_count"]
assert audit.model == "llama3.1:8b"
assert audit.path == "/api/chat"
async def test_chat_non_streaming(
client: httpx.AsyncClient, api_key: IntegrationKey
) -> None:
resp = await client.post(
"/api/chat",
headers={"Authorization": f"Bearer {api_key.full_key}"},
json={
"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hello there"}],
"stream": False,
},
)
assert resp.status_code == 200
body = resp.json()
assert body["model"] == "llama3.1:8b"
assert isinstance(body.get("message"), dict)
assert body.get("done") is True
# Usage counters carried through.
assert isinstance(body.get("prompt_eval_count"), int)
assert isinstance(body.get("eval_count"), int)
async def test_generate_stream(
client: httpx.AsyncClient, api_key: IntegrationKey
) -> None:
auth = {"Authorization": f"Bearer {api_key.full_key}"}
last: dict[str, Any] = {}
async with client.stream(
"POST",
"/api/generate",
headers=auth,
json={"model": "llama3.1:8b", "prompt": "once upon a time", "stream": True},
) as resp:
assert resp.status_code == 200
async for line in resp.aiter_lines():
if line:
last = json.loads(line)
assert last.get("done") is True
assert "response" in last # generate uses 'response' not 'message'
@pytest.mark.parametrize(
"path", ["/api/pull", "/api/push", "/api/create", "/api/copy", "/api/delete"]
)
async def test_mutating_endpoints_return_403(
client: httpx.AsyncClient, api_key: IntegrationKey, path: str
) -> None:
resp = await client.post(
path,
headers={"Authorization": f"Bearer {api_key.full_key}"},
json={"name": "llama3.1:8b"},
)
assert resp.status_code == 403
err = resp.json()["error"]
# Generic — no enumeration of what's blocked, no upstream details.
assert err["code"] in {"forbidden", "not_found"}
assert err["request_id"]
async def test_api_ps_blocked(
client: httpx.AsyncClient, api_key: IntegrationKey
) -> None:
# /api/ps is also hard-blocked (leaks loaded models per SPEC §6.1).
resp = await client.get(
"/api/ps", headers={"Authorization": f"Bearer {api_key.full_key}"}
)
assert resp.status_code in {403, 404, 405} # however the router renders it
if resp.status_code == 403:
assert resp.json()["error"]["code"] == "forbidden"
async def test_blobs_prefix_blocked(
client: httpx.AsyncClient, api_key: IntegrationKey
) -> None:
resp = await client.get(
"/api/blobs/sha256:abc",
headers={"Authorization": f"Bearer {api_key.full_key}"},
)
# Hard-blocked prefix — never reaches Ollama.
assert resp.status_code in {403, 404, 405}

View File

@@ -0,0 +1,135 @@
"""Integration tests for rate limiting (SPEC §4.3 step 4, §4.4, §6.5, §12).
* Per-key RPM trips at the configured limit; 429 carries ``Retry-After`` and
the §6.5 rate-limit headers.
* Redis outage fails closed with 503 (never 200).
"""
from __future__ import annotations
import asyncio
import httpx
import pytest
from tests.integration.conftest import (
IntegrationApp,
IntegrationKey,
_create_tenant_and_key,
)
from tests.integration.mock_ollama import DEFAULT_MODELS
pytestmark = pytest.mark.asyncio
async def _spam(client: httpx.AsyncClient, key: IntegrationKey, n: int) -> list[httpx.Response]:
"""Issue n authenticated non-streaming chat requests sequentially."""
out: list[httpx.Response] = []
for _ in range(n):
out.append(
await client.post(
"/api/chat",
headers={"Authorization": f"Bearer {key.full_key}"},
json={
"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hi"}],
"stream": False,
},
)
)
await asyncio.sleep(0.003) # avoid the known same-ms ZSET collision bug
return out
async def test_rpm_limit_returns_429_with_retry_after(
integration_app: IntegrationApp, client: httpx.AsyncClient
) -> None:
# Tight RPM so we hit the cap without burning the suite's time.
key = await _create_tenant_and_key(
integration_app,
rpm=2,
allowed_models=list(DEFAULT_MODELS),
)
results = await _spam(client, key, 5)
codes = [r.status_code for r in results]
# Some requests admitted, then 429s start.
assert 429 in codes, codes
blocked = next(r for r in results if r.status_code == 429)
# Retry-After header per SPEC §6.5.
assert "retry-after" in {h.lower() for h in blocked.headers.keys()}
err = blocked.json()["error"]
assert err["code"] == "rate_limited"
assert err["request_id"]
# The successful responses carry the §6.5 rate-limit headers.
ok = next(r for r in results if r.status_code == 200)
assert "X-RateLimit-Limit-Requests" in ok.headers
assert "X-RateLimit-Remaining-Requests" in ok.headers
class _BrokenRedis:
"""Stand-in Redis client whose every call raises ``RedisError``.
Drop-in for ``app.state.redis`` to model "Redis unreachable" deterministically
without tearing down the testcontainer (which subsequent tests need).
"""
def __getattr__(self, _name: str): # type: ignore[no-untyped-def]
from redis.exceptions import RedisError
async def _fail(*_a: object, **_k: object) -> None:
raise RedisError("simulated Redis outage")
def _register_script(*_a: object, **_k: object): # type: ignore[no-untyped-def]
class _Script:
async def __call__(self, *_aa: object, **_kk: object) -> None:
raise RedisError("simulated Redis outage")
return _Script()
if _name == "register_script":
return _register_script
return _fail
async def test_redis_down_fails_closed_with_503(
integration_app: IntegrationApp, client: httpx.AsyncClient, api_key: IntegrationKey
) -> None:
# Warm the auth cache so the auth middleware doesn't need Redis to admit us
# (this test targets the *rate-limit* fail-closed path specifically).
warm = await client.post(
"/api/chat",
headers={"Authorization": f"Bearer {api_key.full_key}"},
json={
"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hi"}],
"stream": False,
},
)
assert warm.status_code == 200
# Swap the Redis client for a stub that raises on every call. The
# SlidingWindowLimiter / BudgetCounter / ConcurrencyLimiter must catch
# the RedisError and surface DependencyUnavailableError => 503.
original = integration_app.app.state.redis
integration_app.app.state.redis = _BrokenRedis()
try:
resp = await client.post(
"/api/chat",
headers={"Authorization": f"Bearer {api_key.full_key}"},
json={
"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hi"}],
"stream": False,
},
)
# The hard requirement: fail closed, never 200 (SPEC §4.4).
assert resp.status_code != 200, resp.text
assert resp.status_code >= 400
body = resp.json()
assert "error" in body
# No upstream/internal leakage.
msg = " ".join(str(v) for v in body["error"].values()).lower()
for needle in ("redis", "asyncpg", "traceback"):
assert needle not in msg, body
finally:
integration_app.app.state.redis = original

View File

@@ -0,0 +1,67 @@
"""Integration tests for key revocation (SPEC §4.5, §12).
INSERT INTO ``gateway.revocations`` → AFTER-INSERT trigger fires
``pg_notify('key_revoked', key_id)`` → the gateway's revocation listener evicts
the Redis cache entry for that key prefix → the next request with that key
returns 401 within ~1 second. The key's DB status is also flipped to revoked so
a full DB re-lookup likewise denies (defense in depth).
"""
from __future__ import annotations
import asyncio
import httpx
import pytest
from sqlalchemy import update
from neuronetz_gateway.db.models import ApiKey, KeyStatus
from neuronetz_gateway.db.repositories import RevocationRepository
from tests.integration.conftest import IntegrationApp, IntegrationKey
pytestmark = pytest.mark.asyncio
async def _chat(client: httpx.AsyncClient, full_key: str) -> httpx.Response:
return await client.post(
"/api/chat",
headers={"Authorization": f"Bearer {full_key}"},
json={
"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hi"}],
"stream": False,
},
)
async def test_revoked_key_rejected_within_one_second(
integration_app: IntegrationApp,
client: httpx.AsyncClient,
api_key: IntegrationKey,
) -> None:
# Warm the auth cache by making a successful request first.
ok = await _chat(client, api_key.full_key)
assert ok.status_code == 200
# Revoke: flip status + insert into the outbox (fires NOTIFY via trigger).
async with integration_app.sessionmaker() as session:
await session.execute(
update(ApiKey).where(ApiKey.id == api_key.key_id).values(status=KeyStatus.revoked)
)
await RevocationRepository(session).insert(api_key.key_id, "test")
await session.commit()
# Poll up to 1s for the listener to evict the cache; the next request must
# fail with 401 (or 403; "key revoked" is a sanitized auth failure).
deadline = 1.0
waited = 0.0
step = 0.05
while waited < deadline:
resp = await _chat(client, api_key.full_key)
if resp.status_code != 200:
assert resp.status_code in {401, 403}
assert resp.json()["error"]["request_id"]
return
await asyncio.sleep(step)
waited += step
pytest.fail(f"revoked key still accepted after {deadline}s")