diff --git a/src/psyc/cockpit/inference.py b/src/psyc/cockpit/inference.py index 2a47acf..06d6690 100644 --- a/src/psyc/cockpit/inference.py +++ b/src/psyc/cockpit/inference.py @@ -3,13 +3,19 @@ The cockpit venv has no torch; the fine-tuned model only runs inside the CUDA container behind serve_model.py. This client reaches it over HTTP and degrades gracefully — if the server is down, callers get None and fall back to rules. + +Two backends are supported via PSYC_INFERENCE_MODE: + - "psyc" (default) — native serve_model.py, POST /infer + - "openai" — OpenAI-compatible / Ollama, POST /v1/chat/completions +A bearer token can be set via PSYC_INFERENCE_TOKEN; it is sent on every request +when present (psyc-native ignores it; api.neuronetz.ai requires it). """ from __future__ import annotations import json import os -from typing import Optional +from typing import Dict, Optional import httpx @@ -18,15 +24,31 @@ from psyc.lines.train import SEVERITY_INSTRUCTION, severity_features from psyc.models import Case -INFERENCE_URL = os.environ.get("PSYC_INFERENCE_URL", "http://127.0.0.1:8771") +INFERENCE_URL = os.environ.get("PSYC_INFERENCE_URL", "http://127.0.0.1:8771") +INFERENCE_TOKEN = os.environ.get("PSYC_INFERENCE_TOKEN", "") +INFERENCE_MODE = os.environ.get("PSYC_INFERENCE_MODE", "psyc").lower() +INFERENCE_MODEL = os.environ.get("PSYC_INFERENCE_MODEL", "psyc-v5") _log = log.get(__name__) +def _auth_headers() -> Dict[str, str]: + """Bearer header when a token is set, empty dict otherwise.""" + return {"Authorization": f"Bearer {INFERENCE_TOKEN}"} if INFERENCE_TOKEN else {} + + def server_adapter(timeout: float = 2.0) -> Optional[str]: """Return the adapter the server is running, or None if it is unreachable.""" try: with httpx.Client(timeout=timeout) as client: + if INFERENCE_MODE == "openai": + # OpenAI/Ollama exposes GET /v1/models — first available id wins. + resp = client.get(f"{INFERENCE_URL}/v1/models", headers=_auth_headers()) + resp.raise_for_status() + data = resp.json().get("data") or [] + if data: + return str(data[0].get("id") or INFERENCE_MODEL) + return INFERENCE_MODEL resp = client.get(f"{INFERENCE_URL}/healthz") resp.raise_for_status() return resp.json().get("adapter") @@ -45,18 +67,44 @@ def adapter_name(timeout: float = 2.0) -> Optional[str]: def model_severity(case: Case, timeout: float = 15.0) -> Optional[str]: """Ask the live model to classify case severity. None if the server is down.""" - payload = { - "instruction": SEVERITY_INSTRUCTION, - "input": json.dumps(severity_features(case), ensure_ascii=False), - "max_new_tokens": 16, - } + features_json = json.dumps(severity_features(case), ensure_ascii=False) try: with httpx.Client(timeout=timeout) as client: - resp = client.post(f"{INFERENCE_URL}/infer", json=payload) - resp.raise_for_status() - output = str(resp.json().get("output", "")).strip().lower() + if INFERENCE_MODE == "openai": + payload = { + "model": INFERENCE_MODEL, + "messages": [ + {"role": "system", "content": SEVERITY_INSTRUCTION}, + {"role": "user", "content": features_json}, + ], + "max_tokens": 16, + "temperature": 0.0, + } + resp = client.post( + f"{INFERENCE_URL}/v1/chat/completions", + json=payload, + headers=_auth_headers(), + ) + resp.raise_for_status() + choices = resp.json().get("choices") or [] + if not choices: + return None + output = str(choices[0].get("message", {}).get("content", "")).strip().lower() + else: + payload = { + "instruction": SEVERITY_INSTRUCTION, + "input": features_json, + "max_new_tokens": 16, + } + resp = client.post( + f"{INFERENCE_URL}/infer", + json=payload, + headers=_auth_headers(), + ) + resp.raise_for_status() + output = str(resp.json().get("output", "")).strip().lower() except httpx.HTTPError as exc: _log.info("inference.unavailable", error=str(exc)) return None - _log.info("inference.severity", case_id=case.case_id, model_answer=output) + _log.info("inference.severity", case_id=case.case_id, model_answer=output, mode=INFERENCE_MODE) return output diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 0000000..a5dcc48 --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,151 @@ +"""Tests for the inference client — both psyc-native and openai-compatible modes.""" + +from __future__ import annotations + +import importlib +import json +from typing import Any, Dict + +import httpx +import pytest + +from psyc.cockpit import inference +from psyc.models import Case, Classification, Confidence, Severity, TLP, Observables, Evidence, Victim + + +def _reload_with_env(monkeypatch, **env: str) -> Any: + for k, v in env.items(): + monkeypatch.setenv(k, v) + return importlib.reload(inference) + + +def _case() -> Case: + return Case( + case_id="C-T-1", + summary="test", + source_type="test", + source_ref="", + observed_at="2026-01-01T00:00:00+00:00", + ingested_at="2026-01-01T00:00:00+00:00", + classification=Classification(tlp=TLP.GREEN, severity=Severity.HIGH), + confidence=Confidence(level="medium", source_reliability="B", information_credibility="2"), + observables=Observables(), + evidence=Evidence(), + source_metadata={}, + victim=Victim(), + ) + + +def test_no_auth_header_when_token_unset(monkeypatch): + mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_TOKEN="") + assert mod._auth_headers() == {} + + +def test_bearer_header_when_token_set(monkeypatch): + mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_TOKEN="abc123") + assert mod._auth_headers() == {"Authorization": "Bearer abc123"} + + +def test_psyc_mode_server_adapter(monkeypatch): + mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="psyc", PSYC_INFERENCE_URL="http://x") + seen: Dict[str, Any] = {} + + def handler(request: httpx.Request) -> httpx.Response: + seen["url"] = str(request.url) + seen["method"] = request.method + return httpx.Response(200, json={"adapter": "/data/adapters/psyc-v5/final"}) + + transport = httpx.MockTransport(handler) + real_client = httpx.Client + monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"})) + + assert mod.server_adapter() == "/data/adapters/psyc-v5/final" + assert seen["url"].endswith("/healthz") + + +def test_openai_mode_server_adapter(monkeypatch): + mod = _reload_with_env( + monkeypatch, + PSYC_INFERENCE_MODE="openai", + PSYC_INFERENCE_URL="https://api.example", + PSYC_INFERENCE_TOKEN="t0k", + PSYC_INFERENCE_MODEL="psyc-v5", + ) + seen: Dict[str, Any] = {} + + def handler(request: httpx.Request) -> httpx.Response: + seen["url"] = str(request.url) + seen["auth"] = request.headers.get("authorization") + return httpx.Response(200, json={"data": [{"id": "llama3"}, {"id": "mistral"}]}) + + transport = httpx.MockTransport(handler) + real_client = httpx.Client + monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"})) + + assert mod.server_adapter() == "llama3" + assert seen["url"].endswith("/v1/models") + assert seen["auth"] == "Bearer t0k" + + +def test_openai_mode_severity_request_shape(monkeypatch): + mod = _reload_with_env( + monkeypatch, + PSYC_INFERENCE_MODE="openai", + PSYC_INFERENCE_URL="https://api.example", + PSYC_INFERENCE_TOKEN="t0k", + PSYC_INFERENCE_MODEL="psyc-v5", + ) + sent: Dict[str, Any] = {} + + def handler(request: httpx.Request) -> httpx.Response: + sent["url"] = str(request.url) + sent["auth"] = request.headers.get("authorization") + sent["body"] = json.loads(request.content.decode()) + return httpx.Response(200, json={"choices": [{"message": {"content": "HIGH"}}]}) + + transport = httpx.MockTransport(handler) + real_client = httpx.Client + monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"})) + + result = mod.model_severity(_case()) + assert result == "high" + assert sent["url"].endswith("/v1/chat/completions") + assert sent["auth"] == "Bearer t0k" + assert sent["body"]["model"] == "psyc-v5" + assert sent["body"]["messages"][0]["role"] == "system" + assert sent["body"]["messages"][1]["role"] == "user" + assert sent["body"]["max_tokens"] == 16 + + +def test_psyc_mode_severity_unchanged(monkeypatch): + mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="psyc", PSYC_INFERENCE_URL="http://x", PSYC_INFERENCE_TOKEN="") + sent: Dict[str, Any] = {} + + def handler(request: httpx.Request) -> httpx.Response: + sent["url"] = str(request.url) + sent["auth"] = request.headers.get("authorization") + sent["body"] = json.loads(request.content.decode()) + return httpx.Response(200, json={"output": "MEDIUM"}) + + transport = httpx.MockTransport(handler) + real_client = httpx.Client + monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"})) + + assert mod.model_severity(_case()) == "medium" + assert sent["url"].endswith("/infer") + assert sent["auth"] is None + assert "instruction" in sent["body"] + assert "max_new_tokens" in sent["body"] + + +def test_server_adapter_returns_none_on_http_error(monkeypatch): + mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="openai", PSYC_INFERENCE_URL="https://api.example") + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(401, json={"error": "unauthorized"}) + + transport = httpx.MockTransport(handler) + real_client = httpx.Client + monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"})) + + assert mod.server_adapter() is None