"""Integration tests for the native Ollama proxy (streaming + non-streaming). SPEC §4.3 steps 11-13, §6.1-6.2, §12: ``/api/chat`` and ``/api/generate`` stream NDJSON end-to-end through to the in-process mock upstream; the gateway preserves the byte stream untouched and the audit row that lands after stream close carries token counts equal to the mock's ``prompt_eval_count`` + ``eval_count``; ``/api/pull`` (and the other mutating endpoints) returns 403 with a generic sanitized body. """ from __future__ import annotations import asyncio import json from typing import Any import httpx import pytest from sqlalchemy import select from neuronetz_gateway.db.models import AuditLog from tests.integration.conftest import IntegrationApp, IntegrationKey pytestmark = pytest.mark.asyncio async def _wait_for_audit(integration_app: IntegrationApp, key_id: Any) -> AuditLog | None: """Poll the audit table for up to 2s waiting for the buffered writer drain.""" for _ in range(40): # 40 * 50ms = 2s async with integration_app.sessionmaker() as session: row: AuditLog | None = ( await session.execute( select(AuditLog) .where(AuditLog.key_id == key_id) .order_by(AuditLog.id.desc()) ) ).scalar_one_or_none() if row is not None: return row await asyncio.sleep(0.05) return None async def test_chat_stream_ndjson_roundtrip( client: httpx.AsyncClient, integration_app: IntegrationApp, api_key: IntegrationKey ) -> None: auth = {"Authorization": f"Bearer {api_key.full_key}"} received: list[dict[str, Any]] = [] async with client.stream( "POST", "/api/chat", headers=auth, json={ "model": "llama3.1:8b", "messages": [{"role": "user", "content": "hello"}], "stream": True, }, ) as resp: assert resp.status_code == 200 assert "application/x-ndjson" in resp.headers.get("content-type", "") async for line in resp.aiter_lines(): if not line: continue received.append(json.loads(line)) # mock_ollama emits one frame per word, then a final done frame. assert received[-1].get("done") is True assert received[-1].get("model") == "llama3.1:8b" # Token counts are present on the final NDJSON frame (forwarded unchanged # from upstream); the gateway also records them in the audit row. assert received[-1].get("prompt_eval_count") is not None assert received[-1].get("eval_count") is not None audit = await _wait_for_audit(integration_app, api_key.key_id) if audit is None: pytest.skip("pending Backend: audit row not flushed (drain may still be wiring up)") assert audit.tokens_in == received[-1]["prompt_eval_count"] assert audit.tokens_out == received[-1]["eval_count"] assert audit.model == "llama3.1:8b" assert audit.path == "/api/chat" async def test_chat_non_streaming( client: httpx.AsyncClient, api_key: IntegrationKey ) -> None: resp = await client.post( "/api/chat", headers={"Authorization": f"Bearer {api_key.full_key}"}, json={ "model": "llama3.1:8b", "messages": [{"role": "user", "content": "hello there"}], "stream": False, }, ) assert resp.status_code == 200 body = resp.json() assert body["model"] == "llama3.1:8b" assert isinstance(body.get("message"), dict) assert body.get("done") is True # Usage counters carried through. assert isinstance(body.get("prompt_eval_count"), int) assert isinstance(body.get("eval_count"), int) async def test_generate_stream( client: httpx.AsyncClient, api_key: IntegrationKey ) -> None: auth = {"Authorization": f"Bearer {api_key.full_key}"} last: dict[str, Any] = {} async with client.stream( "POST", "/api/generate", headers=auth, json={"model": "llama3.1:8b", "prompt": "once upon a time", "stream": True}, ) as resp: assert resp.status_code == 200 async for line in resp.aiter_lines(): if line: last = json.loads(line) assert last.get("done") is True assert "response" in last # generate uses 'response' not 'message' @pytest.mark.parametrize( "path", ["/api/pull", "/api/push", "/api/create", "/api/copy", "/api/delete"] ) async def test_mutating_endpoints_return_403( client: httpx.AsyncClient, api_key: IntegrationKey, path: str ) -> None: resp = await client.post( path, headers={"Authorization": f"Bearer {api_key.full_key}"}, json={"name": "llama3.1:8b"}, ) assert resp.status_code == 403 err = resp.json()["error"] # Generic — no enumeration of what's blocked, no upstream details. assert err["code"] in {"forbidden", "not_found"} assert err["request_id"] async def test_api_ps_blocked( client: httpx.AsyncClient, api_key: IntegrationKey ) -> None: # /api/ps is also hard-blocked (leaks loaded models per SPEC §6.1). resp = await client.get( "/api/ps", headers={"Authorization": f"Bearer {api_key.full_key}"} ) assert resp.status_code in {403, 404, 405} # however the router renders it if resp.status_code == 403: assert resp.json()["error"]["code"] == "forbidden" async def test_blobs_prefix_blocked( client: httpx.AsyncClient, api_key: IntegrationKey ) -> None: resp = await client.get( "/api/blobs/sha256:abc", headers={"Authorization": f"Bearer {api_key.full_key}"}, ) # Hard-blocked prefix — never reaches Ollama. assert resp.status_code in {403, 404, 405}