tests: unit + integration suite (99 tests; ruff + mypy --strict clean)

Real test bodies (not stubs), driven against an in-process httpx.ASGITransport
override of the gateway's get_ollama_client dependency pointing at
tests/integration/mock_ollama.py.

Unit (target 100% on auth/, ratelimit/, budget/):
- argon2id roundtrip, wrong-key, garbage encoding, needs_rehash on param change
- key format/uniqueness/prefix extraction
- token counter (prompt_eval_count + eval_count, embeddings, missing-counts)
- translate (OpenAI <-> Ollama for chat/completion/embeddings, streaming chunks,
  /v1/models list shape)
- allowlist (hard-blocks, effective-set semantics across allow_all/inheritance/
  empty-discovered)
- discovery (parse, cache roundtrip with TTL, fail-closed, tolerates redis=None)
- sliding window (allow/block/reset/per-key vs per-tenant/cost-weighted)

Integration (testcontainers postgres + redis + in-process mock Ollama):
- auth flow (no/malformed/wrong key all return identical sanitized 401)
- proxy stream (NDJSON roundtrip, audit row's token counts match, hard-blocked
  endpoints uniformly 403)
- openai_compat (SSE chunks, data: [DONE], non-stream shape, /v1/models)
- model_discovery (allow_all sees all, default-deny sees allowed ∩ discovered,
  /v1/models filtered, unpermitted-but-installed = nonexistent = 403,
  empty cache denies even allow_all)
- rate_limit (429 + Retry-After + headers; Redis down ⇒ 503, never 200)
- budget (decrement + headers; pre-burned counter blocks next request)
- revocation (INSERT into gateway.revocations → NOTIFY → cache evicted → 401 ≤ 1s)

Includes a known-issue xfail flagging a bug in ratelimit/sliding_window.py:
the per-hit ZSET member uses id(object()) which returns the same id on
consecutive calls, causing same-millisecond hits to overwrite instead of
stacking. To be fixed in a follow-up commit.
This commit is contained in:
Stephan Berbig
2026-05-26 20:52:33 +02:00
parent 6a92bc8ce9
commit 844b02aade
23 changed files with 2567 additions and 0 deletions

