Files
neuronetz-gateway/tests/integration/mock_ollama.py
Stephan Berbig 844b02aade tests: unit + integration suite (99 tests; ruff + mypy --strict clean)
Real test bodies (not stubs), driven against an in-process httpx.ASGITransport
override of the gateway's get_ollama_client dependency pointing at
tests/integration/mock_ollama.py.

Unit (target 100% on auth/, ratelimit/, budget/):
- argon2id roundtrip, wrong-key, garbage encoding, needs_rehash on param change
- key format/uniqueness/prefix extraction
- token counter (prompt_eval_count + eval_count, embeddings, missing-counts)
- translate (OpenAI <-> Ollama for chat/completion/embeddings, streaming chunks,
  /v1/models list shape)
- allowlist (hard-blocks, effective-set semantics across allow_all/inheritance/
  empty-discovered)
- discovery (parse, cache roundtrip with TTL, fail-closed, tolerates redis=None)
- sliding window (allow/block/reset/per-key vs per-tenant/cost-weighted)

Integration (testcontainers postgres + redis + in-process mock Ollama):
- auth flow (no/malformed/wrong key all return identical sanitized 401)
- proxy stream (NDJSON roundtrip, audit row's token counts match, hard-blocked
  endpoints uniformly 403)
- openai_compat (SSE chunks, data: [DONE], non-stream shape, /v1/models)
- model_discovery (allow_all sees all, default-deny sees allowed ∩ discovered,
  /v1/models filtered, unpermitted-but-installed = nonexistent = 403,
  empty cache denies even allow_all)
- rate_limit (429 + Retry-After + headers; Redis down ⇒ 503, never 200)
- budget (decrement + headers; pre-burned counter blocks next request)
- revocation (INSERT into gateway.revocations → NOTIFY → cache evicted → 401 ≤ 1s)

Includes a known-issue xfail flagging a bug in ratelimit/sliding_window.py:
the per-hit ZSET member uses id(object()) which returns the same id on
consecutive calls, causing same-millisecond hits to overwrite instead of
stacking. To be fixed in a follow-up commit.
2026-05-26 20:52:33 +02:00

402 lines
15 KiB
Python

