inference: openai-compatible mode + bearer auth (for api.neuronetz.ai etc.)
This commit is contained in:
@@ -3,13 +3,19 @@
|
|||||||
The cockpit venv has no torch; the fine-tuned model only runs inside the CUDA
|
The cockpit venv has no torch; the fine-tuned model only runs inside the CUDA
|
||||||
container behind serve_model.py. This client reaches it over HTTP and degrades
|
container behind serve_model.py. This client reaches it over HTTP and degrades
|
||||||
gracefully — if the server is down, callers get None and fall back to rules.
|
gracefully — if the server is down, callers get None and fall back to rules.
|
||||||
|
|
||||||
|
Two backends are supported via PSYC_INFERENCE_MODE:
|
||||||
|
- "psyc" (default) — native serve_model.py, POST /infer
|
||||||
|
- "openai" — OpenAI-compatible / Ollama, POST /v1/chat/completions
|
||||||
|
A bearer token can be set via PSYC_INFERENCE_TOKEN; it is sent on every request
|
||||||
|
when present (psyc-native ignores it; api.neuronetz.ai requires it).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -18,15 +24,31 @@ from psyc.lines.train import SEVERITY_INSTRUCTION, severity_features
|
|||||||
from psyc.models import Case
|
from psyc.models import Case
|
||||||
|
|
||||||
|
|
||||||
INFERENCE_URL = os.environ.get("PSYC_INFERENCE_URL", "http://127.0.0.1:8771")
|
INFERENCE_URL = os.environ.get("PSYC_INFERENCE_URL", "http://127.0.0.1:8771")
|
||||||
|
INFERENCE_TOKEN = os.environ.get("PSYC_INFERENCE_TOKEN", "")
|
||||||
|
INFERENCE_MODE = os.environ.get("PSYC_INFERENCE_MODE", "psyc").lower()
|
||||||
|
INFERENCE_MODEL = os.environ.get("PSYC_INFERENCE_MODEL", "psyc-v5")
|
||||||
|
|
||||||
_log = log.get(__name__)
|
_log = log.get(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_headers() -> Dict[str, str]:
|
||||||
|
"""Bearer header when a token is set, empty dict otherwise."""
|
||||||
|
return {"Authorization": f"Bearer {INFERENCE_TOKEN}"} if INFERENCE_TOKEN else {}
|
||||||
|
|
||||||
|
|
||||||
def server_adapter(timeout: float = 2.0) -> Optional[str]:
|
def server_adapter(timeout: float = 2.0) -> Optional[str]:
|
||||||
"""Return the adapter the server is running, or None if it is unreachable."""
|
"""Return the adapter the server is running, or None if it is unreachable."""
|
||||||
try:
|
try:
|
||||||
with httpx.Client(timeout=timeout) as client:
|
with httpx.Client(timeout=timeout) as client:
|
||||||
|
if INFERENCE_MODE == "openai":
|
||||||
|
# OpenAI/Ollama exposes GET /v1/models — first available id wins.
|
||||||
|
resp = client.get(f"{INFERENCE_URL}/v1/models", headers=_auth_headers())
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json().get("data") or []
|
||||||
|
if data:
|
||||||
|
return str(data[0].get("id") or INFERENCE_MODEL)
|
||||||
|
return INFERENCE_MODEL
|
||||||
resp = client.get(f"{INFERENCE_URL}/healthz")
|
resp = client.get(f"{INFERENCE_URL}/healthz")
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json().get("adapter")
|
return resp.json().get("adapter")
|
||||||
@@ -45,18 +67,44 @@ def adapter_name(timeout: float = 2.0) -> Optional[str]:
|
|||||||
|
|
||||||
def model_severity(case: Case, timeout: float = 15.0) -> Optional[str]:
|
def model_severity(case: Case, timeout: float = 15.0) -> Optional[str]:
|
||||||
"""Ask the live model to classify case severity. None if the server is down."""
|
"""Ask the live model to classify case severity. None if the server is down."""
|
||||||
payload = {
|
features_json = json.dumps(severity_features(case), ensure_ascii=False)
|
||||||
"instruction": SEVERITY_INSTRUCTION,
|
|
||||||
"input": json.dumps(severity_features(case), ensure_ascii=False),
|
|
||||||
"max_new_tokens": 16,
|
|
||||||
}
|
|
||||||
try:
|
try:
|
||||||
with httpx.Client(timeout=timeout) as client:
|
with httpx.Client(timeout=timeout) as client:
|
||||||
resp = client.post(f"{INFERENCE_URL}/infer", json=payload)
|
if INFERENCE_MODE == "openai":
|
||||||
resp.raise_for_status()
|
payload = {
|
||||||
output = str(resp.json().get("output", "")).strip().lower()
|
"model": INFERENCE_MODEL,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": SEVERITY_INSTRUCTION},
|
||||||
|
{"role": "user", "content": features_json},
|
||||||
|
],
|
||||||
|
"max_tokens": 16,
|
||||||
|
"temperature": 0.0,
|
||||||
|
}
|
||||||
|
resp = client.post(
|
||||||
|
f"{INFERENCE_URL}/v1/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
headers=_auth_headers(),
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
choices = resp.json().get("choices") or []
|
||||||
|
if not choices:
|
||||||
|
return None
|
||||||
|
output = str(choices[0].get("message", {}).get("content", "")).strip().lower()
|
||||||
|
else:
|
||||||
|
payload = {
|
||||||
|
"instruction": SEVERITY_INSTRUCTION,
|
||||||
|
"input": features_json,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
}
|
||||||
|
resp = client.post(
|
||||||
|
f"{INFERENCE_URL}/infer",
|
||||||
|
json=payload,
|
||||||
|
headers=_auth_headers(),
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
output = str(resp.json().get("output", "")).strip().lower()
|
||||||
except httpx.HTTPError as exc:
|
except httpx.HTTPError as exc:
|
||||||
_log.info("inference.unavailable", error=str(exc))
|
_log.info("inference.unavailable", error=str(exc))
|
||||||
return None
|
return None
|
||||||
_log.info("inference.severity", case_id=case.case_id, model_answer=output)
|
_log.info("inference.severity", case_id=case.case_id, model_answer=output, mode=INFERENCE_MODE)
|
||||||
return output
|
return output
|
||||||
|
|||||||
151
tests/test_inference.py
Normal file
151
tests/test_inference.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
"""Tests for the inference client — both psyc-native and openai-compatible modes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from psyc.cockpit import inference
|
||||||
|
from psyc.models import Case, Classification, Confidence, Severity, TLP, Observables, Evidence, Victim
|
||||||
|
|
||||||
|
|
||||||
|
def _reload_with_env(monkeypatch, **env: str) -> Any:
|
||||||
|
for k, v in env.items():
|
||||||
|
monkeypatch.setenv(k, v)
|
||||||
|
return importlib.reload(inference)
|
||||||
|
|
||||||
|
|
||||||
|
def _case() -> Case:
|
||||||
|
return Case(
|
||||||
|
case_id="C-T-1",
|
||||||
|
summary="test",
|
||||||
|
source_type="test",
|
||||||
|
source_ref="",
|
||||||
|
observed_at="2026-01-01T00:00:00+00:00",
|
||||||
|
ingested_at="2026-01-01T00:00:00+00:00",
|
||||||
|
classification=Classification(tlp=TLP.GREEN, severity=Severity.HIGH),
|
||||||
|
confidence=Confidence(level="medium", source_reliability="B", information_credibility="2"),
|
||||||
|
observables=Observables(),
|
||||||
|
evidence=Evidence(),
|
||||||
|
source_metadata={},
|
||||||
|
victim=Victim(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_auth_header_when_token_unset(monkeypatch):
|
||||||
|
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_TOKEN="")
|
||||||
|
assert mod._auth_headers() == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_bearer_header_when_token_set(monkeypatch):
|
||||||
|
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_TOKEN="abc123")
|
||||||
|
assert mod._auth_headers() == {"Authorization": "Bearer abc123"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_psyc_mode_server_adapter(monkeypatch):
|
||||||
|
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="psyc", PSYC_INFERENCE_URL="http://x")
|
||||||
|
seen: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
seen["url"] = str(request.url)
|
||||||
|
seen["method"] = request.method
|
||||||
|
return httpx.Response(200, json={"adapter": "/data/adapters/psyc-v5/final"})
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
real_client = httpx.Client
|
||||||
|
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
||||||
|
|
||||||
|
assert mod.server_adapter() == "/data/adapters/psyc-v5/final"
|
||||||
|
assert seen["url"].endswith("/healthz")
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_mode_server_adapter(monkeypatch):
|
||||||
|
mod = _reload_with_env(
|
||||||
|
monkeypatch,
|
||||||
|
PSYC_INFERENCE_MODE="openai",
|
||||||
|
PSYC_INFERENCE_URL="https://api.example",
|
||||||
|
PSYC_INFERENCE_TOKEN="t0k",
|
||||||
|
PSYC_INFERENCE_MODEL="psyc-v5",
|
||||||
|
)
|
||||||
|
seen: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
seen["url"] = str(request.url)
|
||||||
|
seen["auth"] = request.headers.get("authorization")
|
||||||
|
return httpx.Response(200, json={"data": [{"id": "llama3"}, {"id": "mistral"}]})
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
real_client = httpx.Client
|
||||||
|
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
||||||
|
|
||||||
|
assert mod.server_adapter() == "llama3"
|
||||||
|
assert seen["url"].endswith("/v1/models")
|
||||||
|
assert seen["auth"] == "Bearer t0k"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_mode_severity_request_shape(monkeypatch):
|
||||||
|
mod = _reload_with_env(
|
||||||
|
monkeypatch,
|
||||||
|
PSYC_INFERENCE_MODE="openai",
|
||||||
|
PSYC_INFERENCE_URL="https://api.example",
|
||||||
|
PSYC_INFERENCE_TOKEN="t0k",
|
||||||
|
PSYC_INFERENCE_MODEL="psyc-v5",
|
||||||
|
)
|
||||||
|
sent: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
sent["url"] = str(request.url)
|
||||||
|
sent["auth"] = request.headers.get("authorization")
|
||||||
|
sent["body"] = json.loads(request.content.decode())
|
||||||
|
return httpx.Response(200, json={"choices": [{"message": {"content": "HIGH"}}]})
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
real_client = httpx.Client
|
||||||
|
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
||||||
|
|
||||||
|
result = mod.model_severity(_case())
|
||||||
|
assert result == "high"
|
||||||
|
assert sent["url"].endswith("/v1/chat/completions")
|
||||||
|
assert sent["auth"] == "Bearer t0k"
|
||||||
|
assert sent["body"]["model"] == "psyc-v5"
|
||||||
|
assert sent["body"]["messages"][0]["role"] == "system"
|
||||||
|
assert sent["body"]["messages"][1]["role"] == "user"
|
||||||
|
assert sent["body"]["max_tokens"] == 16
|
||||||
|
|
||||||
|
|
||||||
|
def test_psyc_mode_severity_unchanged(monkeypatch):
|
||||||
|
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="psyc", PSYC_INFERENCE_URL="http://x", PSYC_INFERENCE_TOKEN="")
|
||||||
|
sent: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
sent["url"] = str(request.url)
|
||||||
|
sent["auth"] = request.headers.get("authorization")
|
||||||
|
sent["body"] = json.loads(request.content.decode())
|
||||||
|
return httpx.Response(200, json={"output": "MEDIUM"})
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
real_client = httpx.Client
|
||||||
|
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
||||||
|
|
||||||
|
assert mod.model_severity(_case()) == "medium"
|
||||||
|
assert sent["url"].endswith("/infer")
|
||||||
|
assert sent["auth"] is None
|
||||||
|
assert "instruction" in sent["body"]
|
||||||
|
assert "max_new_tokens" in sent["body"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_adapter_returns_none_on_http_error(monkeypatch):
|
||||||
|
mod = _reload_with_env(monkeypatch, PSYC_INFERENCE_MODE="openai", PSYC_INFERENCE_URL="https://api.example")
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(401, json={"error": "unauthorized"})
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
real_client = httpx.Client
|
||||||
|
monkeypatch.setattr(httpx, "Client", lambda **kw: real_client(transport=transport, **{k: v for k, v in kw.items() if k != "transport"}))
|
||||||
|
|
||||||
|
assert mod.server_adapter() is None
|
||||||
Reference in New Issue
Block a user