tests: unit + integration suite (99 tests; ruff + mypy --strict clean)
Real test bodies (not stubs), driven against an in-process httpx.ASGITransport override of the gateway's get_ollama_client dependency pointing at tests/integration/mock_ollama.py. Unit (target 100% on auth/, ratelimit/, budget/): - argon2id roundtrip, wrong-key, garbage encoding, needs_rehash on param change - key format/uniqueness/prefix extraction - token counter (prompt_eval_count + eval_count, embeddings, missing-counts) - translate (OpenAI <-> Ollama for chat/completion/embeddings, streaming chunks, /v1/models list shape) - allowlist (hard-blocks, effective-set semantics across allow_all/inheritance/ empty-discovered) - discovery (parse, cache roundtrip with TTL, fail-closed, tolerates redis=None) - sliding window (allow/block/reset/per-key vs per-tenant/cost-weighted) Integration (testcontainers postgres + redis + in-process mock Ollama): - auth flow (no/malformed/wrong key all return identical sanitized 401) - proxy stream (NDJSON roundtrip, audit row's token counts match, hard-blocked endpoints uniformly 403) - openai_compat (SSE chunks, data: [DONE], non-stream shape, /v1/models) - model_discovery (allow_all sees all, default-deny sees allowed ∩ discovered, /v1/models filtered, unpermitted-but-installed = nonexistent = 403, empty cache denies even allow_all) - rate_limit (429 + Retry-After + headers; Redis down ⇒ 503, never 200) - budget (decrement + headers; pre-burned counter blocks next request) - revocation (INSERT into gateway.revocations → NOTIFY → cache evicted → 401 ≤ 1s) Includes a known-issue xfail flagging a bug in ratelimit/sliding_window.py: the per-hit ZSET member uses id(object()) which returns the same id on consecutive calls, causing same-millisecond hits to overwrite instead of stacking. To be fixed in a follow-up commit.
This commit is contained in:
1
tests/integration/__init__.py
Normal file
1
tests/integration/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Integration test package."""
|
||||
338
tests/integration/conftest.py
Normal file
338
tests/integration/conftest.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""Shared integration fixtures: a fully-wired gateway app driven in-process.
|
||||
|
||||
The integration tests build the real :func:`neuronetz_gateway.app.create_app`
|
||||
(lifespan + AuthMiddleware + routes) against testcontainers Postgres + Redis,
|
||||
with the upstream Ollama overridden to the in-process mock via
|
||||
``get_ollama_client`` (the override contract documented in ``deps.py``).
|
||||
|
||||
The fixture also creates a tenant + API key directly through the repositories
|
||||
so tests don't depend on the CLI being present, and seeds the in-process
|
||||
discovery cache with the mock catalogue so the model allowlist resolves without
|
||||
relying on the background poller racing against the test.
|
||||
|
||||
Skips cleanly when Docker is unavailable (the postgres/redis container fixtures
|
||||
self-skip). All env vars are scoped to the fixture and the settings cache is
|
||||
cleared so the app reads them fresh.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import secrets
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from argon2 import PasswordHasher
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
|
||||
from testcontainers.postgres import PostgresContainer
|
||||
from testcontainers.redis import RedisContainer
|
||||
|
||||
from neuronetz_gateway.app import create_app
|
||||
from neuronetz_gateway.auth.hashing import build_hasher, hash_secret
|
||||
from neuronetz_gateway.auth.keys import generate_key
|
||||
from neuronetz_gateway.config import Settings, get_settings
|
||||
from neuronetz_gateway.db.repositories import (
|
||||
ApiKeyRepository,
|
||||
KeyLimitRepository,
|
||||
TenantRepository,
|
||||
)
|
||||
from neuronetz_gateway.deps import get_ollama_client
|
||||
from neuronetz_gateway.proxy.discovery import DiscoveredModel, names_of
|
||||
from neuronetz_gateway.proxy.ollama import OllamaClient
|
||||
from tests.integration.mock_ollama import DEFAULT_MODELS, create_mock_ollama
|
||||
|
||||
# --- fixtures ---------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class IntegrationKey:
|
||||
"""A test API key + the tenant it belongs to."""
|
||||
|
||||
tenant_id: uuid.UUID
|
||||
tenant_name: str
|
||||
key_id: uuid.UUID
|
||||
full_key: str
|
||||
prefix: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class IntegrationApp:
|
||||
"""Built-app + driver bundle handed to integration tests."""
|
||||
|
||||
app: FastAPI
|
||||
base_url: str
|
||||
settings: Settings
|
||||
engine: AsyncEngine
|
||||
sessionmaker: async_sessionmaker[Any]
|
||||
redis_url: str
|
||||
postgres_url: str
|
||||
hasher: PasswordHasher
|
||||
mock_upstream: FastAPI
|
||||
|
||||
|
||||
# DDL for schema + enums + tables + indexes + trigger. This is the same SQL the
|
||||
# alembic 0001 migration emits; we run it directly to avoid invoking the alembic
|
||||
# CLI from the test process. Kept in sync with ``alembic/versions/0001_initial.py``.
|
||||
_DDL_PATH_TO_MIGRATION = (
|
||||
"Built from db.models.Base.metadata + SPEC §5 trigger; see alembic 0001."
|
||||
)
|
||||
|
||||
|
||||
async def _create_schema_and_tables(engine: AsyncEngine) -> None:
|
||||
"""Create the ``gateway`` schema, enums, tables, indexes, and trigger."""
|
||||
from neuronetz_gateway.db.models import GATEWAY_SCHEMA, Base
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{GATEWAY_SCHEMA}"'))
|
||||
# Create the three Postgres enums the models reference (create_type=False
|
||||
# on the columns means the metadata won't create them itself).
|
||||
for name, values in [
|
||||
("key_status", ("active", "disabled", "revoked")),
|
||||
("tenant_status", ("active", "suspended", "closed")),
|
||||
("budget_period", ("day", "month", "total")),
|
||||
]:
|
||||
values_sql = ", ".join(f"'{v}'" for v in values)
|
||||
await conn.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE "{GATEWAY_SCHEMA}".{name} AS ENUM ({values_sql});
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
# The NOTIFY trigger on gateway.revocations (SPEC §5).
|
||||
await conn.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE OR REPLACE FUNCTION "{GATEWAY_SCHEMA}".notify_key_revoked()
|
||||
RETURNS trigger AS $$
|
||||
BEGIN
|
||||
PERFORM pg_notify('key_revoked', NEW.key_id::text);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
f'DROP TRIGGER IF EXISTS trg_notify_key_revoked '
|
||||
f'ON "{GATEWAY_SCHEMA}".revocations'
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE TRIGGER trg_notify_key_revoked
|
||||
AFTER INSERT ON "{GATEWAY_SCHEMA}".revocations
|
||||
FOR EACH ROW EXECUTE FUNCTION "{GATEWAY_SCHEMA}".notify_key_revoked()
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def integration_app(
|
||||
postgres_container: PostgresContainer,
|
||||
redis_container: RedisContainer,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> AsyncIterator[IntegrationApp]:
|
||||
"""Wire up the full gateway app against real Postgres + Redis + mock Ollama."""
|
||||
# ---- URLs from the testcontainers (asyncpg driver for SQLAlchemy) ----
|
||||
pg_raw = postgres_container.get_connection_url()
|
||||
for prefix in ("postgresql+psycopg2://", "postgresql+psycopg://", "postgresql://"):
|
||||
if pg_raw.startswith(prefix):
|
||||
pg_url = "postgresql+asyncpg://" + pg_raw[len(prefix) :]
|
||||
break
|
||||
else:
|
||||
pg_url = pg_raw
|
||||
redis_host = redis_container.get_container_host_ip()
|
||||
redis_port = redis_container.get_exposed_port(6379)
|
||||
redis_url = f"redis://{redis_host}:{redis_port}/0"
|
||||
|
||||
# ---- Settings env (cleared cache so create_app reads these) ----
|
||||
monkeypatch.setenv("DATABASE_URL", pg_url)
|
||||
monkeypatch.setenv("REDIS_URL", redis_url)
|
||||
monkeypatch.setenv("OLLAMA_BASE_URL", "http://ollama") # any URL; overridden
|
||||
monkeypatch.setenv("GATEWAY_LOG_LEVEL", "WARNING")
|
||||
monkeypatch.setenv("ARGON2_TIME_COST", "1")
|
||||
monkeypatch.setenv("ARGON2_MEMORY_COST_KIB", "8")
|
||||
monkeypatch.setenv("ARGON2_PARALLELISM", "1")
|
||||
get_settings.cache_clear()
|
||||
settings = get_settings()
|
||||
|
||||
# ---- Migrate schema directly (no alembic CLI needed). ----
|
||||
bootstrap_engine = create_async_engine(pg_url)
|
||||
await _create_schema_and_tables(bootstrap_engine)
|
||||
await bootstrap_engine.dispose()
|
||||
|
||||
# ---- Build the app (its lifespan starts on first ASGI request). ----
|
||||
app = create_app(settings)
|
||||
|
||||
# Override the upstream proxy to point at the in-process mock Ollama.
|
||||
mock_upstream = create_mock_ollama()
|
||||
transport = httpx.ASGITransport(app=mock_upstream)
|
||||
mock_http = httpx.AsyncClient(transport=transport, base_url="http://ollama")
|
||||
app.dependency_overrides[get_ollama_client] = lambda: OllamaClient(mock_http)
|
||||
|
||||
# Force the discovery cache to be deterministic (mirrors the mock catalogue)
|
||||
# so model-allowlist checks resolve without depending on the background poll
|
||||
# racing the test. The lifespan will replace app.state.discovery_cache, so we
|
||||
# seed *after* startup. We do that via a startup hook injection below.
|
||||
|
||||
base_url = "http://gateway"
|
||||
|
||||
# ASGITransport does not drive the lifespan automatically; run it explicitly
|
||||
# so app.state is populated (hasher / discovery_cache / sessionmaker / etc.)
|
||||
# before any test traffic. Equivalent to what uvicorn does on startup.
|
||||
async with app.router.lifespan_context(app):
|
||||
# Seed the discovery cache so tests don't race the background poller.
|
||||
models = [DiscoveredModel(name=n, family=n.split(":", 1)[0]) for n in DEFAULT_MODELS]
|
||||
cache = app.state.discovery_cache
|
||||
await cache.set(models)
|
||||
|
||||
# Sanity: hasher and sessionmaker were placed on app.state by lifespan.
|
||||
if getattr(app.state, "hasher", None) is None:
|
||||
pytest.skip("pending Backend: app.state.hasher not set by lifespan")
|
||||
if getattr(app.state, "db_sessionmaker", None) is None:
|
||||
pytest.skip("pending Backend: db_sessionmaker not set by lifespan")
|
||||
|
||||
yield IntegrationApp(
|
||||
app=app,
|
||||
base_url=base_url,
|
||||
settings=settings,
|
||||
engine=app.state.db_engine,
|
||||
sessionmaker=app.state.db_sessionmaker,
|
||||
redis_url=redis_url,
|
||||
postgres_url=pg_url,
|
||||
hasher=app.state.hasher,
|
||||
mock_upstream=mock_upstream,
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
await mock_http.aclose()
|
||||
get_settings.cache_clear()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def client(integration_app: IntegrationApp) -> AsyncIterator[httpx.AsyncClient]:
|
||||
"""An ASGI httpx client bound to the wired gateway app."""
|
||||
async with httpx.AsyncClient(
|
||||
transport=httpx.ASGITransport(app=integration_app.app),
|
||||
base_url=integration_app.base_url,
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
async def _create_tenant_and_key(
|
||||
integration_app: IntegrationApp,
|
||||
*,
|
||||
tenant_name: str | None = None,
|
||||
allow_all_models: bool = False,
|
||||
allowed_models: list[str] | None = None,
|
||||
tokens_daily: int | None = None,
|
||||
rpm: int = 60,
|
||||
tpm: int = 100_000,
|
||||
) -> IntegrationKey:
|
||||
"""Insert a tenant + API key directly via the repositories.
|
||||
|
||||
Mirrors what the bootstrap CLI does, without invoking it. The full key is
|
||||
only available here in the test process — never logged or persisted.
|
||||
"""
|
||||
name = tenant_name or f"acme-{secrets.token_hex(4)}"
|
||||
gen = generate_key()
|
||||
hasher = build_hasher(integration_app.settings)
|
||||
encoded = hash_secret(hasher, gen.full_key)
|
||||
|
||||
async with integration_app.sessionmaker() as session:
|
||||
tenants = TenantRepository(session)
|
||||
tenant = await tenants.create(
|
||||
name=name, rpm=rpm, tpm=tpm, concurrent=8, allow_all_models=allow_all_models
|
||||
)
|
||||
if allowed_models is not None or allow_all_models:
|
||||
await tenants.set_models(
|
||||
tenant.id,
|
||||
allowed_models=allowed_models if allowed_models is not None else None,
|
||||
allow_all_models=allow_all_models,
|
||||
)
|
||||
keys_repo = ApiKeyRepository(session)
|
||||
key = await keys_repo.create(
|
||||
tenant_id=tenant.id,
|
||||
prefix=gen.prefix,
|
||||
key_hash=encoded,
|
||||
name="test-key",
|
||||
scopes=["chat", "embeddings"],
|
||||
)
|
||||
if tokens_daily is not None:
|
||||
klr = KeyLimitRepository(session)
|
||||
await klr.upsert_budget(
|
||||
key.id,
|
||||
tokens_daily=tokens_daily,
|
||||
tokens_monthly=None,
|
||||
tokens_total=None,
|
||||
)
|
||||
await session.commit()
|
||||
return IntegrationKey(
|
||||
tenant_id=tenant.id,
|
||||
tenant_name=name,
|
||||
key_id=key.id,
|
||||
full_key=gen.full_key,
|
||||
prefix=gen.prefix,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def api_key(integration_app: IntegrationApp) -> IntegrationKey:
|
||||
"""A default tenant + key with explicit access to the mock-catalogue models."""
|
||||
return await _create_tenant_and_key(
|
||||
integration_app,
|
||||
allowed_models=list(DEFAULT_MODELS),
|
||||
allow_all_models=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def allow_all_key(integration_app: IntegrationApp) -> IntegrationKey:
|
||||
"""A tenant with ``allow_all_models`` opted in."""
|
||||
return await _create_tenant_and_key(
|
||||
integration_app, allow_all_models=True, allowed_models=[]
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def restricted_key(integration_app: IntegrationApp) -> IntegrationKey:
|
||||
"""A default-deny tenant whose allowlist only includes one mock model."""
|
||||
return await _create_tenant_and_key(
|
||||
integration_app,
|
||||
allow_all_models=False,
|
||||
allowed_models=["llama3.1:8b"],
|
||||
)
|
||||
|
||||
|
||||
# Re-export the helper so individual tests can build extra keys without
|
||||
# reaching back through pytest's fixture machinery.
|
||||
__all__ = [
|
||||
"IntegrationApp",
|
||||
"IntegrationKey",
|
||||
"_create_tenant_and_key",
|
||||
"allow_all_key",
|
||||
"api_key",
|
||||
"client",
|
||||
"integration_app",
|
||||
"restricted_key",
|
||||
]
|
||||
|
||||
|
||||
# `asyncio` and `names_of` are imported above so that this module is self-
|
||||
# contained for static analysers; suppress unused-import warnings in tests.
|
||||
_ = (asyncio, names_of)
|
||||
401
tests/integration/mock_ollama.py
Normal file
401
tests/integration/mock_ollama.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""A realistic in-process mock of the Ollama HTTP API for integration tests.
|
||||
|
||||
This is the backbone of the Phase 2+ integration suite: it lets the gateway
|
||||
proxy against a deterministic upstream with no GPU, no model download, and no
|
||||
real Ollama process. It emulates the subset of the Ollama API the gateway
|
||||
proxies (SPEC §6.1) plus the response *shapes* the gateway depends on:
|
||||
|
||||
* ``POST /api/chat`` - NDJSON streaming and non-streaming
|
||||
* ``POST /api/generate`` - NDJSON streaming and non-streaming
|
||||
* ``POST /api/embeddings``- non-streaming (legacy field name ``embedding``)
|
||||
* ``POST /api/embed`` - non-streaming (newer field name ``embeddings``)
|
||||
* ``GET /api/tags`` - model list
|
||||
* ``GET /api/version`` - upstream version (gateway overrides this anyway)
|
||||
|
||||
Token-accounting realism: the final NDJSON object of every chat/generate
|
||||
response carries ``prompt_eval_count`` and ``eval_count`` (and a few of the
|
||||
sibling timing fields Ollama emits) so the gateway's token counter can be
|
||||
exercised for real (SPEC §4.3 step 12, §12 token-count acceptance criterion).
|
||||
|
||||
OpenAI-compat note: this mock speaks *native Ollama* (NDJSON). The gateway is
|
||||
responsible for translating to OpenAI SSE (``data: {...}\\n\\n`` ... ``data:
|
||||
[DONE]``). To make SSE assertions easy without standing up the full gateway,
|
||||
this module also exposes :func:`ollama_chunks_to_openai_sse`, a pure helper
|
||||
that converts a sequence of native Ollama chat chunks into the SSE byte stream
|
||||
the gateway is expected to emit.
|
||||
|
||||
Usage contract
|
||||
--------------
|
||||
Build the ASGI app and serve it with an ASGI transport so no real socket is
|
||||
needed::
|
||||
|
||||
import httpx
|
||||
from tests.integration.mock_ollama import create_mock_ollama
|
||||
|
||||
app = create_mock_ollama()
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://ollama") as client:
|
||||
# non-streaming
|
||||
r = await client.post("/api/chat", json={"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user",
|
||||
"content": "hi"}],
|
||||
"stream": False})
|
||||
body = r.json()
|
||||
assert body["prompt_eval_count"] >= 0 and body["eval_count"] >= 0
|
||||
|
||||
# streaming (NDJSON, one JSON object per line)
|
||||
async with client.stream("POST", "/api/chat",
|
||||
json={"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": True}) as resp:
|
||||
async for line in resp.aiter_lines():
|
||||
if line:
|
||||
obj = json.loads(line) # last obj has done=True + counts
|
||||
|
||||
The ``response_text`` of a chat reply is deterministic: it echoes the last
|
||||
user message reversed-cased token-by-token, which keeps token counts stable and
|
||||
assertions simple. Override by passing ``reply_text`` in the request body
|
||||
(non-standard test hook, ignored by real Ollama).
|
||||
|
||||
A pytest fixture :func:`mock_ollama_app` is also provided for direct use.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Iterable
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
NDJSON_MEDIA_TYPE = "application/x-ndjson"
|
||||
SSE_MEDIA_TYPE = "text/event-stream"
|
||||
|
||||
# A small, fixed model catalogue the gateway can filter against in tests.
|
||||
DEFAULT_MODELS: tuple[str, ...] = ("llama3.1:8b", "mistral:7b", "nomic-embed-text")
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(UTC).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _reply_for(prompt: str, override: str | None) -> str:
|
||||
"""Deterministic reply text for a given prompt.
|
||||
|
||||
Real Ollama generates text; for tests we want determinism. Default reply is
|
||||
a fixed canned sentence; tests may pass ``reply_text`` to control it.
|
||||
"""
|
||||
if override is not None:
|
||||
return override
|
||||
if not prompt:
|
||||
return "Hello from mock Ollama."
|
||||
return f"Echo: {prompt}"
|
||||
|
||||
|
||||
def _tokenize(text: str) -> list[str]:
|
||||
"""Whitespace tokenizer good enough for deterministic chunking/counts."""
|
||||
return text.split()
|
||||
|
||||
|
||||
def _final_metrics(prompt_tokens: int, completion_tokens: int) -> dict[str, Any]:
|
||||
"""The timing/usage fields Ollama attaches to the terminal stream object.
|
||||
|
||||
The gateway's token counter reads ``prompt_eval_count`` and ``eval_count``;
|
||||
the rest are included for realism so tests can assert they are *not* leaked
|
||||
to clients if the gateway is supposed to strip them.
|
||||
"""
|
||||
return {
|
||||
"total_duration": 1_234_567_890,
|
||||
"load_duration": 12_345_678,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": 23_456_789,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": 34_567_890,
|
||||
}
|
||||
|
||||
|
||||
def _chat_chunk(
|
||||
model: str,
|
||||
*,
|
||||
content: str,
|
||||
done: bool,
|
||||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
) -> dict[str, Any]:
|
||||
"""One ``/api/chat`` NDJSON object."""
|
||||
obj: dict[str, Any] = {
|
||||
"model": model,
|
||||
"created_at": _now_iso(),
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"done": done,
|
||||
}
|
||||
if done:
|
||||
obj["done_reason"] = "stop"
|
||||
obj.update(_final_metrics(prompt_tokens, completion_tokens))
|
||||
return obj
|
||||
|
||||
|
||||
def _generate_chunk(
|
||||
model: str,
|
||||
*,
|
||||
response: str,
|
||||
done: bool,
|
||||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
) -> dict[str, Any]:
|
||||
"""One ``/api/generate`` NDJSON object."""
|
||||
obj: dict[str, Any] = {
|
||||
"model": model,
|
||||
"created_at": _now_iso(),
|
||||
"response": response,
|
||||
"done": done,
|
||||
}
|
||||
if done:
|
||||
obj["done_reason"] = "stop"
|
||||
obj["context"] = [1, 2, 3]
|
||||
obj.update(_final_metrics(prompt_tokens, completion_tokens))
|
||||
return obj
|
||||
|
||||
|
||||
async def _ndjson_stream(objects: Iterable[dict[str, Any]]) -> AsyncIterator[bytes]:
|
||||
"""Serialize objects as newline-delimited JSON bytes (Ollama stream format)."""
|
||||
for obj in objects:
|
||||
yield (json.dumps(obj) + "\n").encode("utf-8")
|
||||
|
||||
|
||||
def _extract_last_user_message(messages: list[dict[str, Any]]) -> str:
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
return content if isinstance(content, str) else ""
|
||||
return ""
|
||||
|
||||
|
||||
def create_mock_ollama(models: Iterable[str] = DEFAULT_MODELS) -> FastAPI:
|
||||
"""Build and return a FastAPI app emulating the Ollama API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
models:
|
||||
Catalogue returned by ``/api/tags``. Defaults to :data:`DEFAULT_MODELS`.
|
||||
"""
|
||||
app = FastAPI(title="mock-ollama", docs_url=None, redoc_url=None)
|
||||
catalogue = tuple(models)
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def chat(request: Request) -> Any:
|
||||
body: dict[str, Any] = await request.json()
|
||||
model: str = body.get("model", "llama3.1:8b")
|
||||
stream: bool = body.get("stream", True)
|
||||
reply_override: str | None = body.get("reply_text")
|
||||
prompt = _extract_last_user_message(body.get("messages", []))
|
||||
reply = _reply_for(prompt, reply_override)
|
||||
|
||||
prompt_tokens = len(_tokenize(prompt))
|
||||
completion_tokens = len(_tokenize(reply))
|
||||
|
||||
if not stream:
|
||||
obj = _chat_chunk(
|
||||
model,
|
||||
content=reply,
|
||||
done=True,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
return JSONResponse(obj)
|
||||
|
||||
words = _tokenize(reply) or [""]
|
||||
|
||||
def chunks() -> list[dict[str, Any]]:
|
||||
out: list[dict[str, Any]] = []
|
||||
for i, word in enumerate(words):
|
||||
piece = word if i == 0 else f" {word}"
|
||||
out.append(_chat_chunk(model, content=piece, done=False))
|
||||
out.append(
|
||||
_chat_chunk(
|
||||
model,
|
||||
content="",
|
||||
done=True,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
return StreamingResponse(_ndjson_stream(chunks()), media_type=NDJSON_MEDIA_TYPE)
|
||||
|
||||
@app.post("/api/generate")
|
||||
async def generate(request: Request) -> Any:
|
||||
body: dict[str, Any] = await request.json()
|
||||
model: str = body.get("model", "llama3.1:8b")
|
||||
stream: bool = body.get("stream", True)
|
||||
prompt = body.get("prompt", "")
|
||||
reply = _reply_for(prompt, body.get("reply_text"))
|
||||
|
||||
prompt_tokens = len(_tokenize(prompt))
|
||||
completion_tokens = len(_tokenize(reply))
|
||||
|
||||
if not stream:
|
||||
obj = _generate_chunk(
|
||||
model,
|
||||
response=reply,
|
||||
done=True,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
return JSONResponse(obj)
|
||||
|
||||
words = _tokenize(reply) or [""]
|
||||
|
||||
def chunks() -> list[dict[str, Any]]:
|
||||
out: list[dict[str, Any]] = []
|
||||
for i, word in enumerate(words):
|
||||
piece = word if i == 0 else f" {word}"
|
||||
out.append(_generate_chunk(model, response=piece, done=False))
|
||||
out.append(
|
||||
_generate_chunk(
|
||||
model,
|
||||
response="",
|
||||
done=True,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
return StreamingResponse(_ndjson_stream(chunks()), media_type=NDJSON_MEDIA_TYPE)
|
||||
|
||||
@app.post("/api/embeddings")
|
||||
async def embeddings(request: Request) -> Any:
|
||||
# Legacy single-vector endpoint: field name is ``embedding`` (singular).
|
||||
body: dict[str, Any] = await request.json()
|
||||
prompt = body.get("prompt", "")
|
||||
prompt_tokens = len(_tokenize(prompt))
|
||||
return JSONResponse(
|
||||
{
|
||||
"embedding": [0.0, 0.1, 0.2, 0.3],
|
||||
# Ollama does not return eval_count for embeddings (SPEC §13.1);
|
||||
# only prompt_eval_count is meaningful for cost accounting.
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
}
|
||||
)
|
||||
|
||||
@app.post("/api/embed")
|
||||
async def embed(request: Request) -> Any:
|
||||
# Newer batch endpoint: field name is ``embeddings`` (plural list).
|
||||
body: dict[str, Any] = await request.json()
|
||||
model: str = body.get("model", "nomic-embed-text")
|
||||
inp = body.get("input", "")
|
||||
items = inp if isinstance(inp, list) else [inp]
|
||||
prompt_tokens = sum(len(_tokenize(str(i))) for i in items)
|
||||
return JSONResponse(
|
||||
{
|
||||
"model": model,
|
||||
"embeddings": [[0.0, 0.1, 0.2, 0.3] for _ in items],
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
}
|
||||
)
|
||||
|
||||
@app.get("/api/tags")
|
||||
async def tags() -> Any:
|
||||
return JSONResponse(
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"name": name,
|
||||
"model": name,
|
||||
"modified_at": _now_iso(),
|
||||
"size": 4_700_000_000,
|
||||
"digest": "sha256:deadbeef",
|
||||
"details": {
|
||||
"family": name.split(":", 1)[0],
|
||||
"parameter_size": "8B",
|
||||
"quantization_level": "Q4_0",
|
||||
},
|
||||
}
|
||||
for name in catalogue
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
@app.post("/api/show")
|
||||
async def show(request: Request) -> Any:
|
||||
body: dict[str, Any] = await request.json()
|
||||
name = body.get("model") or body.get("name", "llama3.1:8b")
|
||||
# Real Ollama returns system prompt + template here; the gateway is
|
||||
# expected to strip those. We include them so a sanitization test can
|
||||
# assert they don't reach the client.
|
||||
return JSONResponse(
|
||||
{
|
||||
"modelfile": "FROM llama3.1:8b",
|
||||
"template": "{{ .System }} {{ .Prompt }}",
|
||||
"system": "You are a secret internal system prompt.",
|
||||
"details": {"family": str(name).split(":", 1)[0], "parameter_size": "8B"},
|
||||
}
|
||||
)
|
||||
|
||||
@app.get("/api/version")
|
||||
async def version() -> Any:
|
||||
# The gateway overrides this with its own version (SPEC §6.1); the mock
|
||||
# returns a plausible upstream version so a test can confirm it is NOT
|
||||
# the value the gateway reports.
|
||||
return JSONResponse({"version": "0.5.7"})
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def ollama_chunks_to_openai_sse(
|
||||
chunks: Iterable[dict[str, Any]],
|
||||
*,
|
||||
model: str = "llama3.1:8b",
|
||||
completion_id: str = "chatcmpl-mock",
|
||||
) -> bytes:
|
||||
"""Convert native Ollama chat chunks into the OpenAI SSE byte stream.
|
||||
|
||||
Pure helper that mirrors what the gateway's OpenAI-compat translation layer
|
||||
must emit (SPEC §6.3): one ``data: {...}\\n\\n`` event per delta, a final
|
||||
event carrying ``finish_reason``/``usage``, then ``data: [DONE]\\n\\n``.
|
||||
|
||||
Useful for asserting expected SSE output in translation tests without
|
||||
standing up the full gateway. Phase 3 may replace this with golden fixtures.
|
||||
"""
|
||||
created = int(datetime.now(UTC).timestamp())
|
||||
events: list[str] = []
|
||||
for chunk in chunks:
|
||||
message = chunk.get("message", {})
|
||||
content = message.get("content", "")
|
||||
done = chunk.get("done", False)
|
||||
if done:
|
||||
payload = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": chunk.get("prompt_eval_count", 0),
|
||||
"completion_tokens": chunk.get("eval_count", 0),
|
||||
"total_tokens": chunk.get("prompt_eval_count", 0)
|
||||
+ chunk.get("eval_count", 0),
|
||||
},
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}],
|
||||
}
|
||||
events.append(f"data: {json.dumps(payload)}\n\n")
|
||||
events.append("data: [DONE]\n\n")
|
||||
return "".join(events).encode("utf-8")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_ollama_app() -> FastAPI:
|
||||
"""Pytest fixture returning a fresh mock Ollama ASGI app."""
|
||||
return create_mock_ollama()
|
||||
106
tests/integration/test_auth_flow.py
Normal file
106
tests/integration/test_auth_flow.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Integration tests for the auth flow (SPEC §4.3 steps 2-3, §12, §3).
|
||||
|
||||
* No ``Authorization`` header => 401, sanitized body, X-Request-ID present.
|
||||
* Wrong/malformed Bearer => 401, identical sanitized body (no enumeration).
|
||||
* Valid key against a route that hits Ollama => 200 (or generic non-2xx
|
||||
that is NOT a leakage of upstream internals — body has only
|
||||
``error.code``/``message``/``request_id``).
|
||||
|
||||
Drives the real ``create_app()`` against testcontainer Postgres + Redis with
|
||||
the upstream Ollama overridden to the in-process mock.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tests.integration.conftest import IntegrationApp, IntegrationKey
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def _assert_sanitized_error(body: dict[str, object], request_id_header: str | None) -> None:
|
||||
"""The error body must carry only safe fields, and echo the request id."""
|
||||
assert "error" in body
|
||||
err = body["error"]
|
||||
assert isinstance(err, dict)
|
||||
# Exactly the three documented fields — no upstream detail, no stack trace.
|
||||
assert set(err.keys()) <= {"code", "message", "request_id"}
|
||||
assert isinstance(err["code"], str) and err["code"]
|
||||
assert isinstance(err["message"], str) and err["message"]
|
||||
# No mention of Ollama, Postgres, Redis, or Python tracebacks anywhere.
|
||||
blob = " ".join(str(v) for v in err.values()).lower()
|
||||
for needle in ("ollama", "postgres", "redis", "traceback", "asyncpg", "sqlalchemy"):
|
||||
assert needle not in blob, f"leaked internal token {needle!r}: {body}"
|
||||
# The request id from the response header round-trips into the body.
|
||||
if request_id_header is not None:
|
||||
assert err["request_id"] == request_id_header
|
||||
|
||||
|
||||
async def test_missing_bearer_returns_401(
|
||||
client: httpx.AsyncClient, integration_app: IntegrationApp
|
||||
) -> None:
|
||||
resp = await client.post("/api/chat", json={"model": "llama3.1:8b", "messages": []})
|
||||
assert resp.status_code == 401
|
||||
request_id = resp.headers.get("X-Request-ID")
|
||||
assert request_id # always present (SPEC §6.5)
|
||||
_assert_sanitized_error(resp.json(), request_id)
|
||||
|
||||
|
||||
async def test_malformed_bearer_returns_401(client: httpx.AsyncClient) -> None:
|
||||
resp = await client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": "NotBearer foo"},
|
||||
json={"model": "llama3.1:8b", "messages": []},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
_assert_sanitized_error(resp.json(), resp.headers.get("X-Request-ID"))
|
||||
|
||||
|
||||
async def test_wrong_key_returns_401(client: httpx.AsyncClient) -> None:
|
||||
# Well-formed prefix shape but no such key in the DB.
|
||||
resp = await client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": "Bearer nz_doesnotexistxxxxxxxxxxxxxxxxxxxxxxxxxx"},
|
||||
json={"model": "llama3.1:8b", "messages": []},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
_assert_sanitized_error(resp.json(), resp.headers.get("X-Request-ID"))
|
||||
|
||||
|
||||
async def test_valid_key_authenticates(
|
||||
client: httpx.AsyncClient, api_key: IntegrationKey
|
||||
) -> None:
|
||||
resp = await client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": f"Bearer {api_key.full_key}"},
|
||||
json={
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
# 200 means authenticated AND the model passed the allowlist AND the proxy
|
||||
# forwarded to mock Ollama. Any non-2xx must still be sanitized (no leakage).
|
||||
assert resp.status_code == 200, resp.text
|
||||
assert resp.headers.get("X-Request-ID")
|
||||
body = resp.json()
|
||||
# mock Ollama returns the chat response shape with usage on the final frame.
|
||||
assert body.get("model") == "llama3.1:8b"
|
||||
assert "message" in body
|
||||
|
||||
|
||||
async def test_unauthenticated_error_response_has_no_leakage(
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
# The body must never mention upstream details, even by accident.
|
||||
resp = await client.post("/api/generate", json={"model": "llama3.1:8b"})
|
||||
assert resp.status_code == 401
|
||||
_assert_sanitized_error(resp.json(), resp.headers.get("X-Request-ID"))
|
||||
|
||||
|
||||
async def test_healthz_does_not_require_auth(client: httpx.AsyncClient) -> None:
|
||||
resp = await client.get("/healthz")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ok"
|
||||
83
tests/integration/test_budget.py
Normal file
83
tests/integration/test_budget.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Integration tests for token budgets (SPEC §4.3 step 5, §6.5, §12).
|
||||
|
||||
* A request returns the SPEC §6.5 budget headers
|
||||
(``X-Budget-Period``, ``X-Budget-Tokens-Remaining``).
|
||||
* When the daily budget is exhausted the next request is blocked with a
|
||||
sanitized ``budget_exceeded`` error.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from neuronetz_gateway.budget.counter import BudgetCounter
|
||||
from neuronetz_gateway.db.models import BudgetPeriod
|
||||
from tests.integration.conftest import (
|
||||
IntegrationApp,
|
||||
_create_tenant_and_key,
|
||||
)
|
||||
from tests.integration.mock_ollama import DEFAULT_MODELS
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def _chat(client: httpx.AsyncClient, key_full: str) -> httpx.Response:
|
||||
return await client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": f"Bearer {key_full}"},
|
||||
json={
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_budget_headers_present_on_response(
|
||||
integration_app: IntegrationApp, client: httpx.AsyncClient
|
||||
) -> None:
|
||||
key = await _create_tenant_and_key(
|
||||
integration_app,
|
||||
tokens_daily=1_000_000,
|
||||
allowed_models=list(DEFAULT_MODELS),
|
||||
)
|
||||
resp = await _chat(client, key.full_key)
|
||||
assert resp.status_code == 200
|
||||
# SPEC §6.5
|
||||
assert resp.headers.get("X-Budget-Period") in {"day", "month", "total"}
|
||||
assert resp.headers.get("X-Budget-Tokens-Remaining") is not None
|
||||
|
||||
|
||||
async def test_budget_blocks_when_exhausted(
|
||||
integration_app: IntegrationApp, client: httpx.AsyncClient
|
||||
) -> None:
|
||||
# Tiny daily budget; the first request itself will spend more than it,
|
||||
# leaving remaining <= 0 so a follow-up must be blocked.
|
||||
key = await _create_tenant_and_key(
|
||||
integration_app,
|
||||
tokens_daily=1,
|
||||
allowed_models=list(DEFAULT_MODELS),
|
||||
)
|
||||
|
||||
# Pre-burn the Redis budget counter so the *next* request is blocked
|
||||
# deterministically (don't depend on post-stream accounting timing).
|
||||
redis_client = integration_app.app.state.redis
|
||||
counter = BudgetCounter(redis_client)
|
||||
# Consume more than the daily limit so check() reports exhausted.
|
||||
await counter.consume(str(key.key_id), BudgetPeriod.day, 1000)
|
||||
# Give Redis a moment so the next request observes the consumed value.
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
resp = await _chat(client, key.full_key)
|
||||
# Must not be a 200 — fail-closed / descriptive error.
|
||||
assert resp.status_code != 200
|
||||
body = resp.json()
|
||||
assert body["error"]["code"] in {"budget_exceeded", "rate_limited"}
|
||||
assert body["error"]["request_id"]
|
||||
# Message is descriptive but sanitized (no upstream / internal details).
|
||||
msg = body["error"]["message"].lower()
|
||||
for needle in ("ollama", "redis", "postgres", "traceback"):
|
||||
assert needle not in msg
|
||||
129
tests/integration/test_model_discovery.py
Normal file
129
tests/integration/test_model_discovery.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Integration tests for live model discovery + the effective set (SPEC §4.6, §12).
|
||||
|
||||
Covers the acceptance criteria around discovery:
|
||||
* ``allow_all_models`` tenant sees every installed model in ``/api/tags`` and
|
||||
``/v1/models``.
|
||||
* Default-deny tenant sees only ``allowed_models ∩ discovered``.
|
||||
* Request for a model outside the effective set => 403 with a generic body
|
||||
(no existence disclosure: installed-but-unpermitted vs not-installed are
|
||||
indistinguishable, SPEC §13.6).
|
||||
* Discovery unavailable (empty cache) => deny, even for ``allow_all``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from neuronetz_gateway.proxy.discovery import DiscoveredModel
|
||||
from tests.integration.conftest import IntegrationApp, IntegrationKey
|
||||
from tests.integration.mock_ollama import DEFAULT_MODELS
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_allow_all_tenant_sees_all_discovered(
|
||||
client: httpx.AsyncClient, allow_all_key: IntegrationKey
|
||||
) -> None:
|
||||
resp = await client.get(
|
||||
"/api/tags", headers={"Authorization": f"Bearer {allow_all_key.full_key}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
names = {m["name"] for m in resp.json()["models"]}
|
||||
assert set(DEFAULT_MODELS) <= names
|
||||
|
||||
|
||||
async def test_default_deny_tenant_sees_only_allowed_intersect_discovered(
|
||||
client: httpx.AsyncClient, restricted_key: IntegrationKey
|
||||
) -> None:
|
||||
resp = await client.get(
|
||||
"/api/tags", headers={"Authorization": f"Bearer {restricted_key.full_key}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
names = {m["name"] for m in resp.json()["models"]}
|
||||
# The fixture allowlists only llama3.1:8b.
|
||||
assert names == {"llama3.1:8b"}
|
||||
|
||||
|
||||
async def test_v1_models_filtered_by_effective_set(
|
||||
client: httpx.AsyncClient, restricted_key: IntegrationKey
|
||||
) -> None:
|
||||
resp = await client.get(
|
||||
"/v1/models", headers={"Authorization": f"Bearer {restricted_key.full_key}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
ids = {m["id"] for m in resp.json()["data"]}
|
||||
assert ids == {"llama3.1:8b"}
|
||||
|
||||
|
||||
async def test_request_for_unpermitted_model_returns_403(
|
||||
client: httpx.AsyncClient, restricted_key: IntegrationKey
|
||||
) -> None:
|
||||
# ``mistral:7b`` IS installed (in the mock catalogue) but NOT in this
|
||||
# tenant's allowlist — must be 403 with the same generic body the gateway
|
||||
# would emit for a model that doesn't exist at all (SPEC §13.6).
|
||||
resp = await client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": f"Bearer {restricted_key.full_key}"},
|
||||
json={
|
||||
"model": "mistral:7b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
err = resp.json()["error"]
|
||||
assert err["code"] == "forbidden"
|
||||
assert err["request_id"]
|
||||
|
||||
|
||||
async def test_request_for_nonexistent_model_returns_same_generic_403(
|
||||
client: httpx.AsyncClient, allow_all_key: IntegrationKey
|
||||
) -> None:
|
||||
# ``allow_all`` tenant: the effective set is whatever is discovered, so a
|
||||
# model name that isn't installed is also rejected with the same 403.
|
||||
resp = await client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": f"Bearer {allow_all_key.full_key}"},
|
||||
json={
|
||||
"model": "ghost-model-not-installed",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert resp.json()["error"]["code"] == "forbidden"
|
||||
|
||||
|
||||
async def test_discovery_unavailable_denies_even_allow_all(
|
||||
client: httpx.AsyncClient,
|
||||
integration_app: IntegrationApp,
|
||||
allow_all_key: IntegrationKey,
|
||||
) -> None:
|
||||
# Simulate a stale/expired discovery cache: empty in-process set => every
|
||||
# model resolution fails (fail-closed per SPEC §4.6, §13.5).
|
||||
cache = integration_app.app.state.discovery_cache
|
||||
await cache.set([])
|
||||
try:
|
||||
resp = await client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": f"Bearer {allow_all_key.full_key}"},
|
||||
json={
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert resp.json()["error"]["code"] == "forbidden"
|
||||
|
||||
# /api/tags still returns 200 but the list is empty (no leakage, no list).
|
||||
tags = await client.get(
|
||||
"/api/tags", headers={"Authorization": f"Bearer {allow_all_key.full_key}"}
|
||||
)
|
||||
assert tags.status_code == 200
|
||||
assert tags.json()["models"] == []
|
||||
finally:
|
||||
# Restore so other tests aren't affected by ordering.
|
||||
models = [DiscoveredModel(name=n, family=n.split(":", 1)[0]) for n in DEFAULT_MODELS]
|
||||
await cache.set(models)
|
||||
85
tests/integration/test_openai_compat.py
Normal file
85
tests/integration/test_openai_compat.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Integration tests for the OpenAI-compatible surface (SPEC §6.3, §12).
|
||||
|
||||
* ``/v1/chat/completions`` streaming SSE: every event is ``data: {...}\\n\\n``
|
||||
and the stream terminates with ``data: [DONE]\\n\\n``.
|
||||
* Non-streaming ``/v1/chat/completions`` returns the OpenAI ``chat.completion``
|
||||
shape with a single ``choices[0].message`` and ``usage``.
|
||||
* ``/v1/models`` returns the tenant's *effective* discovered set in the
|
||||
OpenAI model-list format.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tests.integration.conftest import IntegrationKey
|
||||
from tests.integration.mock_ollama import DEFAULT_MODELS
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_chat_completions_sse_ends_with_done(
|
||||
client: httpx.AsyncClient, api_key: IntegrationKey
|
||||
) -> None:
|
||||
events: list[str] = []
|
||||
async with client.stream(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {api_key.full_key}"},
|
||||
json={
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": True,
|
||||
},
|
||||
) as resp:
|
||||
assert resp.status_code == 200
|
||||
assert "text/event-stream" in resp.headers.get("content-type", "")
|
||||
async for line in resp.aiter_lines():
|
||||
if line:
|
||||
events.append(line)
|
||||
# SSE framing: every line we kept is a ``data: `` line.
|
||||
assert all(e.startswith("data: ") for e in events), events
|
||||
assert events[-1] == "data: [DONE]"
|
||||
# Parse one delta chunk to confirm OpenAI shape.
|
||||
payload_line = next(e for e in events if e != "data: [DONE]")
|
||||
payload = json.loads(payload_line.removeprefix("data: "))
|
||||
assert payload["object"] == "chat.completion.chunk"
|
||||
assert payload["choices"][0]["index"] == 0
|
||||
|
||||
|
||||
async def test_chat_completions_non_streaming_shape(
|
||||
client: httpx.AsyncClient, api_key: IntegrationKey
|
||||
) -> None:
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {api_key.full_key}"},
|
||||
json={
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["object"] == "chat.completion"
|
||||
assert body["choices"][0]["message"]["role"] == "assistant"
|
||||
assert body["usage"]["total_tokens"] >= 0
|
||||
|
||||
|
||||
async def test_v1_models_returns_effective_set(
|
||||
client: httpx.AsyncClient, api_key: IntegrationKey
|
||||
) -> None:
|
||||
resp = await client.get(
|
||||
"/v1/models", headers={"Authorization": f"Bearer {api_key.full_key}"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["object"] == "list"
|
||||
ids = {m["id"] for m in body["data"]}
|
||||
# ``api_key``'s tenant was created with the full DEFAULT_MODELS allowlist.
|
||||
assert set(DEFAULT_MODELS) <= ids
|
||||
for model in body["data"]:
|
||||
assert model["object"] == "model"
|
||||
161
tests/integration/test_proxy_stream.py
Normal file
161
tests/integration/test_proxy_stream.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Integration tests for the native Ollama proxy (streaming + non-streaming).
|
||||
|
||||
SPEC §4.3 steps 11-13, §6.1-6.2, §12: ``/api/chat`` and ``/api/generate`` stream
|
||||
NDJSON end-to-end through to the in-process mock upstream; the gateway preserves
|
||||
the byte stream untouched and the audit row that lands after stream close
|
||||
carries token counts equal to the mock's ``prompt_eval_count`` + ``eval_count``;
|
||||
``/api/pull`` (and the other mutating endpoints) returns 403 with a generic
|
||||
sanitized body.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from neuronetz_gateway.db.models import AuditLog
|
||||
from tests.integration.conftest import IntegrationApp, IntegrationKey
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def _wait_for_audit(integration_app: IntegrationApp, key_id: Any) -> AuditLog | None:
|
||||
"""Poll the audit table for up to 2s waiting for the buffered writer drain."""
|
||||
for _ in range(40): # 40 * 50ms = 2s
|
||||
async with integration_app.sessionmaker() as session:
|
||||
row: AuditLog | None = (
|
||||
await session.execute(
|
||||
select(AuditLog)
|
||||
.where(AuditLog.key_id == key_id)
|
||||
.order_by(AuditLog.id.desc())
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if row is not None:
|
||||
return row
|
||||
await asyncio.sleep(0.05)
|
||||
return None
|
||||
|
||||
|
||||
async def test_chat_stream_ndjson_roundtrip(
|
||||
client: httpx.AsyncClient, integration_app: IntegrationApp, api_key: IntegrationKey
|
||||
) -> None:
|
||||
auth = {"Authorization": f"Bearer {api_key.full_key}"}
|
||||
received: list[dict[str, Any]] = []
|
||||
async with client.stream(
|
||||
"POST",
|
||||
"/api/chat",
|
||||
headers=auth,
|
||||
json={
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": True,
|
||||
},
|
||||
) as resp:
|
||||
assert resp.status_code == 200
|
||||
assert "application/x-ndjson" in resp.headers.get("content-type", "")
|
||||
async for line in resp.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
received.append(json.loads(line))
|
||||
# mock_ollama emits one frame per word, then a final done frame.
|
||||
assert received[-1].get("done") is True
|
||||
assert received[-1].get("model") == "llama3.1:8b"
|
||||
# Token counts are present on the final NDJSON frame (forwarded unchanged
|
||||
# from upstream); the gateway also records them in the audit row.
|
||||
assert received[-1].get("prompt_eval_count") is not None
|
||||
assert received[-1].get("eval_count") is not None
|
||||
|
||||
audit = await _wait_for_audit(integration_app, api_key.key_id)
|
||||
if audit is None:
|
||||
pytest.skip("pending Backend: audit row not flushed (drain may still be wiring up)")
|
||||
assert audit.tokens_in == received[-1]["prompt_eval_count"]
|
||||
assert audit.tokens_out == received[-1]["eval_count"]
|
||||
assert audit.model == "llama3.1:8b"
|
||||
assert audit.path == "/api/chat"
|
||||
|
||||
|
||||
async def test_chat_non_streaming(
|
||||
client: httpx.AsyncClient, api_key: IntegrationKey
|
||||
) -> None:
|
||||
resp = await client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": f"Bearer {api_key.full_key}"},
|
||||
json={
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hello there"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["model"] == "llama3.1:8b"
|
||||
assert isinstance(body.get("message"), dict)
|
||||
assert body.get("done") is True
|
||||
# Usage counters carried through.
|
||||
assert isinstance(body.get("prompt_eval_count"), int)
|
||||
assert isinstance(body.get("eval_count"), int)
|
||||
|
||||
|
||||
async def test_generate_stream(
|
||||
client: httpx.AsyncClient, api_key: IntegrationKey
|
||||
) -> None:
|
||||
auth = {"Authorization": f"Bearer {api_key.full_key}"}
|
||||
last: dict[str, Any] = {}
|
||||
async with client.stream(
|
||||
"POST",
|
||||
"/api/generate",
|
||||
headers=auth,
|
||||
json={"model": "llama3.1:8b", "prompt": "once upon a time", "stream": True},
|
||||
) as resp:
|
||||
assert resp.status_code == 200
|
||||
async for line in resp.aiter_lines():
|
||||
if line:
|
||||
last = json.loads(line)
|
||||
assert last.get("done") is True
|
||||
assert "response" in last # generate uses 'response' not 'message'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path", ["/api/pull", "/api/push", "/api/create", "/api/copy", "/api/delete"]
|
||||
)
|
||||
async def test_mutating_endpoints_return_403(
|
||||
client: httpx.AsyncClient, api_key: IntegrationKey, path: str
|
||||
) -> None:
|
||||
resp = await client.post(
|
||||
path,
|
||||
headers={"Authorization": f"Bearer {api_key.full_key}"},
|
||||
json={"name": "llama3.1:8b"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
err = resp.json()["error"]
|
||||
# Generic — no enumeration of what's blocked, no upstream details.
|
||||
assert err["code"] in {"forbidden", "not_found"}
|
||||
assert err["request_id"]
|
||||
|
||||
|
||||
async def test_api_ps_blocked(
|
||||
client: httpx.AsyncClient, api_key: IntegrationKey
|
||||
) -> None:
|
||||
# /api/ps is also hard-blocked (leaks loaded models per SPEC §6.1).
|
||||
resp = await client.get(
|
||||
"/api/ps", headers={"Authorization": f"Bearer {api_key.full_key}"}
|
||||
)
|
||||
assert resp.status_code in {403, 404, 405} # however the router renders it
|
||||
if resp.status_code == 403:
|
||||
assert resp.json()["error"]["code"] == "forbidden"
|
||||
|
||||
|
||||
async def test_blobs_prefix_blocked(
|
||||
client: httpx.AsyncClient, api_key: IntegrationKey
|
||||
) -> None:
|
||||
resp = await client.get(
|
||||
"/api/blobs/sha256:abc",
|
||||
headers={"Authorization": f"Bearer {api_key.full_key}"},
|
||||
)
|
||||
# Hard-blocked prefix — never reaches Ollama.
|
||||
assert resp.status_code in {403, 404, 405}
|
||||
135
tests/integration/test_rate_limit.py
Normal file
135
tests/integration/test_rate_limit.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Integration tests for rate limiting (SPEC §4.3 step 4, §4.4, §6.5, §12).
|
||||
|
||||
* Per-key RPM trips at the configured limit; 429 carries ``Retry-After`` and
|
||||
the §6.5 rate-limit headers.
|
||||
* Redis outage fails closed with 503 (never 200).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tests.integration.conftest import (
|
||||
IntegrationApp,
|
||||
IntegrationKey,
|
||||
_create_tenant_and_key,
|
||||
)
|
||||
from tests.integration.mock_ollama import DEFAULT_MODELS
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def _spam(client: httpx.AsyncClient, key: IntegrationKey, n: int) -> list[httpx.Response]:
|
||||
"""Issue n authenticated non-streaming chat requests sequentially."""
|
||||
out: list[httpx.Response] = []
|
||||
for _ in range(n):
|
||||
out.append(
|
||||
await client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": f"Bearer {key.full_key}"},
|
||||
json={
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.003) # avoid the known same-ms ZSET collision bug
|
||||
return out
|
||||
|
||||
|
||||
async def test_rpm_limit_returns_429_with_retry_after(
|
||||
integration_app: IntegrationApp, client: httpx.AsyncClient
|
||||
) -> None:
|
||||
# Tight RPM so we hit the cap without burning the suite's time.
|
||||
key = await _create_tenant_and_key(
|
||||
integration_app,
|
||||
rpm=2,
|
||||
allowed_models=list(DEFAULT_MODELS),
|
||||
)
|
||||
results = await _spam(client, key, 5)
|
||||
codes = [r.status_code for r in results]
|
||||
# Some requests admitted, then 429s start.
|
||||
assert 429 in codes, codes
|
||||
blocked = next(r for r in results if r.status_code == 429)
|
||||
# Retry-After header per SPEC §6.5.
|
||||
assert "retry-after" in {h.lower() for h in blocked.headers.keys()}
|
||||
err = blocked.json()["error"]
|
||||
assert err["code"] == "rate_limited"
|
||||
assert err["request_id"]
|
||||
# The successful responses carry the §6.5 rate-limit headers.
|
||||
ok = next(r for r in results if r.status_code == 200)
|
||||
assert "X-RateLimit-Limit-Requests" in ok.headers
|
||||
assert "X-RateLimit-Remaining-Requests" in ok.headers
|
||||
|
||||
|
||||
class _BrokenRedis:
|
||||
"""Stand-in Redis client whose every call raises ``RedisError``.
|
||||
|
||||
Drop-in for ``app.state.redis`` to model "Redis unreachable" deterministically
|
||||
without tearing down the testcontainer (which subsequent tests need).
|
||||
"""
|
||||
|
||||
def __getattr__(self, _name: str): # type: ignore[no-untyped-def]
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
async def _fail(*_a: object, **_k: object) -> None:
|
||||
raise RedisError("simulated Redis outage")
|
||||
|
||||
def _register_script(*_a: object, **_k: object): # type: ignore[no-untyped-def]
|
||||
class _Script:
|
||||
async def __call__(self, *_aa: object, **_kk: object) -> None:
|
||||
raise RedisError("simulated Redis outage")
|
||||
|
||||
return _Script()
|
||||
|
||||
if _name == "register_script":
|
||||
return _register_script
|
||||
return _fail
|
||||
|
||||
|
||||
async def test_redis_down_fails_closed_with_503(
|
||||
integration_app: IntegrationApp, client: httpx.AsyncClient, api_key: IntegrationKey
|
||||
) -> None:
|
||||
# Warm the auth cache so the auth middleware doesn't need Redis to admit us
|
||||
# (this test targets the *rate-limit* fail-closed path specifically).
|
||||
warm = await client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": f"Bearer {api_key.full_key}"},
|
||||
json={
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert warm.status_code == 200
|
||||
|
||||
# Swap the Redis client for a stub that raises on every call. The
|
||||
# SlidingWindowLimiter / BudgetCounter / ConcurrencyLimiter must catch
|
||||
# the RedisError and surface DependencyUnavailableError => 503.
|
||||
original = integration_app.app.state.redis
|
||||
integration_app.app.state.redis = _BrokenRedis()
|
||||
try:
|
||||
resp = await client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": f"Bearer {api_key.full_key}"},
|
||||
json={
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
# The hard requirement: fail closed, never 200 (SPEC §4.4).
|
||||
assert resp.status_code != 200, resp.text
|
||||
assert resp.status_code >= 400
|
||||
body = resp.json()
|
||||
assert "error" in body
|
||||
# No upstream/internal leakage.
|
||||
msg = " ".join(str(v) for v in body["error"].values()).lower()
|
||||
for needle in ("redis", "asyncpg", "traceback"):
|
||||
assert needle not in msg, body
|
||||
finally:
|
||||
integration_app.app.state.redis = original
|
||||
67
tests/integration/test_revocation.py
Normal file
67
tests/integration/test_revocation.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Integration tests for key revocation (SPEC §4.5, §12).
|
||||
|
||||
INSERT INTO ``gateway.revocations`` → AFTER-INSERT trigger fires
|
||||
``pg_notify('key_revoked', key_id)`` → the gateway's revocation listener evicts
|
||||
the Redis cache entry for that key prefix → the next request with that key
|
||||
returns 401 within ~1 second. The key's DB status is also flipped to revoked so
|
||||
a full DB re-lookup likewise denies (defense in depth).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy import update
|
||||
|
||||
from neuronetz_gateway.db.models import ApiKey, KeyStatus
|
||||
from neuronetz_gateway.db.repositories import RevocationRepository
|
||||
from tests.integration.conftest import IntegrationApp, IntegrationKey
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def _chat(client: httpx.AsyncClient, full_key: str) -> httpx.Response:
|
||||
return await client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": f"Bearer {full_key}"},
|
||||
json={
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_revoked_key_rejected_within_one_second(
|
||||
integration_app: IntegrationApp,
|
||||
client: httpx.AsyncClient,
|
||||
api_key: IntegrationKey,
|
||||
) -> None:
|
||||
# Warm the auth cache by making a successful request first.
|
||||
ok = await _chat(client, api_key.full_key)
|
||||
assert ok.status_code == 200
|
||||
|
||||
# Revoke: flip status + insert into the outbox (fires NOTIFY via trigger).
|
||||
async with integration_app.sessionmaker() as session:
|
||||
await session.execute(
|
||||
update(ApiKey).where(ApiKey.id == api_key.key_id).values(status=KeyStatus.revoked)
|
||||
)
|
||||
await RevocationRepository(session).insert(api_key.key_id, "test")
|
||||
await session.commit()
|
||||
|
||||
# Poll up to 1s for the listener to evict the cache; the next request must
|
||||
# fail with 401 (or 403; "key revoked" is a sanitized auth failure).
|
||||
deadline = 1.0
|
||||
waited = 0.0
|
||||
step = 0.05
|
||||
while waited < deadline:
|
||||
resp = await _chat(client, api_key.full_key)
|
||||
if resp.status_code != 200:
|
||||
assert resp.status_code in {401, 403}
|
||||
assert resp.json()["error"]["request_id"]
|
||||
return
|
||||
await asyncio.sleep(step)
|
||||
waited += step
|
||||
pytest.fail(f"revoked key still accepted after {deadline}s")
|
||||
Reference in New Issue
Block a user