1
tests/unit/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Unit test package."""

View File

@@ -0,0 +1,115 @@
"""Unit tests for ``neuronetz_gateway.proxy.allowlist``.
Two concerns (SPEC §4.3 step 7-8, §4.6, §6.1-6.2):
1. **Hard-blocked endpoints** — mutating Ollama endpoints and ``/api/ps`` are
always blocked, not configurable (``is_hard_blocked``).
2. **Effective model set** (``resolve_effective_models`` / ``is_model_allowed``):
* ``allow_all`` ⇒ all discovered
* default-deny ⇒ ``allowed_models ∩ discovered``
* stale/typo'd allowlist entries never resolve (not in discovered)
* empty discovered ⇒ deny, even under ``allow_all`` (fail-closed)
The resolver merges the *already-resolved* ``allow_all``/``allowed_models``
inputs (key-vs-tenant precedence per SPEC §13.7 is applied by the caller before
this point; that precedence is exercised in the integration model-discovery
tests).
"""
from __future__ import annotations
import pytest
from neuronetz_gateway.proxy import allowlist
# --- hard-blocked endpoint allowlist ---------------------------------------
@pytest.mark.parametrize(
"path",
["/api/pull", "/api/push", "/api/create", "/api/copy", "/api/delete", "/api/ps"],
)
def test_mutating_and_ps_endpoints_hard_blocked(path: str) -> None:
assert allowlist.is_hard_blocked(path) is True
@pytest.mark.parametrize("path", ["/api/blobs", "/api/blobs/sha256:abc", "/api/blobs/x/y"])
def test_blob_prefix_hard_blocked(path: str) -> None:
assert allowlist.is_hard_blocked(path) is True
@pytest.mark.parametrize(
"path", ["/api/chat", "/api/generate", "/api/embed", "/api/tags", "/api/show", "/api/version"]
)
def test_allowlisted_endpoints_not_hard_blocked(path: str) -> None:
assert allowlist.is_hard_blocked(path) is False
# --- effective-set resolution (SPEC §4.3 step 7 / §4.6) --------------------
DISCOVERED = frozenset({"llama3.1:8b", "mistral:7b", "nomic-embed-text"})
def test_allow_all_returns_all_discovered() -> None:
eff = allowlist.resolve_effective_models(
allow_all=True, allowed_models=(), discovered=DISCOVERED
)
assert eff == DISCOVERED
def test_default_deny_intersects_discovered() -> None:
# "qwen" is allowed but not installed => must not resolve (stale entry).
eff = allowlist.resolve_effective_models(
allow_all=False,
allowed_models=("llama3.1:8b", "qwen:0.5b"),
discovered=DISCOVERED,
)
assert eff == frozenset({"llama3.1:8b"})
def test_default_deny_empty_allowlist_denies_all() -> None:
eff = allowlist.resolve_effective_models(
allow_all=False, allowed_models=(), discovered=DISCOVERED
)
assert eff == frozenset()
def test_empty_discovered_denies_everything_even_allow_all() -> None:
# Fail-closed: empty discovered set => no model resolves, even allow_all.
eff = allowlist.resolve_effective_models(
allow_all=True, allowed_models=("llama3.1:8b",), discovered=frozenset()
)
assert eff == frozenset()
def test_is_model_allowed_membership() -> None:
assert (
allowlist.is_model_allowed(
"llama3.1:8b", allow_all=False, allowed_models=("llama3.1:8b",), discovered=DISCOVERED
)
is True
)
# Installed but not permitted.
assert (
allowlist.is_model_allowed(
"mistral:7b", allow_all=False, allowed_models=("llama3.1:8b",), discovered=DISCOVERED
)
is False
)
# Permitted but not installed (typo'd / stale) — indistinguishable from
# not-allowed to the caller (SPEC §13.6 no existence disclosure).
assert (
allowlist.is_model_allowed(
"ghost:1b", allow_all=False, allowed_models=("ghost:1b",), discovered=DISCOVERED
)
is False
)
def test_is_model_allowed_fail_closed_empty_discovered() -> None:
assert (
allowlist.is_model_allowed(
"llama3.1:8b", allow_all=True, allowed_models=(), discovered=frozenset()
)
is False
)

View File

@@ -0,0 +1,142 @@
"""Unit tests for model discovery (SPEC §4.6).
A background poller queries Ollama ``GET /api/tags``, parses the installed model
set, caches it in Redis (TTL) + in-process (:class:`DiscoveryCache`), and **fails
closed**: an empty/expired discovered set means no model resolves (deny). On an
upstream error ``refresh_once`` returns ``False`` and leaves the caches untouched
so they expire on their own TTL (stale-expired ⇒ empty ⇒ deny); discovery never
opens access.
Driven against the in-process ``mock_ollama`` upstream and the real
``redis_client`` testcontainer fixture (skips cleanly without Docker).
"""
from __future__ import annotations
from typing import Any
import httpx
import pytest
import redis.asyncio as aioredis
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from neuronetz_gateway.config import Settings
from neuronetz_gateway.proxy import discovery
def _ollama_client(app: FastAPI) -> httpx.AsyncClient:
transport = httpx.ASGITransport(app=app)
return httpx.AsyncClient(transport=transport, base_url="http://ollama")
def _settings() -> Settings:
return Settings(model_discovery_cache_ttl_s=120, model_discovery_refresh_s=60)
# --- pure parsing ----------------------------------------------------------
def test_names_of_extracts_model_names() -> None:
models = [
discovery.DiscoveredModel(name="llama3.1:8b", family="llama"),
discovery.DiscoveredModel(name="mistral:7b"),
]
assert discovery.names_of(models) == frozenset({"llama3.1:8b", "mistral:7b"})
def test_names_of_empty() -> None:
assert discovery.names_of([]) == frozenset()
# --- fetch + parse against the mock upstream -------------------------------
@pytest.mark.asyncio
async def test_fetch_tags_parses_mock_catalogue(mock_ollama_app: FastAPI) -> None:
async with _ollama_client(mock_ollama_app) as ollama:
models = await discovery.fetch_tags(ollama)
names = discovery.names_of(models)
assert {"llama3.1:8b", "mistral:7b", "nomic-embed-text"} <= names
# Sanitized metadata is captured (family parsed from the mock's details).
by_name = {m.name: m for m in models}
assert by_name["llama3.1:8b"].family == "llama3.1"
assert by_name["llama3.1:8b"].parameter_size == "8B"
@pytest.mark.asyncio
async def test_fetch_tags_raises_on_upstream_error() -> None:
broken = FastAPI()
@broken.get("/api/tags")
async def _tags() -> Any:
return JSONResponse({"error": "boom"}, status_code=500)
async with _ollama_client(broken) as ollama:
with pytest.raises(httpx.HTTPError):
await discovery.fetch_tags(ollama)
# --- redis cache round-trip ------------------------------------------------
@pytest.mark.asyncio
async def test_redis_cache_roundtrip(redis_client: aioredis.Redis) -> None:
models = [discovery.DiscoveredModel(name="llama3.1:8b"), discovery.DiscoveredModel(name="x:1b")]
await discovery.write_discovered_to_redis(redis_client, models, ttl_s=120)
names = await discovery.read_discovered_from_redis(redis_client)
assert names == frozenset({"llama3.1:8b", "x:1b"})
# TTL was applied (staleness expires) per SPEC §4.6.
assert 0 < await redis_client.ttl(discovery.REDIS_DISCOVERED_KEY) <= 120
@pytest.mark.asyncio
async def test_read_discovered_miss_returns_empty(redis_client: aioredis.Redis) -> None:
# Cache miss / expiry => empty set => fail-closed deny.
assert await discovery.read_discovered_from_redis(redis_client) == frozenset()
# --- refresh_once: success and fail-closed ---------------------------------
@pytest.mark.asyncio
async def test_refresh_once_populates_in_process_and_redis(
redis_client: aioredis.Redis, mock_ollama_app: FastAPI
) -> None:
cache = discovery.DiscoveryCache()
async with _ollama_client(mock_ollama_app) as ollama:
ok = await discovery.refresh_once(ollama, redis_client, cache, _settings())
assert ok is True
assert {"llama3.1:8b", "mistral:7b", "nomic-embed-text"} <= cache.names
# Mirrored into Redis under the SPEC §4.6 key.
assert {"llama3.1:8b", "mistral:7b"} <= await discovery.read_discovered_from_redis(redis_client)
@pytest.mark.asyncio
async def test_refresh_once_fail_closed_on_upstream_error(
redis_client: aioredis.Redis,
) -> None:
broken = FastAPI()
@broken.get("/api/tags")
async def _tags() -> Any:
return JSONResponse({"error": "boom"}, status_code=500)
cache = discovery.DiscoveryCache()
async with _ollama_client(broken) as ollama:
ok = await discovery.refresh_once(ollama, redis_client, cache, _settings())
# Refresh reports failure; the in-process cache stays empty (no models
# resolve) — discovery never opens access on error.
assert ok is False
assert cache.names == frozenset()
@pytest.mark.asyncio
async def test_refresh_once_tolerates_missing_redis(mock_ollama_app: FastAPI) -> None:
# redis_client=None must still refresh the in-process cache (best-effort
# Redis fill), not crash the poller.
cache = discovery.DiscoveryCache()
async with _ollama_client(mock_ollama_app) as ollama:
ok = await discovery.refresh_once(ollama, None, cache, _settings())
assert ok is True
assert "llama3.1:8b" in cache.names

View File

@@ -0,0 +1,93 @@
"""Unit tests for ``neuronetz_gateway.auth.hashing`` (argon2id wrapper).
Covers SPEC §3/§9: hash, constant-time verify, wrong-key rejection, and
rehash-on-parameter-change. Targets 100% coverage of ``auth/hashing.py``.
These bodies are real; while ``hashing.py`` is a Phase-1 stub each call routes
through :func:`tests._skip.call_or_skip`, which skips (not fails) on
``NotImplementedError`` and lets the test self-activate once Backend lands.
"""
from __future__ import annotations
import pytest
from argon2 import PasswordHasher
from neuronetz_gateway.auth import hashing
from neuronetz_gateway.config import Settings
from tests._skip import call_or_skip
# A fast hasher so the suite stays quick; the wrapper must round-trip regardless
# of the cost parameters, which are what we actually exercise here.
_FAST = PasswordHasher(time_cost=1, memory_cost=8, parallelism=1)
SECRET = "nz_abcdef012345deadbeefcafebabe00112233" # noqa: S105 - test fixture, not a real key
def _settings(**overrides: object) -> Settings:
base: dict[str, object] = {
"argon2_time_cost": 1,
"argon2_memory_cost_kib": 8,
"argon2_parallelism": 1,
}
base.update(overrides)
return Settings(**base) # type: ignore[arg-type] # pydantic-settings kwargs
def test_build_hasher_uses_settings() -> None:
hasher = hashing.build_hasher(_settings(argon2_time_cost=2))
assert isinstance(hasher, PasswordHasher)
assert hasher.time_cost == 2
def test_hash_secret_returns_argon2id_encoded() -> None:
encoded = call_or_skip(hashing.hash_secret, _FAST, SECRET)
assert isinstance(encoded, str)
# argon2id encoded hashes begin with the variant tag.
assert encoded.startswith("$argon2id$")
# Never echoes the plaintext back.
assert SECRET not in encoded
def test_hash_is_salted_unique_per_call() -> None:
a = call_or_skip(hashing.hash_secret, _FAST, SECRET)
b = call_or_skip(hashing.hash_secret, _FAST, SECRET)
assert a != b # random salt => distinct encodings
def test_verify_secret_roundtrip_true() -> None:
encoded = call_or_skip(hashing.hash_secret, _FAST, SECRET)
assert call_or_skip(hashing.verify_secret, _FAST, encoded, SECRET) is True
def test_verify_wrong_secret_returns_false_not_raise() -> None:
encoded = call_or_skip(hashing.hash_secret, _FAST, SECRET)
# Constant-time path must return False, never propagate argon2's
# VerifyMismatchError to the caller (auth must fail closed, not 500).
result = call_or_skip(hashing.verify_secret, _FAST, encoded, SECRET + "x")
assert result is False
def test_verify_garbage_encoding_returns_false() -> None:
# An invalid/corrupt stored hash must verify False, not raise.
result = call_or_skip(hashing.verify_secret, _FAST, "not-a-valid-hash", SECRET)
assert result is False
def test_needs_rehash_false_for_current_params() -> None:
encoded = call_or_skip(hashing.hash_secret, _FAST, SECRET)
assert call_or_skip(hashing.needs_rehash, _FAST, encoded) is False
def test_needs_rehash_true_when_params_change() -> None:
weak = PasswordHasher(time_cost=1, memory_cost=8, parallelism=1)
strong = PasswordHasher(time_cost=3, memory_cost=64, parallelism=1)
encoded = call_or_skip(hashing.hash_secret, weak, SECRET)
# A hash made with weaker params must be flagged for rehash by a stronger
# configured hasher (SPEC §9 rehash-on-parameter-change).
assert call_or_skip(hashing.needs_rehash, strong, encoded) is True
@pytest.mark.parametrize("secret", ["", "a", "x" * 4096])
def test_hash_verify_edge_lengths(secret: str) -> None:
encoded = call_or_skip(hashing.hash_secret, _FAST, secret)
assert call_or_skip(hashing.verify_secret, _FAST, encoded, secret) is True

63
tests/unit/test_keys.py Normal file
View File

@@ -0,0 +1,63 @@
"""Unit tests for ``neuronetz_gateway.auth.keys`` (key gen + prefix parsing).
SPEC §11/§4.3 key scheme: a full key is ``nz_<random base62>``; the **stored
prefix** is ``full_key[:PREFIX_LEN]`` (the ``nz_`` namespace plus leading random
chars) and is used as the Redis cache key / DB lookup. The whole full key is
argon2id-hashed; the full key is shown once. Part of the 100%-coverage gate on
``auth/``.
"""
from __future__ import annotations
import pytest
from neuronetz_gateway.auth import keys
from tests._skip import call_or_skip
def test_namespace_and_prefix_len_match_spec() -> None:
assert keys.KEY_NAMESPACE == "nz_"
assert keys.PREFIX_LEN == 12
def test_generate_key_shape() -> None:
gen = call_or_skip(keys.generate_key)
assert gen.full_key.startswith(keys.KEY_NAMESPACE)
# Full key = namespace + SECRET_LEN random chars.
assert len(gen.full_key) == len(keys.KEY_NAMESPACE) + keys.SECRET_LEN
# Stored prefix is exactly the first PREFIX_LEN chars of the full key and
# therefore a literal prefix of it (SPEC §4.3 "first 12 chars").
assert gen.prefix == gen.full_key[: keys.PREFIX_LEN]
assert len(gen.prefix) == keys.PREFIX_LEN
assert gen.full_key.startswith(gen.prefix)
def test_generate_key_is_unique() -> None:
a = call_or_skip(keys.generate_key)
b = call_or_skip(keys.generate_key)
assert a.full_key != b.full_key
assert a.prefix != b.prefix # CSPRNG => prefixes differ with overwhelming prob.
def test_generated_key_is_url_safe_ascii() -> None:
gen = call_or_skip(keys.generate_key)
body = gen.full_key[len(keys.KEY_NAMESPACE) :]
assert body.isascii()
assert body.isalnum() # base62 alphabet, no separators
assert " " not in gen.full_key
def test_extract_prefix_roundtrips_generated_key() -> None:
gen = call_or_skip(keys.generate_key)
assert call_or_skip(keys.extract_prefix, gen.full_key) == gen.prefix
def test_extract_prefix_rejects_bad_format() -> None:
# Missing namespace / too short must raise rather than silently truncate.
with pytest.raises((ValueError, NotImplementedError)):
keys.extract_prefix("definitely-not-a-key")
def test_extract_prefix_rejects_too_short() -> None:
with pytest.raises((ValueError, NotImplementedError)):
keys.extract_prefix("nz_short")

View File

@@ -0,0 +1,120 @@
"""Unit tests for ``neuronetz_gateway.ratelimit.sliding_window``.
Redis Lua-atomic sliding window (SPEC §4.3 step 4, §9 100% on ``ratelimit/``):
counts hits within the window, resets after it elapses, and keeps per-key vs
per-tenant scopes independent. Backed by the real ``redis_client`` testcontainer
fixture; skips cleanly when Docker is unavailable.
Bodies are real but skip if the limiter is still a Phase-1 stub
(``NotImplementedError``) so the suite stays green until Backend lands.
"""
from __future__ import annotations
import asyncio
import pytest
import redis.asyncio as aioredis
from neuronetz_gateway.ratelimit.sliding_window import RateLimitResult, SlidingWindowLimiter
from tests._skip import call_or_skip
pytestmark = pytest.mark.asyncio
# Known Backend bug (sliding_window.py): the per-hit ZSET member is
# ``f"{now_ms}-{id(object())}"``. ``id(object())`` returns the SAME value on
# every call (the temporary object is freed immediately), so two hits landing in
# the same millisecond produce an identical member; ZADD then *overwrites*
# instead of adding a second entry, undercounting the window and admitting
# requests that should be blocked. These two tests assert correct counting and
# therefore fail until Backend gives each hit a unique member (e.g. a counter or
# ``secrets.token_hex``). ``strict=False`` so they flip to XPASS once fixed
# without breaking the suite. See QA report.
_MEMBER_COLLISION = pytest.mark.xfail(
reason="sliding_window member id(object()) collides within a millisecond; "
"undercounts the window (see QA report)",
strict=False,
)
async def _check(
limiter: SlidingWindowLimiter, key: str, limit: int, window_s: int, cost: int = 1
) -> RateLimitResult:
return await call_or_skip(limiter.check, key, limit, window_s, cost)
# A spacing larger than 1ms so consecutive hits land on distinct ZSET members,
# isolating the *windowing* logic from the separate member-collision bug
# (asserted directly in test_same_millisecond_burst_undercounts below).
_SPACING_S = 0.003
async def test_allows_up_to_limit_then_blocks(redis_client: aioredis.Redis) -> None:
limiter = SlidingWindowLimiter(redis_client)
key = "rl:key:abc"
limit, window = 3, 60
results = []
for _ in range(limit):
results.append(await _check(limiter, key, limit, window))
await asyncio.sleep(_SPACING_S)
assert all(r.allowed for r in results)
assert results[0].limit == limit
# Remaining decrements monotonically toward zero.
assert results[-1].remaining == 0
blocked = await _check(limiter, key, limit, window)
assert blocked.allowed is False
assert blocked.remaining == 0
# A blocked result advertises when to retry (used for Retry-After).
assert blocked.retry_after_s is not None
assert blocked.retry_after_s >= 0
async def test_window_resets_after_elapse(redis_client: aioredis.Redis) -> None:
limiter = SlidingWindowLimiter(redis_client)
key = "rl:key:resets"
limit, window = 2, 1 # 1-second window for a fast test
assert (await _check(limiter, key, limit, window)).allowed
await asyncio.sleep(_SPACING_S)
assert (await _check(limiter, key, limit, window)).allowed
await asyncio.sleep(_SPACING_S)
assert (await _check(limiter, key, limit, window)).allowed is False
# After the window passes, the oldest hits age out and capacity returns.
await asyncio.sleep(1.2)
assert (await _check(limiter, key, limit, window)).allowed is True
@_MEMBER_COLLISION
async def test_same_millisecond_burst_undercounts(redis_client: aioredis.Redis) -> None:
# A burst of hits within one millisecond must still each count. With the
# current member scheme they collide and only one is recorded, so the
# limiter wrongly keeps admitting. xfail until Backend makes members unique.
limiter = SlidingWindowLimiter(redis_client)
key = "rl:key:burst"
limit, window = 2, 60
results = [await _check(limiter, key, limit, window) for _ in range(4)]
# Correct behaviour: first two admitted, rest blocked.
assert [r.allowed for r in results] == [True, True, False, False]
async def test_per_key_and_per_tenant_scopes_independent(
redis_client: aioredis.Redis,
) -> None:
limiter = SlidingWindowLimiter(redis_client)
# Distinct keys => distinct windows; exhausting one must not affect another.
await _check(limiter, "rl:key:k1", 1, 60)
assert (await _check(limiter, "rl:key:k1", 1, 60)).allowed is False
assert (await _check(limiter, "rl:tenant:t1", 1, 60)).allowed is True
async def test_cost_consumes_multiple_slots(redis_client: aioredis.Redis) -> None:
# TPM-style accounting: a single check may cost >1 (per-key TPM, SPEC §4.3).
limiter = SlidingWindowLimiter(redis_client)
first = await _check(limiter, "rl:tpm:k", limit=10, window_s=60, cost=8)
assert first.allowed is True
assert first.remaining == 2
second = await _check(limiter, "rl:tpm:k", limit=10, window_s=60, cost=8)
assert second.allowed is False

21
tests/unit/test_smoke.py Normal file
View File

@@ -0,0 +1,21 @@
"""Smoke test: the package imports and exposes a version string.
This is the one always-passing test that guarantees ``pytest`` exits 0 in
Phase 1 (rather than exit code 5, "no tests ran"). It also serves as the
canary that the src-layout package is importable on the configured path.
"""
from __future__ import annotations
def test_package_imports() -> None:
import neuronetz_gateway
assert neuronetz_gateway is not None
def test_version_is_str() -> None:
import neuronetz_gateway
assert isinstance(neuronetz_gateway.__version__, str)
assert neuronetz_gateway.__version__

View File

@@ -0,0 +1,49 @@
"""Unit tests for ``neuronetz_gateway.proxy.token_counter``.
Tokens are read precisely from Ollama's final frame: ``prompt_eval_count``
(input) and ``eval_count`` (output) — never estimated (SPEC §2, §4.3 step 12,
§13.1). Embeddings carry only ``prompt_eval_count`` (SPEC §13.1).
"""
from __future__ import annotations
from neuronetz_gateway.proxy.token_counter import TokenUsage, extract_usage
from tests._skip import call_or_skip
def test_extract_from_final_chat_frame() -> None:
# Mirrors the terminal NDJSON object emitted by mock_ollama (_final_metrics).
final = {
"model": "llama3.1:8b",
"done": True,
"done_reason": "stop",
"total_duration": 1_234_567_890,
"prompt_eval_count": 11,
"eval_count": 7,
}
usage = call_or_skip(extract_usage, final)
assert isinstance(usage, TokenUsage)
assert usage.tokens_in == 11
assert usage.tokens_out == 7
def test_extract_from_generate_frame() -> None:
final = {"done": True, "context": [1, 2, 3], "prompt_eval_count": 5, "eval_count": 42}
usage = call_or_skip(extract_usage, final)
assert (usage.tokens_in, usage.tokens_out) == (5, 42)
def test_embeddings_frame_only_prompt_eval_count() -> None:
# Embeddings: Ollama returns no eval_count (SPEC §13.1) => tokens_out == 0.
frame = {"embedding": [0.0, 0.1], "prompt_eval_count": 9}
usage = call_or_skip(extract_usage, frame)
assert usage.tokens_in == 9
assert usage.tokens_out == 0
def test_missing_counts_default_to_zero() -> None:
# A frame lacking the counter fields must not raise; charge nothing rather
# than crash the audit/budget path.
usage = call_or_skip(extract_usage, {"done": True})
assert usage.tokens_in == 0
assert usage.tokens_out == 0

View File

@@ -0,0 +1,179 @@
"""Unit tests for ``neuronetz_gateway.proxy.translate`` (OpenAI <-> Ollama).
Golden-fixture tests for the OpenAI-compat layer (SPEC §6.3):
* OpenAI chat/completion/embeddings request -> Ollama request body
* Ollama stream frame -> OpenAI ``chat.completion.chunk`` (delta + final usage)
* Ollama non-stream response -> OpenAI ``chat.completion`` / ``text_completion``
* model name list -> OpenAI ``/v1/models`` list shape
The streaming chunk shape is anchored to ``mock_ollama``'s reference helper
``ollama_chunks_to_openai_sse``. SSE *framing* (``data: {...}\\n\\n`` +
``data: [DONE]``) is asserted in the integration test_openai_compat.py.
"""
from __future__ import annotations
from typing import Any
from neuronetz_gateway.proxy import translate
def _as_dict(value: object) -> dict[str, Any]:
"""Narrow a translator-returned ``object`` to a typed dict for assertions."""
assert isinstance(value, dict), value
return value
def _as_list(value: object) -> list[Any]:
"""Narrow a translator-returned ``object`` to a typed list for assertions."""
assert isinstance(value, list), value
return value
# --- request translation: OpenAI -> Ollama ---------------------------------
def test_openai_chat_request_to_ollama_preserves_messages_and_model() -> None:
openai_req: dict[str, Any] = {
"model": "llama3.1:8b",
"messages": [
{"role": "system", "content": "be terse"},
{"role": "user", "content": "hi"},
],
"stream": True,
}
ollama = translate.openai_chat_to_ollama(openai_req)
assert ollama["model"] == "llama3.1:8b"
assert ollama["messages"] == openai_req["messages"]
assert ollama["stream"] is True
def test_openai_chat_options_mapped() -> None:
openai_req: dict[str, Any] = {
"model": "llama3.1:8b",
"messages": [{"role": "user", "content": "hi"}],
"temperature": 0.2,
"max_tokens": 128,
"stream": False,
}
ollama = translate.openai_chat_to_ollama(openai_req)
options = _as_dict(ollama["options"])
assert options["temperature"] == 0.2
# OpenAI ``max_tokens`` => Ollama ``num_predict``.
assert options["num_predict"] == 128
assert ollama["stream"] is False
def test_openai_completion_to_ollama_generate() -> None:
openai_req: dict[str, Any] = {
"model": "llama3.1:8b",
"prompt": "once upon a time",
"stream": True,
}
ollama = translate.openai_completion_to_ollama(openai_req)
assert ollama["model"] == "llama3.1:8b"
assert ollama["prompt"] == "once upon a time"
assert ollama["stream"] is True
def test_openai_embeddings_to_ollama_embed() -> None:
openai_req: dict[str, Any] = {"model": "nomic-embed-text", "input": "hello world"}
ollama = translate.openai_embeddings_to_ollama(openai_req)
assert ollama["model"] == "nomic-embed-text"
assert ollama["input"] == "hello world"
# --- streaming response translation: Ollama frame -> OpenAI chunk ----------
def test_chat_delta_chunk_to_openai() -> None:
frame: dict[str, Any] = {
"model": "llama3.1:8b",
"message": {"role": "assistant", "content": "Echo:"},
"done": False,
}
out = translate.ollama_chat_chunk_to_openai(
frame, completion_id="chatcmpl-x", model="llama3.1:8b", created=1700
)
assert out["object"] == "chat.completion.chunk"
choice = _as_dict(_as_list(out["choices"])[0])
delta = _as_dict(choice["delta"])
assert delta["content"] == "Echo:"
assert choice["finish_reason"] is None
def test_chat_final_chunk_carries_usage_and_finish_reason() -> None:
frame: dict[str, Any] = {
"model": "llama3.1:8b",
"message": {"role": "assistant", "content": ""},
"done": True,
"done_reason": "stop",
"prompt_eval_count": 4,
"eval_count": 6,
}
out = translate.ollama_chat_chunk_to_openai(
frame, completion_id="chatcmpl-x", model="llama3.1:8b", created=1700
)
choice = _as_dict(_as_list(out["choices"])[0])
assert choice["finish_reason"] == "stop"
usage = _as_dict(out["usage"])
assert usage["prompt_tokens"] == 4
assert usage["completion_tokens"] == 6
assert usage["total_tokens"] == 10
# --- non-streaming response translation ------------------------------------
def test_nonstream_chat_to_openai_completion() -> None:
ollama_resp: dict[str, Any] = {
"model": "llama3.1:8b",
"message": {"role": "assistant", "content": "Echo: hi"},
"done": True,
"prompt_eval_count": 2,
"eval_count": 3,
}
out = translate.ollama_chat_to_openai(ollama_resp)
assert out["object"] == "chat.completion"
choice = _as_dict(_as_list(out["choices"])[0])
assert choice["message"] == {"role": "assistant", "content": "Echo: hi"}
assert choice["finish_reason"] == "stop"
assert _as_dict(out["usage"])["total_tokens"] == 5
def test_nonstream_generate_to_openai() -> None:
ollama_resp: dict[str, Any] = {
"model": "llama3.1:8b",
"response": "once upon a time",
"done": True,
"prompt_eval_count": 1,
"eval_count": 4,
}
out = translate.ollama_generate_to_openai(ollama_resp)
assert out["object"] == "text_completion"
choice = _as_dict(_as_list(out["choices"])[0])
assert choice["text"] == "once upon a time"
assert _as_dict(out["usage"])["total_tokens"] == 5
def test_embed_to_openai_shape() -> None:
ollama_resp: dict[str, Any] = {
"model": "nomic-embed-text",
"embeddings": [[0.0, 0.1], [0.2, 0.3]],
"prompt_eval_count": 7,
}
out = translate.ollama_embed_to_openai(ollama_resp, model="nomic-embed-text")
assert out["object"] == "list"
data = _as_list(out["data"])
assert len(data) == 2
assert data[0] == {"object": "embedding", "index": 0, "embedding": [0.0, 0.1]}
# Embeddings charge prompt tokens only (SPEC §13.1).
assert out["usage"] == {"prompt_tokens": 7, "total_tokens": 7}
def test_models_to_openai_list_shape() -> None:
out = translate.models_to_openai_list(["llama3.1:8b", "mistral:7b"])
assert out["object"] == "list"
data = _as_list(out["data"])
ids = {_as_dict(m)["id"] for m in data}
assert ids == {"llama3.1:8b", "mistral:7b"}
assert all(_as_dict(m)["object"] == "model" for m in data)