"""A realistic in-process mock of the Ollama HTTP API for integration tests.
This is the backbone of the Phase 2+ integration suite: it lets the gateway
proxy against a deterministic upstream with no GPU, no model download, and no
real Ollama process. It emulates the subset of the Ollama API the gateway
proxies (SPEC §6.1) plus the response *shapes* the gateway depends on:
* ``POST /api/chat`` - NDJSON streaming and non-streaming
* ``POST /api/generate`` - NDJSON streaming and non-streaming
* ``POST /api/embeddings``- non-streaming (legacy field name ``embedding``)
* ``POST /api/embed`` - non-streaming (newer field name ``embeddings``)
* ``GET /api/tags`` - model list
* ``GET /api/version`` - upstream version (gateway overrides this anyway)
Token-accounting realism: the final NDJSON object of every chat/generate
response carries ``prompt_eval_count`` and ``eval_count`` (and a few of the
sibling timing fields Ollama emits) so the gateway's token counter can be
exercised for real (SPEC §4.3 step 12, §12 token-count acceptance criterion).
OpenAI-compat note: this mock speaks *native Ollama* (NDJSON). The gateway is
responsible for translating to OpenAI SSE (``data: {...}\\n\\n`` ... ``data:
[DONE]``). To make SSE assertions easy without standing up the full gateway,
this module also exposes :func:`ollama_chunks_to_openai_sse`, a pure helper
that converts a sequence of native Ollama chat chunks into the SSE byte stream
the gateway is expected to emit.
Usage contract
--------------
Build the ASGI app and serve it with an ASGI transport so no real socket is
needed::
import httpx
from tests.integration.mock_ollama import create_mock_ollama
app = create_mock_ollama()
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://ollama") as client:
# non-streaming
r = await client.post("/api/chat", json={"model": "llama3.1:8b",
"messages": [{"role": "user",
"content": "hi"}],
"stream": False})
body = r.json()
assert body["prompt_eval_count"] >= 0 and body["eval_count"] >= 0
# streaming (NDJSON, one JSON object per line)
async with client.stream("POST", "/api/chat",
json={"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hi"}],
"stream": True}) as resp:
async for line in resp.aiter_lines():
if line:
obj = json.loads(line) # last obj has done=True + counts
The ``response_text`` of a chat reply is deterministic: it echoes the last
user message reversed-cased token-by-token, which keeps token counts stable and
assertions simple. Override by passing ``reply_text`` in the request body
(non-standard test hook, ignored by real Ollama).
A pytest fixture :func:`mock_ollama_app` is also provided for direct use.
"""
from __future__ import annotations
import json
from collections.abc import AsyncIterator, Iterable
from datetime import UTC, datetime
from typing import Any
import pytest
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
NDJSON_MEDIA_TYPE = "application/x-ndjson"
SSE_MEDIA_TYPE = "text/event-stream"
# A small, fixed model catalogue the gateway can filter against in tests.
DEFAULT_MODELS: tuple[str, ...] = ("llama3.1:8b", "mistral:7b", "nomic-embed-text")
def _now_iso() -> str:
return datetime.now(UTC).isoformat().replace("+00:00", "Z")
def _reply_for(prompt: str, override: str | None) -> str:
"""Deterministic reply text for a given prompt.
Real Ollama generates text; for tests we want determinism. Default reply is
a fixed canned sentence; tests may pass ``reply_text`` to control it.
"""
if override is not None:
return override
if not prompt:
return "Hello from mock Ollama."
return f"Echo: {prompt}"
def _tokenize(text: str) -> list[str]:
"""Whitespace tokenizer good enough for deterministic chunking/counts."""
return text.split()
def _final_metrics(prompt_tokens: int, completion_tokens: int) -> dict[str, Any]:
"""The timing/usage fields Ollama attaches to the terminal stream object.
The gateway's token counter reads ``prompt_eval_count`` and ``eval_count``;
the rest are included for realism so tests can assert they are *not* leaked
to clients if the gateway is supposed to strip them.
"""
return {
"total_duration": 1_234_567_890,
"load_duration": 12_345_678,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": 23_456_789,
"eval_count": completion_tokens,
"eval_duration": 34_567_890,
}
def _chat_chunk(
model: str,
*,
content: str,
done: bool,
prompt_tokens: int = 0,
completion_tokens: int = 0,
) -> dict[str, Any]:
"""One ``/api/chat`` NDJSON object."""
obj: dict[str, Any] = {
"model": model,
"created_at": _now_iso(),
"message": {"role": "assistant", "content": content},
"done": done,
}
if done:
obj["done_reason"] = "stop"
obj.update(_final_metrics(prompt_tokens, completion_tokens))
return obj
def _generate_chunk(
model: str,
*,
response: str,
done: bool,
prompt_tokens: int = 0,
completion_tokens: int = 0,
) -> dict[str, Any]:
"""One ``/api/generate`` NDJSON object."""
obj: dict[str, Any] = {
"model": model,
"created_at": _now_iso(),
"response": response,
"done": done,
}
if done:
obj["done_reason"] = "stop"
obj["context"] = [1, 2, 3]
obj.update(_final_metrics(prompt_tokens, completion_tokens))
return obj
async def _ndjson_stream(objects: Iterable[dict[str, Any]]) -> AsyncIterator[bytes]:
"""Serialize objects as newline-delimited JSON bytes (Ollama stream format)."""
for obj in objects:
yield (json.dumps(obj) + "\n").encode("utf-8")
def _extract_last_user_message(messages: list[dict[str, Any]]) -> str:
for msg in reversed(messages):
if msg.get("role") == "user":
content = msg.get("content", "")
return content if isinstance(content, str) else ""
return ""
def create_mock_ollama(models: Iterable[str] = DEFAULT_MODELS) -> FastAPI:
"""Build and return a FastAPI app emulating the Ollama API.
Parameters
----------
models:
Catalogue returned by ``/api/tags``. Defaults to :data:`DEFAULT_MODELS`.
"""
app = FastAPI(title="mock-ollama", docs_url=None, redoc_url=None)
catalogue = tuple(models)
@app.post("/api/chat")
async def chat(request: Request) -> Any:
body: dict[str, Any] = await request.json()
model: str = body.get("model", "llama3.1:8b")
stream: bool = body.get("stream", True)
reply_override: str | None = body.get("reply_text")
prompt = _extract_last_user_message(body.get("messages", []))
reply = _reply_for(prompt, reply_override)
prompt_tokens = len(_tokenize(prompt))
completion_tokens = len(_tokenize(reply))
if not stream:
obj = _chat_chunk(
model,
content=reply,
done=True,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
return JSONResponse(obj)
words = _tokenize(reply) or [""]
def chunks() -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
for i, word in enumerate(words):
piece = word if i == 0 else f" {word}"
out.append(_chat_chunk(model, content=piece, done=False))
out.append(
_chat_chunk(
model,
content="",
done=True,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
)
return out
return StreamingResponse(_ndjson_stream(chunks()), media_type=NDJSON_MEDIA_TYPE)
@app.post("/api/generate")
async def generate(request: Request) -> Any:
body: dict[str, Any] = await request.json()
model: str = body.get("model", "llama3.1:8b")
stream: bool = body.get("stream", True)
prompt = body.get("prompt", "")
reply = _reply_for(prompt, body.get("reply_text"))
prompt_tokens = len(_tokenize(prompt))
completion_tokens = len(_tokenize(reply))
if not stream:
obj = _generate_chunk(
model,
response=reply,
done=True,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
return JSONResponse(obj)
words = _tokenize(reply) or [""]
def chunks() -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
for i, word in enumerate(words):
piece = word if i == 0 else f" {word}"
out.append(_generate_chunk(model, response=piece, done=False))
out.append(
_generate_chunk(
model,
response="",
done=True,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
)
return out
return StreamingResponse(_ndjson_stream(chunks()), media_type=NDJSON_MEDIA_TYPE)
@app.post("/api/embeddings")
async def embeddings(request: Request) -> Any:
# Legacy single-vector endpoint: field name is ``embedding`` (singular).
body: dict[str, Any] = await request.json()
prompt = body.get("prompt", "")
prompt_tokens = len(_tokenize(prompt))
return JSONResponse(
{
"embedding": [0.0, 0.1, 0.2, 0.3],
# Ollama does not return eval_count for embeddings (SPEC §13.1);
# only prompt_eval_count is meaningful for cost accounting.
"prompt_eval_count": prompt_tokens,
}
)
@app.post("/api/embed")
async def embed(request: Request) -> Any:
# Newer batch endpoint: field name is ``embeddings`` (plural list).
body: dict[str, Any] = await request.json()
model: str = body.get("model", "nomic-embed-text")
inp = body.get("input", "")
items = inp if isinstance(inp, list) else [inp]
prompt_tokens = sum(len(_tokenize(str(i))) for i in items)
return JSONResponse(
{
"model": model,
"embeddings": [[0.0, 0.1, 0.2, 0.3] for _ in items],
"prompt_eval_count": prompt_tokens,
}
)
@app.get("/api/tags")
async def tags() -> Any:
return JSONResponse(
{
"models": [
{
"name": name,
"model": name,
"modified_at": _now_iso(),
"size": 4_700_000_000,
"digest": "sha256:deadbeef",
"details": {
"family": name.split(":", 1)[0],
"parameter_size": "8B",
"quantization_level": "Q4_0",
},
}
for name in catalogue
]
}
)
@app.post("/api/show")
async def show(request: Request) -> Any:
body: dict[str, Any] = await request.json()
name = body.get("model") or body.get("name", "llama3.1:8b")
# Real Ollama returns system prompt + template here; the gateway is
# expected to strip those. We include them so a sanitization test can
# assert they don't reach the client.
return JSONResponse(
{
"modelfile": "FROM llama3.1:8b",
"template": "{{ .System }} {{ .Prompt }}",
"system": "You are a secret internal system prompt.",
"details": {"family": str(name).split(":", 1)[0], "parameter_size": "8B"},
}
)
@app.get("/api/version")
async def version() -> Any:
# The gateway overrides this with its own version (SPEC §6.1); the mock
# returns a plausible upstream version so a test can confirm it is NOT
# the value the gateway reports.
return JSONResponse({"version": "0.5.7"})
return app
def ollama_chunks_to_openai_sse(
chunks: Iterable[dict[str, Any]],
*,
model: str = "llama3.1:8b",
completion_id: str = "chatcmpl-mock",
) -> bytes:
"""Convert native Ollama chat chunks into the OpenAI SSE byte stream.
Pure helper that mirrors what the gateway's OpenAI-compat translation layer
must emit (SPEC §6.3): one ``data: {...}\\n\\n`` event per delta, a final
event carrying ``finish_reason``/``usage``, then ``data: [DONE]\\n\\n``.
Useful for asserting expected SSE output in translation tests without
standing up the full gateway. Phase 3 may replace this with golden fixtures.
"""
created = int(datetime.now(UTC).timestamp())
events: list[str] = []
for chunk in chunks:
message = chunk.get("message", {})
content = message.get("content", "")
done = chunk.get("done", False)
if done:
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"usage": {
"prompt_tokens": chunk.get("prompt_eval_count", 0),
"completion_tokens": chunk.get("eval_count", 0),
"total_tokens": chunk.get("prompt_eval_count", 0)
+ chunk.get("eval_count", 0),
},
}
else:
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}],
}
events.append(f"data: {json.dumps(payload)}\n\n")
events.append("data: [DONE]\n\n")
return "".join(events).encode("utf-8")
@pytest.fixture()
def mock_ollama_app() -> FastAPI:
"""Pytest fixture returning a fresh mock Ollama ASGI app."""
return create_mock_ollama()