inference: openai-compatible mode + bearer auth (for api.neuronetz.ai etc.)
This commit is contained in:
151
tests/test_inference.py
Normal file
151
tests/test_inference.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for the inference client — both psyc-native and openai-compatible modes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from psyc.cockpit import inference
|
||||
from psyc.models import Case, Classification, Confidence, Severity, TLP, Observables, Evidence, Victim
|
||||
|
||||
|
||||
def _reload_with_env(monkeypatch, **env: str) -> Any:
|
||||
for k, v in env.items():
|
||||
monkeypatch.setenv(k, v)
|
||||
return importlib.reload(inference)
|
||||
|
||||
|
||||
def _case() -> Case:
|
||||
return Case(
|
||||
case_id="C-T-1",
|
||||
summary="test",
|
||||
source_type="test",
|
||||
source_ref="",
|
||||
observed_at="2026-01-01T00:00:00+00:00",
|
||||
ingested_at="2026-01-01T00:00:00+00:00",
|
||||
classification=Classification(tlp=TLP.GREEN, severity=Severity.HIGH),
|
||||
confidence=Confidence(level="medium", source_reliability="B", information_credibility="2"),
|
||||
observables=Observables(),
|
||||
evidence=Evidence(),
|
||||
source_metadata={},
|
||||
victim=Victim(),
|
||||
)
|
||||
|
||||
|
||||
def test_no_auth_header_when_token_unset(monkeypatch):
|
||||
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_TOKEN="")
|
||||
assert mod._auth_headers() == {}
|
||||
|
||||
|
||||
def test_bearer_header_when_token_set(monkeypatch):
|
||||
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_TOKEN="abc123")
|
||||
assert mod._auth_headers() == {"Authorization": "Bearer abc123"}
|
||||
|
||||
|
||||
def test_psyc_mode_server_adapter(monkeypatch):
|
||||
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="psyc", PSYC_INFERENCE_URL="http://x")
|
||||
seen: Dict[str, Any] = {}
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
seen["url"] = str(request.url)
|
||||
seen["method"] = request.method
|
||||
return httpx.Response(200, json={"adapter": "/data/adapters/psyc-v5/final"})
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
real_client = httpx.Client
|
||||
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
||||
|
||||
assert mod.server_adapter() == "/data/adapters/psyc-v5/final"
|
||||
assert seen["url"].endswith("/healthz")
|
||||
|
||||
|
||||
def test_openai_mode_server_adapter(monkeypatch):
|
||||
mod = _reload_with_env(
|
||||
monkeypatch,
|
||||
PSYC_INFERENCE_MODE="openai",
|
||||
PSYC_INFERENCE_URL="https://api.example",
|
||||
PSYC_INFERENCE_TOKEN="t0k",
|
||||
PSYC_INFERENCE_MODEL="psyc-v5",
|
||||
)
|
||||
seen: Dict[str, Any] = {}
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
seen["url"] = str(request.url)
|
||||
seen["auth"] = request.headers.get("authorization")
|
||||
return httpx.Response(200, json={"data": [{"id": "llama3"}, {"id": "mistral"}]})
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
real_client = httpx.Client
|
||||
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
||||
|
||||
assert mod.server_adapter() == "llama3"
|
||||
assert seen["url"].endswith("/v1/models")
|
||||
assert seen["auth"] == "Bearer t0k"
|
||||
|
||||
|
||||
def test_openai_mode_severity_request_shape(monkeypatch):
|
||||
mod = _reload_with_env(
|
||||
monkeypatch,
|
||||
PSYC_INFERENCE_MODE="openai",
|
||||
PSYC_INFERENCE_URL="https://api.example",
|
||||
PSYC_INFERENCE_TOKEN="t0k",
|
||||
PSYC_INFERENCE_MODEL="psyc-v5",
|
||||
)
|
||||
sent: Dict[str, Any] = {}
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
sent["url"] = str(request.url)
|
||||
sent["auth"] = request.headers.get("authorization")
|
||||
sent["body"] = json.loads(request.content.decode())
|
||||
return httpx.Response(200, json={"choices": [{"message": {"content": "HIGH"}}]})
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
real_client = httpx.Client
|
||||
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
||||
|
||||
result = mod.model_severity(_case())
|
||||
assert result == "high"
|
||||
assert sent["url"].endswith("/v1/chat/completions")
|
||||
assert sent["auth"] == "Bearer t0k"
|
||||
assert sent["body"]["model"] == "psyc-v5"
|
||||
assert sent["body"]["messages"][0]["role"] == "system"
|
||||
assert sent["body"]["messages"][1]["role"] == "user"
|
||||
assert sent["body"]["max_tokens"] == 16
|
||||
|
||||
|
||||
def test_psyc_mode_severity_unchanged(monkeypatch):
|
||||
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="psyc", PSYC_INFERENCE_URL="http://x", PSYC_INFERENCE_TOKEN="")
|
||||
sent: Dict[str, Any] = {}
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
sent["url"] = str(request.url)
|
||||
sent["auth"] = request.headers.get("authorization")
|
||||
sent["body"] = json.loads(request.content.decode())
|
||||
return httpx.Response(200, json={"output": "MEDIUM"})
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
real_client = httpx.Client
|
||||
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
||||
|
||||
assert mod.model_severity(_case()) == "medium"
|
||||
assert sent["url"].endswith("/infer")
|
||||
assert sent["auth"] is None
|
||||
assert "instruction" in sent["body"]
|
||||
assert "max_new_tokens" in sent["body"]
|
||||
|
||||
|
||||
def test_server_adapter_returns_none_on_http_error(monkeypatch):
|
||||
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="openai", PSYC_INFERENCE_URL="https://api.example")
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(401, json={"error": "unauthorized"})
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
real_client = httpx.Client
|
||||
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
||||
|
||||
assert mod.server_adapter() is None
|
||||
Reference in New Issue
Block a user