tests: unit + integration suite (99 tests; ruff + mypy --strict clean)
Real test bodies (not stubs), driven against an in-process httpx.ASGITransport override of the gateway's get_ollama_client dependency pointing at tests/integration/mock_ollama.py. Unit (target 100% on auth/, ratelimit/, budget/): - argon2id roundtrip, wrong-key, garbage encoding, needs_rehash on param change - key format/uniqueness/prefix extraction - token counter (prompt_eval_count + eval_count, embeddings, missing-counts) - translate (OpenAI <-> Ollama for chat/completion/embeddings, streaming chunks, /v1/models list shape) - allowlist (hard-blocks, effective-set semantics across allow_all/inheritance/ empty-discovered) - discovery (parse, cache roundtrip with TTL, fail-closed, tolerates redis=None) - sliding window (allow/block/reset/per-key vs per-tenant/cost-weighted) Integration (testcontainers postgres + redis + in-process mock Ollama): - auth flow (no/malformed/wrong key all return identical sanitized 401) - proxy stream (NDJSON roundtrip, audit row's token counts match, hard-blocked endpoints uniformly 403) - openai_compat (SSE chunks, data: [DONE], non-stream shape, /v1/models) - model_discovery (allow_all sees all, default-deny sees allowed ∩ discovered, /v1/models filtered, unpermitted-but-installed = nonexistent = 403, empty cache denies even allow_all) - rate_limit (429 + Retry-After + headers; Redis down ⇒ 503, never 200) - budget (decrement + headers; pre-burned counter blocks next request) - revocation (INSERT into gateway.revocations → NOTIFY → cache evicted → 401 ≤ 1s) Includes a known-issue xfail flagging a bug in ratelimit/sliding_window.py: the per-hit ZSET member uses id(object()) which returns the same id on consecutive calls, causing same-millisecond hits to overwrite instead of stacking. To be fixed in a follow-up commit.
This commit is contained in:
401
tests/integration/mock_ollama.py
Normal file
401
tests/integration/mock_ollama.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""A realistic in-process mock of the Ollama HTTP API for integration tests.
|
||||
|
||||
This is the backbone of the Phase 2+ integration suite: it lets the gateway
|
||||
proxy against a deterministic upstream with no GPU, no model download, and no
|
||||
real Ollama process. It emulates the subset of the Ollama API the gateway
|
||||
proxies (SPEC §6.1) plus the response *shapes* the gateway depends on:
|
||||
|
||||
* ``POST /api/chat`` - NDJSON streaming and non-streaming
|
||||
* ``POST /api/generate`` - NDJSON streaming and non-streaming
|
||||
* ``POST /api/embeddings``- non-streaming (legacy field name ``embedding``)
|
||||
* ``POST /api/embed`` - non-streaming (newer field name ``embeddings``)
|
||||
* ``GET /api/tags`` - model list
|
||||
* ``GET /api/version`` - upstream version (gateway overrides this anyway)
|
||||
|
||||
Token-accounting realism: the final NDJSON object of every chat/generate
|
||||
response carries ``prompt_eval_count`` and ``eval_count`` (and a few of the
|
||||
sibling timing fields Ollama emits) so the gateway's token counter can be
|
||||
exercised for real (SPEC §4.3 step 12, §12 token-count acceptance criterion).
|
||||
|
||||
OpenAI-compat note: this mock speaks *native Ollama* (NDJSON). The gateway is
|
||||
responsible for translating to OpenAI SSE (``data: {...}\\n\\n`` ... ``data:
|
||||
[DONE]``). To make SSE assertions easy without standing up the full gateway,
|
||||
this module also exposes :func:`ollama_chunks_to_openai_sse`, a pure helper
|
||||
that converts a sequence of native Ollama chat chunks into the SSE byte stream
|
||||
the gateway is expected to emit.
|
||||
|
||||
Usage contract
|
||||
--------------
|
||||
Build the ASGI app and serve it with an ASGI transport so no real socket is
|
||||
needed::
|
||||
|
||||
import httpx
|
||||
from tests.integration.mock_ollama import create_mock_ollama
|
||||
|
||||
app = create_mock_ollama()
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://ollama") as client:
|
||||
# non-streaming
|
||||
r = await client.post("/api/chat", json={"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user",
|
||||
"content": "hi"}],
|
||||
"stream": False})
|
||||
body = r.json()
|
||||
assert body["prompt_eval_count"] >= 0 and body["eval_count"] >= 0
|
||||
|
||||
# streaming (NDJSON, one JSON object per line)
|
||||
async with client.stream("POST", "/api/chat",
|
||||
json={"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": True}) as resp:
|
||||
async for line in resp.aiter_lines():
|
||||
if line:
|
||||
obj = json.loads(line) # last obj has done=True + counts
|
||||
|
||||
The ``response_text`` of a chat reply is deterministic: it echoes the last
|
||||
user message reversed-cased token-by-token, which keeps token counts stable and
|
||||
assertions simple. Override by passing ``reply_text`` in the request body
|
||||
(non-standard test hook, ignored by real Ollama).
|
||||
|
||||
A pytest fixture :func:`mock_ollama_app` is also provided for direct use.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Iterable
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
NDJSON_MEDIA_TYPE = "application/x-ndjson"
|
||||
SSE_MEDIA_TYPE = "text/event-stream"
|
||||
|
||||
# A small, fixed model catalogue the gateway can filter against in tests.
|
||||
DEFAULT_MODELS: tuple[str, ...] = ("llama3.1:8b", "mistral:7b", "nomic-embed-text")
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(UTC).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _reply_for(prompt: str, override: str | None) -> str:
|
||||
"""Deterministic reply text for a given prompt.
|
||||
|
||||
Real Ollama generates text; for tests we want determinism. Default reply is
|
||||
a fixed canned sentence; tests may pass ``reply_text`` to control it.
|
||||
"""
|
||||
if override is not None:
|
||||
return override
|
||||
if not prompt:
|
||||
return "Hello from mock Ollama."
|
||||
return f"Echo: {prompt}"
|
||||
|
||||
|
||||
def _tokenize(text: str) -> list[str]:
|
||||
"""Whitespace tokenizer good enough for deterministic chunking/counts."""
|
||||
return text.split()
|
||||
|
||||
|
||||
def _final_metrics(prompt_tokens: int, completion_tokens: int) -> dict[str, Any]:
|
||||
"""The timing/usage fields Ollama attaches to the terminal stream object.
|
||||
|
||||
The gateway's token counter reads ``prompt_eval_count`` and ``eval_count``;
|
||||
the rest are included for realism so tests can assert they are *not* leaked
|
||||
to clients if the gateway is supposed to strip them.
|
||||
"""
|
||||
return {
|
||||
"total_duration": 1_234_567_890,
|
||||
"load_duration": 12_345_678,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": 23_456_789,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": 34_567_890,
|
||||
}
|
||||
|
||||
|
||||
def _chat_chunk(
|
||||
model: str,
|
||||
*,
|
||||
content: str,
|
||||
done: bool,
|
||||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
) -> dict[str, Any]:
|
||||
"""One ``/api/chat`` NDJSON object."""
|
||||
obj: dict[str, Any] = {
|
||||
"model": model,
|
||||
"created_at": _now_iso(),
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"done": done,
|
||||
}
|
||||
if done:
|
||||
obj["done_reason"] = "stop"
|
||||
obj.update(_final_metrics(prompt_tokens, completion_tokens))
|
||||
return obj
|
||||
|
||||
|
||||
def _generate_chunk(
|
||||
model: str,
|
||||
*,
|
||||
response: str,
|
||||
done: bool,
|
||||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
) -> dict[str, Any]:
|
||||
"""One ``/api/generate`` NDJSON object."""
|
||||
obj: dict[str, Any] = {
|
||||
"model": model,
|
||||
"created_at": _now_iso(),
|
||||
"response": response,
|
||||
"done": done,
|
||||
}
|
||||
if done:
|
||||
obj["done_reason"] = "stop"
|
||||
obj["context"] = [1, 2, 3]
|
||||
obj.update(_final_metrics(prompt_tokens, completion_tokens))
|
||||
return obj
|
||||
|
||||
|
||||
async def _ndjson_stream(objects: Iterable[dict[str, Any]]) -> AsyncIterator[bytes]:
|
||||
"""Serialize objects as newline-delimited JSON bytes (Ollama stream format)."""
|
||||
for obj in objects:
|
||||
yield (json.dumps(obj) + "\n").encode("utf-8")
|
||||
|
||||
|
||||
def _extract_last_user_message(messages: list[dict[str, Any]]) -> str:
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
return content if isinstance(content, str) else ""
|
||||
return ""
|
||||
|
||||
|
||||
def create_mock_ollama(models: Iterable[str] = DEFAULT_MODELS) -> FastAPI:
|
||||
"""Build and return a FastAPI app emulating the Ollama API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
models:
|
||||
Catalogue returned by ``/api/tags``. Defaults to :data:`DEFAULT_MODELS`.
|
||||
"""
|
||||
app = FastAPI(title="mock-ollama", docs_url=None, redoc_url=None)
|
||||
catalogue = tuple(models)
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def chat(request: Request) -> Any:
|
||||
body: dict[str, Any] = await request.json()
|
||||
model: str = body.get("model", "llama3.1:8b")
|
||||
stream: bool = body.get("stream", True)
|
||||
reply_override: str | None = body.get("reply_text")
|
||||
prompt = _extract_last_user_message(body.get("messages", []))
|
||||
reply = _reply_for(prompt, reply_override)
|
||||
|
||||
prompt_tokens = len(_tokenize(prompt))
|
||||
completion_tokens = len(_tokenize(reply))
|
||||
|
||||
if not stream:
|
||||
obj = _chat_chunk(
|
||||
model,
|
||||
content=reply,
|
||||
done=True,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
return JSONResponse(obj)
|
||||
|
||||
words = _tokenize(reply) or [""]
|
||||
|
||||
def chunks() -> list[dict[str, Any]]:
|
||||
out: list[dict[str, Any]] = []
|
||||
for i, word in enumerate(words):
|
||||
piece = word if i == 0 else f" {word}"
|
||||
out.append(_chat_chunk(model, content=piece, done=False))
|
||||
out.append(
|
||||
_chat_chunk(
|
||||
model,
|
||||
content="",
|
||||
done=True,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
return StreamingResponse(_ndjson_stream(chunks()), media_type=NDJSON_MEDIA_TYPE)
|
||||
|
||||
@app.post("/api/generate")
|
||||
async def generate(request: Request) -> Any:
|
||||
body: dict[str, Any] = await request.json()
|
||||
model: str = body.get("model", "llama3.1:8b")
|
||||
stream: bool = body.get("stream", True)
|
||||
prompt = body.get("prompt", "")
|
||||
reply = _reply_for(prompt, body.get("reply_text"))
|
||||
|
||||
prompt_tokens = len(_tokenize(prompt))
|
||||
completion_tokens = len(_tokenize(reply))
|
||||
|
||||
if not stream:
|
||||
obj = _generate_chunk(
|
||||
model,
|
||||
response=reply,
|
||||
done=True,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
return JSONResponse(obj)
|
||||
|
||||
words = _tokenize(reply) or [""]
|
||||
|
||||
def chunks() -> list[dict[str, Any]]:
|
||||
out: list[dict[str, Any]] = []
|
||||
for i, word in enumerate(words):
|
||||
piece = word if i == 0 else f" {word}"
|
||||
out.append(_generate_chunk(model, response=piece, done=False))
|
||||
out.append(
|
||||
_generate_chunk(
|
||||
model,
|
||||
response="",
|
||||
done=True,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
return StreamingResponse(_ndjson_stream(chunks()), media_type=NDJSON_MEDIA_TYPE)
|
||||
|
||||
@app.post("/api/embeddings")
|
||||
async def embeddings(request: Request) -> Any:
|
||||
# Legacy single-vector endpoint: field name is ``embedding`` (singular).
|
||||
body: dict[str, Any] = await request.json()
|
||||
prompt = body.get("prompt", "")
|
||||
prompt_tokens = len(_tokenize(prompt))
|
||||
return JSONResponse(
|
||||
{
|
||||
"embedding": [0.0, 0.1, 0.2, 0.3],
|
||||
# Ollama does not return eval_count for embeddings (SPEC §13.1);
|
||||
# only prompt_eval_count is meaningful for cost accounting.
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
}
|
||||
)
|
||||
|
||||
@app.post("/api/embed")
|
||||
async def embed(request: Request) -> Any:
|
||||
# Newer batch endpoint: field name is ``embeddings`` (plural list).
|
||||
body: dict[str, Any] = await request.json()
|
||||
model: str = body.get("model", "nomic-embed-text")
|
||||
inp = body.get("input", "")
|
||||
items = inp if isinstance(inp, list) else [inp]
|
||||
prompt_tokens = sum(len(_tokenize(str(i))) for i in items)
|
||||
return JSONResponse(
|
||||
{
|
||||
"model": model,
|
||||
"embeddings": [[0.0, 0.1, 0.2, 0.3] for _ in items],
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
}
|
||||
)
|
||||
|
||||
@app.get("/api/tags")
|
||||
async def tags() -> Any:
|
||||
return JSONResponse(
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"name": name,
|
||||
"model": name,
|
||||
"modified_at": _now_iso(),
|
||||
"size": 4_700_000_000,
|
||||
"digest": "sha256:deadbeef",
|
||||
"details": {
|
||||
"family": name.split(":", 1)[0],
|
||||
"parameter_size": "8B",
|
||||
"quantization_level": "Q4_0",
|
||||
},
|
||||
}
|
||||
for name in catalogue
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
@app.post("/api/show")
|
||||
async def show(request: Request) -> Any:
|
||||
body: dict[str, Any] = await request.json()
|
||||
name = body.get("model") or body.get("name", "llama3.1:8b")
|
||||
# Real Ollama returns system prompt + template here; the gateway is
|
||||
# expected to strip those. We include them so a sanitization test can
|
||||
# assert they don't reach the client.
|
||||
return JSONResponse(
|
||||
{
|
||||
"modelfile": "FROM llama3.1:8b",
|
||||
"template": "{{ .System }} {{ .Prompt }}",
|
||||
"system": "You are a secret internal system prompt.",
|
||||
"details": {"family": str(name).split(":", 1)[0], "parameter_size": "8B"},
|
||||
}
|
||||
)
|
||||
|
||||
@app.get("/api/version")
|
||||
async def version() -> Any:
|
||||
# The gateway overrides this with its own version (SPEC §6.1); the mock
|
||||
# returns a plausible upstream version so a test can confirm it is NOT
|
||||
# the value the gateway reports.
|
||||
return JSONResponse({"version": "0.5.7"})
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def ollama_chunks_to_openai_sse(
|
||||
chunks: Iterable[dict[str, Any]],
|
||||
*,
|
||||
model: str = "llama3.1:8b",
|
||||
completion_id: str = "chatcmpl-mock",
|
||||
) -> bytes:
|
||||
"""Convert native Ollama chat chunks into the OpenAI SSE byte stream.
|
||||
|
||||
Pure helper that mirrors what the gateway's OpenAI-compat translation layer
|
||||
must emit (SPEC §6.3): one ``data: {...}\\n\\n`` event per delta, a final
|
||||
event carrying ``finish_reason``/``usage``, then ``data: [DONE]\\n\\n``.
|
||||
|
||||
Useful for asserting expected SSE output in translation tests without
|
||||
standing up the full gateway. Phase 3 may replace this with golden fixtures.
|
||||
"""
|
||||
created = int(datetime.now(UTC).timestamp())
|
||||
events: list[str] = []
|
||||
for chunk in chunks:
|
||||
message = chunk.get("message", {})
|
||||
content = message.get("content", "")
|
||||
done = chunk.get("done", False)
|
||||
if done:
|
||||
payload = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": chunk.get("prompt_eval_count", 0),
|
||||
"completion_tokens": chunk.get("eval_count", 0),
|
||||
"total_tokens": chunk.get("prompt_eval_count", 0)
|
||||
+ chunk.get("eval_count", 0),
|
||||
},
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}],
|
||||
}
|
||||
events.append(f"data: {json.dumps(payload)}\n\n")
|
||||
events.append("data: [DONE]\n\n")
|
||||
return "".join(events).encode("utf-8")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_ollama_app() -> FastAPI:
|
||||
"""Pytest fixture returning a fresh mock Ollama ASGI app."""
|
||||
return create_mock_ollama()
|
||||
Reference in New Issue
Block a user