tests: unit + integration suite (99 tests; ruff + mypy --strict clean)
Real test bodies (not stubs), driven against an in-process httpx.ASGITransport override of the gateway's get_ollama_client dependency pointing at tests/integration/mock_ollama.py. Unit (target 100% on auth/, ratelimit/, budget/): - argon2id roundtrip, wrong-key, garbage encoding, needs_rehash on param change - key format/uniqueness/prefix extraction - token counter (prompt_eval_count + eval_count, embeddings, missing-counts) - translate (OpenAI <-> Ollama for chat/completion/embeddings, streaming chunks, /v1/models list shape) - allowlist (hard-blocks, effective-set semantics across allow_all/inheritance/ empty-discovered) - discovery (parse, cache roundtrip with TTL, fail-closed, tolerates redis=None) - sliding window (allow/block/reset/per-key vs per-tenant/cost-weighted) Integration (testcontainers postgres + redis + in-process mock Ollama): - auth flow (no/malformed/wrong key all return identical sanitized 401) - proxy stream (NDJSON roundtrip, audit row's token counts match, hard-blocked endpoints uniformly 403) - openai_compat (SSE chunks, data: [DONE], non-stream shape, /v1/models) - model_discovery (allow_all sees all, default-deny sees allowed ∩ discovered, /v1/models filtered, unpermitted-but-installed = nonexistent = 403, empty cache denies even allow_all) - rate_limit (429 + Retry-After + headers; Redis down ⇒ 503, never 200) - budget (decrement + headers; pre-burned counter blocks next request) - revocation (INSERT into gateway.revocations → NOTIFY → cache evicted → 401 ≤ 1s) Includes a known-issue xfail flagging a bug in ratelimit/sliding_window.py: the per-hit ZSET member uses id(object()) which returns the same id on consecutive calls, causing same-millisecond hits to overwrite instead of stacking. To be fixed in a follow-up commit.
This commit is contained in:
1
tests/unit/__init__.py
Normal file
1
tests/unit/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Unit test package."""
|
||||
115
tests/unit/test_allowlist.py
Normal file
115
tests/unit/test_allowlist.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Unit tests for ``neuronetz_gateway.proxy.allowlist``.
|
||||
|
||||
Two concerns (SPEC §4.3 step 7-8, §4.6, §6.1-6.2):
|
||||
|
||||
1. **Hard-blocked endpoints** — mutating Ollama endpoints and ``/api/ps`` are
|
||||
always blocked, not configurable (``is_hard_blocked``).
|
||||
2. **Effective model set** (``resolve_effective_models`` / ``is_model_allowed``):
|
||||
* ``allow_all`` ⇒ all discovered
|
||||
* default-deny ⇒ ``allowed_models ∩ discovered``
|
||||
* stale/typo'd allowlist entries never resolve (not in discovered)
|
||||
* empty discovered ⇒ deny, even under ``allow_all`` (fail-closed)
|
||||
|
||||
The resolver merges the *already-resolved* ``allow_all``/``allowed_models``
|
||||
inputs (key-vs-tenant precedence per SPEC §13.7 is applied by the caller before
|
||||
this point; that precedence is exercised in the integration model-discovery
|
||||
tests).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from neuronetz_gateway.proxy import allowlist
|
||||
|
||||
# --- hard-blocked endpoint allowlist ---------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
["/api/pull", "/api/push", "/api/create", "/api/copy", "/api/delete", "/api/ps"],
|
||||
)
|
||||
def test_mutating_and_ps_endpoints_hard_blocked(path: str) -> None:
|
||||
assert allowlist.is_hard_blocked(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("path", ["/api/blobs", "/api/blobs/sha256:abc", "/api/blobs/x/y"])
|
||||
def test_blob_prefix_hard_blocked(path: str) -> None:
|
||||
assert allowlist.is_hard_blocked(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path", ["/api/chat", "/api/generate", "/api/embed", "/api/tags", "/api/show", "/api/version"]
|
||||
)
|
||||
def test_allowlisted_endpoints_not_hard_blocked(path: str) -> None:
|
||||
assert allowlist.is_hard_blocked(path) is False
|
||||
|
||||
|
||||
# --- effective-set resolution (SPEC §4.3 step 7 / §4.6) --------------------
|
||||
|
||||
DISCOVERED = frozenset({"llama3.1:8b", "mistral:7b", "nomic-embed-text"})
|
||||
|
||||
|
||||
def test_allow_all_returns_all_discovered() -> None:
|
||||
eff = allowlist.resolve_effective_models(
|
||||
allow_all=True, allowed_models=(), discovered=DISCOVERED
|
||||
)
|
||||
assert eff == DISCOVERED
|
||||
|
||||
|
||||
def test_default_deny_intersects_discovered() -> None:
|
||||
# "qwen" is allowed but not installed => must not resolve (stale entry).
|
||||
eff = allowlist.resolve_effective_models(
|
||||
allow_all=False,
|
||||
allowed_models=("llama3.1:8b", "qwen:0.5b"),
|
||||
discovered=DISCOVERED,
|
||||
)
|
||||
assert eff == frozenset({"llama3.1:8b"})
|
||||
|
||||
|
||||
def test_default_deny_empty_allowlist_denies_all() -> None:
|
||||
eff = allowlist.resolve_effective_models(
|
||||
allow_all=False, allowed_models=(), discovered=DISCOVERED
|
||||
)
|
||||
assert eff == frozenset()
|
||||
|
||||
|
||||
def test_empty_discovered_denies_everything_even_allow_all() -> None:
|
||||
# Fail-closed: empty discovered set => no model resolves, even allow_all.
|
||||
eff = allowlist.resolve_effective_models(
|
||||
allow_all=True, allowed_models=("llama3.1:8b",), discovered=frozenset()
|
||||
)
|
||||
assert eff == frozenset()
|
||||
|
||||
|
||||
def test_is_model_allowed_membership() -> None:
|
||||
assert (
|
||||
allowlist.is_model_allowed(
|
||||
"llama3.1:8b", allow_all=False, allowed_models=("llama3.1:8b",), discovered=DISCOVERED
|
||||
)
|
||||
is True
|
||||
)
|
||||
# Installed but not permitted.
|
||||
assert (
|
||||
allowlist.is_model_allowed(
|
||||
"mistral:7b", allow_all=False, allowed_models=("llama3.1:8b",), discovered=DISCOVERED
|
||||
)
|
||||
is False
|
||||
)
|
||||
# Permitted but not installed (typo'd / stale) — indistinguishable from
|
||||
# not-allowed to the caller (SPEC §13.6 no existence disclosure).
|
||||
assert (
|
||||
allowlist.is_model_allowed(
|
||||
"ghost:1b", allow_all=False, allowed_models=("ghost:1b",), discovered=DISCOVERED
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_model_allowed_fail_closed_empty_discovered() -> None:
|
||||
assert (
|
||||
allowlist.is_model_allowed(
|
||||
"llama3.1:8b", allow_all=True, allowed_models=(), discovered=frozenset()
|
||||
)
|
||||
is False
|
||||
)
|
||||
142
tests/unit/test_discovery.py
Normal file
142
tests/unit/test_discovery.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Unit tests for model discovery (SPEC §4.6).
|
||||
|
||||
A background poller queries Ollama ``GET /api/tags``, parses the installed model
|
||||
set, caches it in Redis (TTL) + in-process (:class:`DiscoveryCache`), and **fails
|
||||
closed**: an empty/expired discovered set means no model resolves (deny). On an
|
||||
upstream error ``refresh_once`` returns ``False`` and leaves the caches untouched
|
||||
so they expire on their own TTL (stale-expired ⇒ empty ⇒ deny); discovery never
|
||||
opens access.
|
||||
|
||||
Driven against the in-process ``mock_ollama`` upstream and the real
|
||||
``redis_client`` testcontainer fixture (skips cleanly without Docker).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import redis.asyncio as aioredis
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from neuronetz_gateway.config import Settings
|
||||
from neuronetz_gateway.proxy import discovery
|
||||
|
||||
|
||||
def _ollama_client(app: FastAPI) -> httpx.AsyncClient:
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
return httpx.AsyncClient(transport=transport, base_url="http://ollama")
|
||||
|
||||
|
||||
def _settings() -> Settings:
|
||||
return Settings(model_discovery_cache_ttl_s=120, model_discovery_refresh_s=60)
|
||||
|
||||
|
||||
# --- pure parsing ----------------------------------------------------------
|
||||
|
||||
|
||||
def test_names_of_extracts_model_names() -> None:
|
||||
models = [
|
||||
discovery.DiscoveredModel(name="llama3.1:8b", family="llama"),
|
||||
discovery.DiscoveredModel(name="mistral:7b"),
|
||||
]
|
||||
assert discovery.names_of(models) == frozenset({"llama3.1:8b", "mistral:7b"})
|
||||
|
||||
|
||||
def test_names_of_empty() -> None:
|
||||
assert discovery.names_of([]) == frozenset()
|
||||
|
||||
|
||||
# --- fetch + parse against the mock upstream -------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_tags_parses_mock_catalogue(mock_ollama_app: FastAPI) -> None:
|
||||
async with _ollama_client(mock_ollama_app) as ollama:
|
||||
models = await discovery.fetch_tags(ollama)
|
||||
names = discovery.names_of(models)
|
||||
assert {"llama3.1:8b", "mistral:7b", "nomic-embed-text"} <= names
|
||||
# Sanitized metadata is captured (family parsed from the mock's details).
|
||||
by_name = {m.name: m for m in models}
|
||||
assert by_name["llama3.1:8b"].family == "llama3.1"
|
||||
assert by_name["llama3.1:8b"].parameter_size == "8B"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_tags_raises_on_upstream_error() -> None:
|
||||
broken = FastAPI()
|
||||
|
||||
@broken.get("/api/tags")
|
||||
async def _tags() -> Any:
|
||||
return JSONResponse({"error": "boom"}, status_code=500)
|
||||
|
||||
async with _ollama_client(broken) as ollama:
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
await discovery.fetch_tags(ollama)
|
||||
|
||||
|
||||
# --- redis cache round-trip ------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_roundtrip(redis_client: aioredis.Redis) -> None:
|
||||
models = [discovery.DiscoveredModel(name="llama3.1:8b"), discovery.DiscoveredModel(name="x:1b")]
|
||||
await discovery.write_discovered_to_redis(redis_client, models, ttl_s=120)
|
||||
names = await discovery.read_discovered_from_redis(redis_client)
|
||||
assert names == frozenset({"llama3.1:8b", "x:1b"})
|
||||
# TTL was applied (staleness expires) per SPEC §4.6.
|
||||
assert 0 < await redis_client.ttl(discovery.REDIS_DISCOVERED_KEY) <= 120
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_discovered_miss_returns_empty(redis_client: aioredis.Redis) -> None:
|
||||
# Cache miss / expiry => empty set => fail-closed deny.
|
||||
assert await discovery.read_discovered_from_redis(redis_client) == frozenset()
|
||||
|
||||
|
||||
# --- refresh_once: success and fail-closed ---------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_once_populates_in_process_and_redis(
|
||||
redis_client: aioredis.Redis, mock_ollama_app: FastAPI
|
||||
) -> None:
|
||||
cache = discovery.DiscoveryCache()
|
||||
async with _ollama_client(mock_ollama_app) as ollama:
|
||||
ok = await discovery.refresh_once(ollama, redis_client, cache, _settings())
|
||||
assert ok is True
|
||||
assert {"llama3.1:8b", "mistral:7b", "nomic-embed-text"} <= cache.names
|
||||
# Mirrored into Redis under the SPEC §4.6 key.
|
||||
assert {"llama3.1:8b", "mistral:7b"} <= await discovery.read_discovered_from_redis(redis_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_once_fail_closed_on_upstream_error(
|
||||
redis_client: aioredis.Redis,
|
||||
) -> None:
|
||||
broken = FastAPI()
|
||||
|
||||
@broken.get("/api/tags")
|
||||
async def _tags() -> Any:
|
||||
return JSONResponse({"error": "boom"}, status_code=500)
|
||||
|
||||
cache = discovery.DiscoveryCache()
|
||||
async with _ollama_client(broken) as ollama:
|
||||
ok = await discovery.refresh_once(ollama, redis_client, cache, _settings())
|
||||
# Refresh reports failure; the in-process cache stays empty (no models
|
||||
# resolve) — discovery never opens access on error.
|
||||
assert ok is False
|
||||
assert cache.names == frozenset()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_once_tolerates_missing_redis(mock_ollama_app: FastAPI) -> None:
|
||||
# redis_client=None must still refresh the in-process cache (best-effort
|
||||
# Redis fill), not crash the poller.
|
||||
cache = discovery.DiscoveryCache()
|
||||
async with _ollama_client(mock_ollama_app) as ollama:
|
||||
ok = await discovery.refresh_once(ollama, None, cache, _settings())
|
||||
assert ok is True
|
||||
assert "llama3.1:8b" in cache.names
|
||||
93
tests/unit/test_hashing.py
Normal file
93
tests/unit/test_hashing.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Unit tests for ``neuronetz_gateway.auth.hashing`` (argon2id wrapper).
|
||||
|
||||
Covers SPEC §3/§9: hash, constant-time verify, wrong-key rejection, and
|
||||
rehash-on-parameter-change. Targets 100% coverage of ``auth/hashing.py``.
|
||||
|
||||
These bodies are real; while ``hashing.py`` is a Phase-1 stub each call routes
|
||||
through :func:`tests._skip.call_or_skip`, which skips (not fails) on
|
||||
``NotImplementedError`` and lets the test self-activate once Backend lands.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from argon2 import PasswordHasher
|
||||
|
||||
from neuronetz_gateway.auth import hashing
|
||||
from neuronetz_gateway.config import Settings
|
||||
from tests._skip import call_or_skip
|
||||
|
||||
# A fast hasher so the suite stays quick; the wrapper must round-trip regardless
|
||||
# of the cost parameters, which are what we actually exercise here.
|
||||
_FAST = PasswordHasher(time_cost=1, memory_cost=8, parallelism=1)
|
||||
SECRET = "nz_abcdef012345deadbeefcafebabe00112233" # noqa: S105 - test fixture, not a real key
|
||||
|
||||
|
||||
def _settings(**overrides: object) -> Settings:
|
||||
base: dict[str, object] = {
|
||||
"argon2_time_cost": 1,
|
||||
"argon2_memory_cost_kib": 8,
|
||||
"argon2_parallelism": 1,
|
||||
}
|
||||
base.update(overrides)
|
||||
return Settings(**base) # type: ignore[arg-type] # pydantic-settings kwargs
|
||||
|
||||
|
||||
def test_build_hasher_uses_settings() -> None:
|
||||
hasher = hashing.build_hasher(_settings(argon2_time_cost=2))
|
||||
assert isinstance(hasher, PasswordHasher)
|
||||
assert hasher.time_cost == 2
|
||||
|
||||
|
||||
def test_hash_secret_returns_argon2id_encoded() -> None:
|
||||
encoded = call_or_skip(hashing.hash_secret, _FAST, SECRET)
|
||||
assert isinstance(encoded, str)
|
||||
# argon2id encoded hashes begin with the variant tag.
|
||||
assert encoded.startswith("$argon2id$")
|
||||
# Never echoes the plaintext back.
|
||||
assert SECRET not in encoded
|
||||
|
||||
|
||||
def test_hash_is_salted_unique_per_call() -> None:
|
||||
a = call_or_skip(hashing.hash_secret, _FAST, SECRET)
|
||||
b = call_or_skip(hashing.hash_secret, _FAST, SECRET)
|
||||
assert a != b # random salt => distinct encodings
|
||||
|
||||
|
||||
def test_verify_secret_roundtrip_true() -> None:
|
||||
encoded = call_or_skip(hashing.hash_secret, _FAST, SECRET)
|
||||
assert call_or_skip(hashing.verify_secret, _FAST, encoded, SECRET) is True
|
||||
|
||||
|
||||
def test_verify_wrong_secret_returns_false_not_raise() -> None:
|
||||
encoded = call_or_skip(hashing.hash_secret, _FAST, SECRET)
|
||||
# Constant-time path must return False, never propagate argon2's
|
||||
# VerifyMismatchError to the caller (auth must fail closed, not 500).
|
||||
result = call_or_skip(hashing.verify_secret, _FAST, encoded, SECRET + "x")
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_verify_garbage_encoding_returns_false() -> None:
|
||||
# An invalid/corrupt stored hash must verify False, not raise.
|
||||
result = call_or_skip(hashing.verify_secret, _FAST, "not-a-valid-hash", SECRET)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_needs_rehash_false_for_current_params() -> None:
|
||||
encoded = call_or_skip(hashing.hash_secret, _FAST, SECRET)
|
||||
assert call_or_skip(hashing.needs_rehash, _FAST, encoded) is False
|
||||
|
||||
|
||||
def test_needs_rehash_true_when_params_change() -> None:
|
||||
weak = PasswordHasher(time_cost=1, memory_cost=8, parallelism=1)
|
||||
strong = PasswordHasher(time_cost=3, memory_cost=64, parallelism=1)
|
||||
encoded = call_or_skip(hashing.hash_secret, weak, SECRET)
|
||||
# A hash made with weaker params must be flagged for rehash by a stronger
|
||||
# configured hasher (SPEC §9 rehash-on-parameter-change).
|
||||
assert call_or_skip(hashing.needs_rehash, strong, encoded) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("secret", ["", "a", "x" * 4096])
|
||||
def test_hash_verify_edge_lengths(secret: str) -> None:
|
||||
encoded = call_or_skip(hashing.hash_secret, _FAST, secret)
|
||||
assert call_or_skip(hashing.verify_secret, _FAST, encoded, secret) is True
|
||||
63
tests/unit/test_keys.py
Normal file
63
tests/unit/test_keys.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Unit tests for ``neuronetz_gateway.auth.keys`` (key gen + prefix parsing).
|
||||
|
||||
SPEC §11/§4.3 key scheme: a full key is ``nz_<random base62>``; the **stored
|
||||
prefix** is ``full_key[:PREFIX_LEN]`` (the ``nz_`` namespace plus leading random
|
||||
chars) and is used as the Redis cache key / DB lookup. The whole full key is
|
||||
argon2id-hashed; the full key is shown once. Part of the 100%-coverage gate on
|
||||
``auth/``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from neuronetz_gateway.auth import keys
|
||||
from tests._skip import call_or_skip
|
||||
|
||||
|
||||
def test_namespace_and_prefix_len_match_spec() -> None:
|
||||
assert keys.KEY_NAMESPACE == "nz_"
|
||||
assert keys.PREFIX_LEN == 12
|
||||
|
||||
|
||||
def test_generate_key_shape() -> None:
|
||||
gen = call_or_skip(keys.generate_key)
|
||||
assert gen.full_key.startswith(keys.KEY_NAMESPACE)
|
||||
# Full key = namespace + SECRET_LEN random chars.
|
||||
assert len(gen.full_key) == len(keys.KEY_NAMESPACE) + keys.SECRET_LEN
|
||||
# Stored prefix is exactly the first PREFIX_LEN chars of the full key and
|
||||
# therefore a literal prefix of it (SPEC §4.3 "first 12 chars").
|
||||
assert gen.prefix == gen.full_key[: keys.PREFIX_LEN]
|
||||
assert len(gen.prefix) == keys.PREFIX_LEN
|
||||
assert gen.full_key.startswith(gen.prefix)
|
||||
|
||||
|
||||
def test_generate_key_is_unique() -> None:
|
||||
a = call_or_skip(keys.generate_key)
|
||||
b = call_or_skip(keys.generate_key)
|
||||
assert a.full_key != b.full_key
|
||||
assert a.prefix != b.prefix # CSPRNG => prefixes differ with overwhelming prob.
|
||||
|
||||
|
||||
def test_generated_key_is_url_safe_ascii() -> None:
|
||||
gen = call_or_skip(keys.generate_key)
|
||||
body = gen.full_key[len(keys.KEY_NAMESPACE) :]
|
||||
assert body.isascii()
|
||||
assert body.isalnum() # base62 alphabet, no separators
|
||||
assert " " not in gen.full_key
|
||||
|
||||
|
||||
def test_extract_prefix_roundtrips_generated_key() -> None:
|
||||
gen = call_or_skip(keys.generate_key)
|
||||
assert call_or_skip(keys.extract_prefix, gen.full_key) == gen.prefix
|
||||
|
||||
|
||||
def test_extract_prefix_rejects_bad_format() -> None:
|
||||
# Missing namespace / too short must raise rather than silently truncate.
|
||||
with pytest.raises((ValueError, NotImplementedError)):
|
||||
keys.extract_prefix("definitely-not-a-key")
|
||||
|
||||
|
||||
def test_extract_prefix_rejects_too_short() -> None:
|
||||
with pytest.raises((ValueError, NotImplementedError)):
|
||||
keys.extract_prefix("nz_short")
|
||||
120
tests/unit/test_sliding_window.py
Normal file
120
tests/unit/test_sliding_window.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Unit tests for ``neuronetz_gateway.ratelimit.sliding_window``.
|
||||
|
||||
Redis Lua-atomic sliding window (SPEC §4.3 step 4, §9 100% on ``ratelimit/``):
|
||||
counts hits within the window, resets after it elapses, and keeps per-key vs
|
||||
per-tenant scopes independent. Backed by the real ``redis_client`` testcontainer
|
||||
fixture; skips cleanly when Docker is unavailable.
|
||||
|
||||
Bodies are real but skip if the limiter is still a Phase-1 stub
|
||||
(``NotImplementedError``) so the suite stays green until Backend lands.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from neuronetz_gateway.ratelimit.sliding_window import RateLimitResult, SlidingWindowLimiter
|
||||
from tests._skip import call_or_skip
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
# Known Backend bug (sliding_window.py): the per-hit ZSET member is
|
||||
# ``f"{now_ms}-{id(object())}"``. ``id(object())`` returns the SAME value on
|
||||
# every call (the temporary object is freed immediately), so two hits landing in
|
||||
# the same millisecond produce an identical member; ZADD then *overwrites*
|
||||
# instead of adding a second entry, undercounting the window and admitting
|
||||
# requests that should be blocked. These two tests assert correct counting and
|
||||
# therefore fail until Backend gives each hit a unique member (e.g. a counter or
|
||||
# ``secrets.token_hex``). ``strict=False`` so they flip to XPASS once fixed
|
||||
# without breaking the suite. See QA report.
|
||||
_MEMBER_COLLISION = pytest.mark.xfail(
|
||||
reason="sliding_window member id(object()) collides within a millisecond; "
|
||||
"undercounts the window (see QA report)",
|
||||
strict=False,
|
||||
)
|
||||
|
||||
|
||||
async def _check(
|
||||
limiter: SlidingWindowLimiter, key: str, limit: int, window_s: int, cost: int = 1
|
||||
) -> RateLimitResult:
|
||||
return await call_or_skip(limiter.check, key, limit, window_s, cost)
|
||||
|
||||
|
||||
# A spacing larger than 1ms so consecutive hits land on distinct ZSET members,
|
||||
# isolating the *windowing* logic from the separate member-collision bug
|
||||
# (asserted directly in test_same_millisecond_burst_undercounts below).
|
||||
_SPACING_S = 0.003
|
||||
|
||||
|
||||
async def test_allows_up_to_limit_then_blocks(redis_client: aioredis.Redis) -> None:
|
||||
limiter = SlidingWindowLimiter(redis_client)
|
||||
key = "rl:key:abc"
|
||||
limit, window = 3, 60
|
||||
|
||||
results = []
|
||||
for _ in range(limit):
|
||||
results.append(await _check(limiter, key, limit, window))
|
||||
await asyncio.sleep(_SPACING_S)
|
||||
assert all(r.allowed for r in results)
|
||||
assert results[0].limit == limit
|
||||
# Remaining decrements monotonically toward zero.
|
||||
assert results[-1].remaining == 0
|
||||
|
||||
blocked = await _check(limiter, key, limit, window)
|
||||
assert blocked.allowed is False
|
||||
assert blocked.remaining == 0
|
||||
# A blocked result advertises when to retry (used for Retry-After).
|
||||
assert blocked.retry_after_s is not None
|
||||
assert blocked.retry_after_s >= 0
|
||||
|
||||
|
||||
async def test_window_resets_after_elapse(redis_client: aioredis.Redis) -> None:
|
||||
limiter = SlidingWindowLimiter(redis_client)
|
||||
key = "rl:key:resets"
|
||||
limit, window = 2, 1 # 1-second window for a fast test
|
||||
|
||||
assert (await _check(limiter, key, limit, window)).allowed
|
||||
await asyncio.sleep(_SPACING_S)
|
||||
assert (await _check(limiter, key, limit, window)).allowed
|
||||
await asyncio.sleep(_SPACING_S)
|
||||
assert (await _check(limiter, key, limit, window)).allowed is False
|
||||
|
||||
# After the window passes, the oldest hits age out and capacity returns.
|
||||
await asyncio.sleep(1.2)
|
||||
assert (await _check(limiter, key, limit, window)).allowed is True
|
||||
|
||||
|
||||
@_MEMBER_COLLISION
|
||||
async def test_same_millisecond_burst_undercounts(redis_client: aioredis.Redis) -> None:
|
||||
# A burst of hits within one millisecond must still each count. With the
|
||||
# current member scheme they collide and only one is recorded, so the
|
||||
# limiter wrongly keeps admitting. xfail until Backend makes members unique.
|
||||
limiter = SlidingWindowLimiter(redis_client)
|
||||
key = "rl:key:burst"
|
||||
limit, window = 2, 60
|
||||
results = [await _check(limiter, key, limit, window) for _ in range(4)]
|
||||
# Correct behaviour: first two admitted, rest blocked.
|
||||
assert [r.allowed for r in results] == [True, True, False, False]
|
||||
|
||||
|
||||
async def test_per_key_and_per_tenant_scopes_independent(
|
||||
redis_client: aioredis.Redis,
|
||||
) -> None:
|
||||
limiter = SlidingWindowLimiter(redis_client)
|
||||
# Distinct keys => distinct windows; exhausting one must not affect another.
|
||||
await _check(limiter, "rl:key:k1", 1, 60)
|
||||
assert (await _check(limiter, "rl:key:k1", 1, 60)).allowed is False
|
||||
assert (await _check(limiter, "rl:tenant:t1", 1, 60)).allowed is True
|
||||
|
||||
|
||||
async def test_cost_consumes_multiple_slots(redis_client: aioredis.Redis) -> None:
|
||||
# TPM-style accounting: a single check may cost >1 (per-key TPM, SPEC §4.3).
|
||||
limiter = SlidingWindowLimiter(redis_client)
|
||||
first = await _check(limiter, "rl:tpm:k", limit=10, window_s=60, cost=8)
|
||||
assert first.allowed is True
|
||||
assert first.remaining == 2
|
||||
second = await _check(limiter, "rl:tpm:k", limit=10, window_s=60, cost=8)
|
||||
assert second.allowed is False
|
||||
21
tests/unit/test_smoke.py
Normal file
21
tests/unit/test_smoke.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Smoke test: the package imports and exposes a version string.
|
||||
|
||||
This is the one always-passing test that guarantees ``pytest`` exits 0 in
|
||||
Phase 1 (rather than exit code 5, "no tests ran"). It also serves as the
|
||||
canary that the src-layout package is importable on the configured path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def test_package_imports() -> None:
|
||||
import neuronetz_gateway
|
||||
|
||||
assert neuronetz_gateway is not None
|
||||
|
||||
|
||||
def test_version_is_str() -> None:
|
||||
import neuronetz_gateway
|
||||
|
||||
assert isinstance(neuronetz_gateway.__version__, str)
|
||||
assert neuronetz_gateway.__version__
|
||||
49
tests/unit/test_token_counter.py
Normal file
49
tests/unit/test_token_counter.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Unit tests for ``neuronetz_gateway.proxy.token_counter``.
|
||||
|
||||
Tokens are read precisely from Ollama's final frame: ``prompt_eval_count``
|
||||
(input) and ``eval_count`` (output) — never estimated (SPEC §2, §4.3 step 12,
|
||||
§13.1). Embeddings carry only ``prompt_eval_count`` (SPEC §13.1).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from neuronetz_gateway.proxy.token_counter import TokenUsage, extract_usage
|
||||
from tests._skip import call_or_skip
|
||||
|
||||
|
||||
def test_extract_from_final_chat_frame() -> None:
|
||||
# Mirrors the terminal NDJSON object emitted by mock_ollama (_final_metrics).
|
||||
final = {
|
||||
"model": "llama3.1:8b",
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 1_234_567_890,
|
||||
"prompt_eval_count": 11,
|
||||
"eval_count": 7,
|
||||
}
|
||||
usage = call_or_skip(extract_usage, final)
|
||||
assert isinstance(usage, TokenUsage)
|
||||
assert usage.tokens_in == 11
|
||||
assert usage.tokens_out == 7
|
||||
|
||||
|
||||
def test_extract_from_generate_frame() -> None:
|
||||
final = {"done": True, "context": [1, 2, 3], "prompt_eval_count": 5, "eval_count": 42}
|
||||
usage = call_or_skip(extract_usage, final)
|
||||
assert (usage.tokens_in, usage.tokens_out) == (5, 42)
|
||||
|
||||
|
||||
def test_embeddings_frame_only_prompt_eval_count() -> None:
|
||||
# Embeddings: Ollama returns no eval_count (SPEC §13.1) => tokens_out == 0.
|
||||
frame = {"embedding": [0.0, 0.1], "prompt_eval_count": 9}
|
||||
usage = call_or_skip(extract_usage, frame)
|
||||
assert usage.tokens_in == 9
|
||||
assert usage.tokens_out == 0
|
||||
|
||||
|
||||
def test_missing_counts_default_to_zero() -> None:
|
||||
# A frame lacking the counter fields must not raise; charge nothing rather
|
||||
# than crash the audit/budget path.
|
||||
usage = call_or_skip(extract_usage, {"done": True})
|
||||
assert usage.tokens_in == 0
|
||||
assert usage.tokens_out == 0
|
||||
179
tests/unit/test_translate.py
Normal file
179
tests/unit/test_translate.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Unit tests for ``neuronetz_gateway.proxy.translate`` (OpenAI <-> Ollama).
|
||||
|
||||
Golden-fixture tests for the OpenAI-compat layer (SPEC §6.3):
|
||||
* OpenAI chat/completion/embeddings request -> Ollama request body
|
||||
* Ollama stream frame -> OpenAI ``chat.completion.chunk`` (delta + final usage)
|
||||
* Ollama non-stream response -> OpenAI ``chat.completion`` / ``text_completion``
|
||||
* model name list -> OpenAI ``/v1/models`` list shape
|
||||
|
||||
The streaming chunk shape is anchored to ``mock_ollama``'s reference helper
|
||||
``ollama_chunks_to_openai_sse``. SSE *framing* (``data: {...}\\n\\n`` +
|
||||
``data: [DONE]``) is asserted in the integration test_openai_compat.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from neuronetz_gateway.proxy import translate
|
||||
|
||||
|
||||
def _as_dict(value: object) -> dict[str, Any]:
|
||||
"""Narrow a translator-returned ``object`` to a typed dict for assertions."""
|
||||
assert isinstance(value, dict), value
|
||||
return value
|
||||
|
||||
|
||||
def _as_list(value: object) -> list[Any]:
|
||||
"""Narrow a translator-returned ``object`` to a typed list for assertions."""
|
||||
assert isinstance(value, list), value
|
||||
return value
|
||||
|
||||
# --- request translation: OpenAI -> Ollama ---------------------------------
|
||||
|
||||
|
||||
def test_openai_chat_request_to_ollama_preserves_messages_and_model() -> None:
|
||||
openai_req: dict[str, Any] = {
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [
|
||||
{"role": "system", "content": "be terse"},
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
"stream": True,
|
||||
}
|
||||
ollama = translate.openai_chat_to_ollama(openai_req)
|
||||
assert ollama["model"] == "llama3.1:8b"
|
||||
assert ollama["messages"] == openai_req["messages"]
|
||||
assert ollama["stream"] is True
|
||||
|
||||
|
||||
def test_openai_chat_options_mapped() -> None:
|
||||
openai_req: dict[str, Any] = {
|
||||
"model": "llama3.1:8b",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"temperature": 0.2,
|
||||
"max_tokens": 128,
|
||||
"stream": False,
|
||||
}
|
||||
ollama = translate.openai_chat_to_ollama(openai_req)
|
||||
options = _as_dict(ollama["options"])
|
||||
assert options["temperature"] == 0.2
|
||||
# OpenAI ``max_tokens`` => Ollama ``num_predict``.
|
||||
assert options["num_predict"] == 128
|
||||
assert ollama["stream"] is False
|
||||
|
||||
|
||||
def test_openai_completion_to_ollama_generate() -> None:
|
||||
openai_req: dict[str, Any] = {
|
||||
"model": "llama3.1:8b",
|
||||
"prompt": "once upon a time",
|
||||
"stream": True,
|
||||
}
|
||||
ollama = translate.openai_completion_to_ollama(openai_req)
|
||||
assert ollama["model"] == "llama3.1:8b"
|
||||
assert ollama["prompt"] == "once upon a time"
|
||||
assert ollama["stream"] is True
|
||||
|
||||
|
||||
def test_openai_embeddings_to_ollama_embed() -> None:
|
||||
openai_req: dict[str, Any] = {"model": "nomic-embed-text", "input": "hello world"}
|
||||
ollama = translate.openai_embeddings_to_ollama(openai_req)
|
||||
assert ollama["model"] == "nomic-embed-text"
|
||||
assert ollama["input"] == "hello world"
|
||||
|
||||
|
||||
# --- streaming response translation: Ollama frame -> OpenAI chunk ----------
|
||||
|
||||
|
||||
def test_chat_delta_chunk_to_openai() -> None:
|
||||
frame: dict[str, Any] = {
|
||||
"model": "llama3.1:8b",
|
||||
"message": {"role": "assistant", "content": "Echo:"},
|
||||
"done": False,
|
||||
}
|
||||
out = translate.ollama_chat_chunk_to_openai(
|
||||
frame, completion_id="chatcmpl-x", model="llama3.1:8b", created=1700
|
||||
)
|
||||
assert out["object"] == "chat.completion.chunk"
|
||||
choice = _as_dict(_as_list(out["choices"])[0])
|
||||
delta = _as_dict(choice["delta"])
|
||||
assert delta["content"] == "Echo:"
|
||||
assert choice["finish_reason"] is None
|
||||
|
||||
|
||||
def test_chat_final_chunk_carries_usage_and_finish_reason() -> None:
|
||||
frame: dict[str, Any] = {
|
||||
"model": "llama3.1:8b",
|
||||
"message": {"role": "assistant", "content": ""},
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
"prompt_eval_count": 4,
|
||||
"eval_count": 6,
|
||||
}
|
||||
out = translate.ollama_chat_chunk_to_openai(
|
||||
frame, completion_id="chatcmpl-x", model="llama3.1:8b", created=1700
|
||||
)
|
||||
choice = _as_dict(_as_list(out["choices"])[0])
|
||||
assert choice["finish_reason"] == "stop"
|
||||
usage = _as_dict(out["usage"])
|
||||
assert usage["prompt_tokens"] == 4
|
||||
assert usage["completion_tokens"] == 6
|
||||
assert usage["total_tokens"] == 10
|
||||
|
||||
|
||||
# --- non-streaming response translation ------------------------------------
|
||||
|
||||
|
||||
def test_nonstream_chat_to_openai_completion() -> None:
|
||||
ollama_resp: dict[str, Any] = {
|
||||
"model": "llama3.1:8b",
|
||||
"message": {"role": "assistant", "content": "Echo: hi"},
|
||||
"done": True,
|
||||
"prompt_eval_count": 2,
|
||||
"eval_count": 3,
|
||||
}
|
||||
out = translate.ollama_chat_to_openai(ollama_resp)
|
||||
assert out["object"] == "chat.completion"
|
||||
choice = _as_dict(_as_list(out["choices"])[0])
|
||||
assert choice["message"] == {"role": "assistant", "content": "Echo: hi"}
|
||||
assert choice["finish_reason"] == "stop"
|
||||
assert _as_dict(out["usage"])["total_tokens"] == 5
|
||||
|
||||
|
||||
def test_nonstream_generate_to_openai() -> None:
|
||||
ollama_resp: dict[str, Any] = {
|
||||
"model": "llama3.1:8b",
|
||||
"response": "once upon a time",
|
||||
"done": True,
|
||||
"prompt_eval_count": 1,
|
||||
"eval_count": 4,
|
||||
}
|
||||
out = translate.ollama_generate_to_openai(ollama_resp)
|
||||
assert out["object"] == "text_completion"
|
||||
choice = _as_dict(_as_list(out["choices"])[0])
|
||||
assert choice["text"] == "once upon a time"
|
||||
assert _as_dict(out["usage"])["total_tokens"] == 5
|
||||
|
||||
|
||||
def test_embed_to_openai_shape() -> None:
|
||||
ollama_resp: dict[str, Any] = {
|
||||
"model": "nomic-embed-text",
|
||||
"embeddings": [[0.0, 0.1], [0.2, 0.3]],
|
||||
"prompt_eval_count": 7,
|
||||
}
|
||||
out = translate.ollama_embed_to_openai(ollama_resp, model="nomic-embed-text")
|
||||
assert out["object"] == "list"
|
||||
data = _as_list(out["data"])
|
||||
assert len(data) == 2
|
||||
assert data[0] == {"object": "embedding", "index": 0, "embedding": [0.0, 0.1]}
|
||||
# Embeddings charge prompt tokens only (SPEC §13.1).
|
||||
assert out["usage"] == {"prompt_tokens": 7, "total_tokens": 7}
|
||||
|
||||
|
||||
def test_models_to_openai_list_shape() -> None:
|
||||
out = translate.models_to_openai_list(["llama3.1:8b", "mistral:7b"])
|
||||
assert out["object"] == "list"
|
||||
data = _as_list(out["data"])
|
||||
ids = {_as_dict(m)["id"] for m in data}
|
||||
assert ids == {"llama3.1:8b", "mistral:7b"}
|
||||
assert all(_as_dict(m)["object"] == "model" for m in data)
|
||||
Reference in New Issue
Block a user