inference: openai-compatible mode + bearer auth (for api.neuronetz.ai etc.)

This commit is contained in:
m17hr1l
2026-06-07 01:09:19 +02:00
parent 2c7f71eff8
commit 351e16c3ce
2 changed files with 210 additions and 11 deletions

View File

@@ -3,13 +3,19 @@
The cockpit venv has no torch; the fine-tuned model only runs inside the CUDA
container behind serve_model.py. This client reaches it over HTTP and degrades
gracefully — if the server is down, callers get None and fall back to rules.
Two backends are supported via PSYC_INFERENCE_MODE:
- "psyc" (default) — native serve_model.py, POST /infer
- "openai" — OpenAI-compatible / Ollama, POST /v1/chat/completions
A bearer token can be set via PSYC_INFERENCE_TOKEN; it is sent on every request
when present (psyc-native ignores it; api.neuronetz.ai requires it).
"""
from __future__ import annotations
import json
import os
from typing import Optional
from typing import Dict, Optional
import httpx
@@ -18,15 +24,31 @@ from psyc.lines.train import SEVERITY_INSTRUCTION, severity_features
from psyc.models import Case
INFERENCE_URL = os.environ.get("PSYC_INFERENCE_URL", "http://127.0.0.1:8771")
INFERENCE_URL = os.environ.get("PSYC_INFERENCE_URL", "http://127.0.0.1:8771")
INFERENCE_TOKEN = os.environ.get("PSYC_INFERENCE_TOKEN", "")
INFERENCE_MODE = os.environ.get("PSYC_INFERENCE_MODE", "psyc").lower()
INFERENCE_MODEL = os.environ.get("PSYC_INFERENCE_MODEL", "psyc-v5")
_log = log.get(__name__)
def _auth_headers() -> Dict[str, str]:
"""Bearer header when a token is set, empty dict otherwise."""
return {"Authorization": f"Bearer {INFERENCE_TOKEN}"} if INFERENCE_TOKEN else {}
def server_adapter(timeout: float = 2.0) -> Optional[str]:
"""Return the adapter the server is running, or None if it is unreachable."""
try:
with httpx.Client(timeout=timeout) as client:
if INFERENCE_MODE == "openai":
# OpenAI/Ollama exposes GET /v1/models — first available id wins.
resp = client.get(f"{INFERENCE_URL}/v1/models", headers=_auth_headers())
resp.raise_for_status()
data = resp.json().get("data") or []
if data:
return str(data[0].get("id") or INFERENCE_MODEL)
return INFERENCE_MODEL
resp = client.get(f"{INFERENCE_URL}/healthz")
resp.raise_for_status()
return resp.json().get("adapter")
@@ -45,18 +67,44 @@ def adapter_name(timeout: float = 2.0) -> Optional[str]:
def model_severity(case: Case, timeout: float = 15.0) -> Optional[str]:
"""Ask the live model to classify case severity. None if the server is down."""
payload = {
"instruction": SEVERITY_INSTRUCTION,
"input": json.dumps(severity_features(case), ensure_ascii=False),
"max_new_tokens": 16,
}
features_json = json.dumps(severity_features(case), ensure_ascii=False)
try:
with httpx.Client(timeout=timeout) as client:
resp = client.post(f"{INFERENCE_URL}/infer", json=payload)
resp.raise_for_status()
output = str(resp.json().get("output", "")).strip().lower()
if INFERENCE_MODE == "openai":
payload = {
"model": INFERENCE_MODEL,
"messages": [
{"role": "system", "content": SEVERITY_INSTRUCTION},
{"role": "user", "content": features_json},
],
"max_tokens": 16,
"temperature": 0.0,
}
resp = client.post(
f"{INFERENCE_URL}/v1/chat/completions",
json=payload,
headers=_auth_headers(),
)
resp.raise_for_status()
choices = resp.json().get("choices") or []
if not choices:
return None
output = str(choices[0].get("message", {}).get("content", "")).strip().lower()
else:
payload = {
"instruction": SEVERITY_INSTRUCTION,
"input": features_json,
"max_new_tokens": 16,
}
resp = client.post(
f"{INFERENCE_URL}/infer",
json=payload,
headers=_auth_headers(),
)
resp.raise_for_status()
output = str(resp.json().get("output", "")).strip().lower()
except httpx.HTTPError as exc:
_log.info("inference.unavailable", error=str(exc))
return None
_log.info("inference.severity", case_id=case.case_id, model_answer=output)
_log.info("inference.severity", case_id=case.case_id, model_answer=output, mode=INFERENCE_MODE)
return output

151
tests/test_inference.py Normal file
View File

@@ -0,0 +1,151 @@
"""Tests for the inference client — both psyc-native and openai-compatible modes."""
from __future__ import annotations
import importlib
import json
from typing import Any, Dict
import httpx
import pytest
from psyc.cockpit import inference
from psyc.models import Case, Classification, Confidence, Severity, TLP, Observables, Evidence, Victim
def _reload_with_env(monkeypatch, **env: str) -> Any:
for k, v in env.items():
monkeypatch.setenv(k, v)
return importlib.reload(inference)
def _case() -> Case:
return Case(
case_id="C-T-1",
summary="test",
source_type="test",
source_ref="",
observed_at="2026-01-01T00:00:00+00:00",
ingested_at="2026-01-01T00:00:00+00:00",
classification=Classification(tlp=TLP.GREEN, severity=Severity.HIGH),
confidence=Confidence(level="medium", source_reliability="B", information_credibility="2"),
observables=Observables(),
evidence=Evidence(),
source_metadata={},
victim=Victim(),
)
def test_no_auth_header_when_token_unset(monkeypatch):
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_TOKEN="")
assert mod._auth_headers() == {}
def test_bearer_header_when_token_set(monkeypatch):
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_TOKEN="abc123")
assert mod._auth_headers() == {"Authorization": "Bearer abc123"}
def test_psyc_mode_server_adapter(monkeypatch):
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="psyc", PSYC_INFERENCE_URL="http://x")
seen: Dict[str, Any] = {}
def handler(request: httpx.Request) -> httpx.Response:
seen["url"] = str(request.url)
seen["method"] = request.method
return httpx.Response(200, json={"adapter": "/data/adapters/psyc-v5/final"})
transport = httpx.MockTransport(handler)
real_client = httpx.Client
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
assert mod.server_adapter() == "/data/adapters/psyc-v5/final"
assert seen["url"].endswith("/healthz")
def test_openai_mode_server_adapter(monkeypatch):
mod = _reload_with_env(
monkeypatch,
PSYC_INFERENCE_MODE="openai",
PSYC_INFERENCE_URL="https://api.example",
PSYC_INFERENCE_TOKEN="t0k",
PSYC_INFERENCE_MODEL="psyc-v5",
)
seen: Dict[str, Any] = {}
def handler(request: httpx.Request) -> httpx.Response:
seen["url"] = str(request.url)
seen["auth"] = request.headers.get("authorization")
return httpx.Response(200, json={"data": [{"id": "llama3"}, {"id": "mistral"}]})
transport = httpx.MockTransport(handler)
real_client = httpx.Client
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
assert mod.server_adapter() == "llama3"
assert seen["url"].endswith("/v1/models")
assert seen["auth"] == "Bearer t0k"
def test_openai_mode_severity_request_shape(monkeypatch):
mod = _reload_with_env(
monkeypatch,
PSYC_INFERENCE_MODE="openai",
PSYC_INFERENCE_URL="https://api.example",
PSYC_INFERENCE_TOKEN="t0k",
PSYC_INFERENCE_MODEL="psyc-v5",
)
sent: Dict[str, Any] = {}
def handler(request: httpx.Request) -> httpx.Response:
sent["url"] = str(request.url)
sent["auth"] = request.headers.get("authorization")
sent["body"] = json.loads(request.content.decode())
return httpx.Response(200, json={"choices": [{"message": {"content": "HIGH"}}]})
transport = httpx.MockTransport(handler)
real_client = httpx.Client
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
result = mod.model_severity(_case())
assert result == "high"
assert sent["url"].endswith("/v1/chat/completions")
assert sent["auth"] == "Bearer t0k"
assert sent["body"]["model"] == "psyc-v5"
assert sent["body"]["messages"][0]["role"] == "system"
assert sent["body"]["messages"][1]["role"] == "user"
assert sent["body"]["max_tokens"] == 16
def test_psyc_mode_severity_unchanged(monkeypatch):
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="psyc", PSYC_INFERENCE_URL="http://x", PSYC_INFERENCE_TOKEN="")
sent: Dict[str, Any] = {}
def handler(request: httpx.Request) -> httpx.Response:
sent["url"] = str(request.url)
sent["auth"] = request.headers.get("authorization")
sent["body"] = json.loads(request.content.decode())
return httpx.Response(200, json={"output": "MEDIUM"})
transport = httpx.MockTransport(handler)
real_client = httpx.Client
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
assert mod.model_severity(_case()) == "medium"
assert sent["url"].endswith("/infer")
assert sent["auth"] is None
assert "instruction" in sent["body"]
assert "max_new_tokens" in sent["body"]
def test_server_adapter_returns_none_on_http_error(monkeypatch):
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="openai", PSYC_INFERENCE_URL="https://api.example")
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(401, json={"error": "unauthorized"})
transport = httpx.MockTransport(handler)
real_client = httpx.Client
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
assert mod.server_adapter() is None