152 lines
5.6 KiB
Python
152 lines
5.6 KiB
Python
"""Tests for the inference client — both psyc-native and openai-compatible modes."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import importlib
|
|
import json
|
|
from typing import Any, Dict
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from psyc.cockpit import inference
|
|
from psyc.models import Case, Classification, Confidence, Severity, TLP, Observables, Evidence, Victim
|
|
|
|
|
|
def _reload_with_env(monkeypatch, **env: str) -> Any:
|
|
for k, v in env.items():
|
|
monkeypatch.setenv(k, v)
|
|
return importlib.reload(inference)
|
|
|
|
|
|
def _case() -> Case:
|
|
return Case(
|
|
case_id="C-T-1",
|
|
summary="test",
|
|
source_type="test",
|
|
source_ref="",
|
|
observed_at="2026-01-01T00:00:00+00:00",
|
|
ingested_at="2026-01-01T00:00:00+00:00",
|
|
classification=Classification(tlp=TLP.GREEN, severity=Severity.HIGH),
|
|
confidence=Confidence(level="medium", source_reliability="B", information_credibility="2"),
|
|
observables=Observables(),
|
|
evidence=Evidence(),
|
|
source_metadata={},
|
|
victim=Victim(),
|
|
)
|
|
|
|
|
|
def test_no_auth_header_when_token_unset(monkeypatch):
|
|
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_TOKEN="")
|
|
assert mod._auth_headers() == {}
|
|
|
|
|
|
def test_bearer_header_when_token_set(monkeypatch):
|
|
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_TOKEN="abc123")
|
|
assert mod._auth_headers() == {"Authorization": "Bearer abc123"}
|
|
|
|
|
|
def test_psyc_mode_server_adapter(monkeypatch):
|
|
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="psyc", PSYC_INFERENCE_URL="http://x")
|
|
seen: Dict[str, Any] = {}
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
seen["url"] = str(request.url)
|
|
seen["method"] = request.method
|
|
return httpx.Response(200, json={"adapter": "/data/adapters/psyc-v5/final"})
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
real_client = httpx.Client
|
|
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
|
|
|
assert mod.server_adapter() == "/data/adapters/psyc-v5/final"
|
|
assert seen["url"].endswith("/healthz")
|
|
|
|
|
|
def test_openai_mode_server_adapter(monkeypatch):
|
|
mod = _reload_with_env(
|
|
monkeypatch,
|
|
PSYC_INFERENCE_MODE="openai",
|
|
PSYC_INFERENCE_URL="https://api.example",
|
|
PSYC_INFERENCE_TOKEN="t0k",
|
|
PSYC_INFERENCE_MODEL="psyc-v5",
|
|
)
|
|
seen: Dict[str, Any] = {}
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
seen["url"] = str(request.url)
|
|
seen["auth"] = request.headers.get("authorization")
|
|
return httpx.Response(200, json={"data": [{"id": "llama3"}, {"id": "mistral"}]})
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
real_client = httpx.Client
|
|
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
|
|
|
assert mod.server_adapter() == "llama3"
|
|
assert seen["url"].endswith("/v1/models")
|
|
assert seen["auth"] == "Bearer t0k"
|
|
|
|
|
|
def test_openai_mode_severity_request_shape(monkeypatch):
|
|
mod = _reload_with_env(
|
|
monkeypatch,
|
|
PSYC_INFERENCE_MODE="openai",
|
|
PSYC_INFERENCE_URL="https://api.example",
|
|
PSYC_INFERENCE_TOKEN="t0k",
|
|
PSYC_INFERENCE_MODEL="psyc-v5",
|
|
)
|
|
sent: Dict[str, Any] = {}
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
sent["url"] = str(request.url)
|
|
sent["auth"] = request.headers.get("authorization")
|
|
sent["body"] = json.loads(request.content.decode())
|
|
return httpx.Response(200, json={"choices": [{"message": {"content": "HIGH"}}]})
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
real_client = httpx.Client
|
|
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
|
|
|
result = mod.model_severity(_case())
|
|
assert result == "high"
|
|
assert sent["url"].endswith("/v1/chat/completions")
|
|
assert sent["auth"] == "Bearer t0k"
|
|
assert sent["body"]["model"] == "psyc-v5"
|
|
assert sent["body"]["messages"][0]["role"] == "system"
|
|
assert sent["body"]["messages"][1]["role"] == "user"
|
|
assert sent["body"]["max_tokens"] == 16
|
|
|
|
|
|
def test_psyc_mode_severity_unchanged(monkeypatch):
|
|
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="psyc", PSYC_INFERENCE_URL="http://x", PSYC_INFERENCE_TOKEN="")
|
|
sent: Dict[str, Any] = {}
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
sent["url"] = str(request.url)
|
|
sent["auth"] = request.headers.get("authorization")
|
|
sent["body"] = json.loads(request.content.decode())
|
|
return httpx.Response(200, json={"output": "MEDIUM"})
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
real_client = httpx.Client
|
|
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
|
|
|
assert mod.model_severity(_case()) == "medium"
|
|
assert sent["url"].endswith("/infer")
|
|
assert sent["auth"] is None
|
|
assert "instruction" in sent["body"]
|
|
assert "max_new_tokens" in sent["body"]
|
|
|
|
|
|
def test_server_adapter_returns_none_on_http_error(monkeypatch):
|
|
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="openai", PSYC_INFERENCE_URL="https://api.example")
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
return httpx.Response(401, json={"error": "unauthorized"})
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
real_client = httpx.Client
|
|
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
|
|
|
assert mod.server_adapter() is None
|