Files
psyc/src/psyc/lines/train.py
m17hr1l f6fa52839f stage-20: defanging pipeline for IOC-extraction augmentation
Real CTI prose defangs IOCs (1[.]2[.]3[.]4, hxxp://, evil[dot]com) so they
don't auto-link in email/chat. A model trained only on canonical inputs
will fail to extract them.

New lines/defang.py: defang_ip, defang_domain, defang_url, defang_text —
four dot-styles ([.], (.), [dot], {.}) plus protocol defanging
(http→hxxp, https→hxxps). Each occurrence picks its style independently
since real advisories don't keep one style across paragraphs.

train.BuildOptions adds defang_frac (default 0.0) and seed; build()
threads options + a seeded Random through the example builders so
the augmentation is reproducible. Only _ex_ioc_extraction reads it
today — output stays canonical so the model learns messy→canonical.

CLI: train-build and train-build-all gain --defang-frac and --seed.
8 new tests including a frac=1.0 / output-canonical integration check.
The pipeline runs but is dormant at defang_frac=0.0 — psyc-v5 dataset
build will set 0.5 once OTX cases land.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-20 22:33:52 +02:00

343 lines
12 KiB
Python

"""Trainline — ExampleBuilder + QualityGate + DatasetWriter.
Reads reviewed cases from the DB, emits Alpaca-style JSONL training examples
for the defensive task set defined in docs/dossier.md §intelminer:
- ioc_extraction
- severity_classification
- routing_decision
- tlp_assignment
Outputs are versioned JSONL files under data/datasets/<task>-v<n>.jsonl.
QualityGate enforces the dossier's training-data policy: never TLP:RED, never
restricted source types, never empty input/output.
"""
from __future__ import annotations
import json
import random
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional
from pydantic import BaseModel, Field
from psyc import DATA_DIR, log
from psyc.lines import classify as classify_line
from psyc.lines import defang as defang_line
from psyc.lines import route as route_line
from psyc.models import Case, TLP
class BuildOptions(BaseModel):
"""Per-build configuration. Currently only ioc_extraction reads any field."""
defang_frac: float = 0.0 # in [0.0, 1.0] — fraction of ioc_extraction inputs to defang
seed: Optional[int] = None # reproducible RNG when set
_log = log.get(__name__)
DATASETS_DIR = DATA_DIR / "datasets"
ADAPTERS_DIR = DATA_DIR / "adapters"
TASKS = ("ioc_extraction", "severity_classification", "routing_decision", "tlp_assignment")
RESTRICTED_SOURCE_TYPES = {"criminal_forum", "stolen_data_paste", "unauthorized_dump"}
class Example(BaseModel):
instruction: str
input: str
output: str
task: str
case_id: str
tlp: str
meta: Dict[str, str] = Field(default_factory=dict)
class DatasetReport(BaseModel):
task: str
path: Path
written: int
skipped_quality: int
skipped_tlp_red: int
skipped_restricted_source: int
skipped_empty: int
# ---------- ExampleBuilder per task ---------------------------------------
def _ex_ioc_extraction(
case: Case,
options: Optional["BuildOptions"] = None,
rng: Optional[random.Random] = None,
) -> Optional[Example]:
obs = case.observables
if not (obs.urls or obs.domains or obs.ips or obs.hashes or obs.cves):
return None
threat = case.source_metadata.get("threat", "malware")
tags = case.source_metadata.get("tags", "")
# The extraction task is only well-posed if every IOC in the output also
# appears in the input — so build the advisory body from the observables.
body = [f"Threat advisory — {threat}."]
if obs.urls:
body.append("Malicious URLs: " + ", ".join(obs.urls) + ".")
if obs.domains:
body.append("Domains: " + ", ".join(obs.domains) + ".")
if obs.ips:
body.append("Hosting IPs: " + ", ".join(obs.ips) + ".")
if obs.hashes:
body.append("Sample hashes: " + ", ".join(obs.hashes) + ".")
if obs.cves:
body.append("Related CVEs: " + ", ".join(obs.cves) + ".")
if tags:
body.append(f"Tags: {tags}.")
body_text = " ".join(body)
# Defanging augmentation: with probability options.defang_frac, replace IOCs
# in the input with common real-world defanged forms (1[.]2[.]3[.]4,
# hxxp://, etc.). Output stays canonical so the model learns the mapping.
if options is not None and rng is not None and options.defang_frac > 0.0:
if rng.random() < options.defang_frac:
body_text = defang_line.defang_text(body_text, obs.ips, obs.domains, obs.urls, rng)
output_obj = {
"urls": obs.urls,
"domains": obs.domains,
"ips": obs.ips,
"hashes": obs.hashes,
"cves": obs.cves,
}
return Example(
instruction="Extract all indicators of compromise from the advisory and return JSON with keys: urls, domains, ips, hashes, cves.",
input=body_text,
output=json.dumps(output_obj, ensure_ascii=False),
task="ioc_extraction",
case_id=case.case_id,
tlp=case.classification.tlp.value,
)
# Shared so the cockpit's live inference sends the exact prompt the model
# trained on — keep _ex_severity_classification and the inference client in sync.
SEVERITY_INSTRUCTION = (
"Classify the defensive severity of this case as one of: "
"low | medium | high | critical. Return only the label."
)
def severity_features(case: Case) -> Dict[str, object]:
return {
"summary": case.summary,
"source_type": case.source_type,
"incident_type": case.classification.incident_type.value if case.classification.incident_type else None,
"status": case.source_metadata.get("url_status") or case.source_metadata.get("status", ""),
"victim_country": case.victim.country,
"critical_infrastructure": case.victim.critical_infrastructure,
"url_count": len(case.observables.urls),
"ip_count": len(case.observables.ips),
}
def _ex_severity_classification(
case: Case,
options: Optional["BuildOptions"] = None,
rng: Optional[random.Random] = None,
) -> Optional[Example]:
if case.classification.severity is None:
return None
return Example(
instruction=SEVERITY_INSTRUCTION,
input=json.dumps(severity_features(case), ensure_ascii=False),
output=case.classification.severity.value,
task="severity_classification",
case_id=case.case_id,
tlp=case.classification.tlp.value,
)
def _ex_routing_decision(
case: Case,
options: Optional["BuildOptions"] = None,
rng: Optional[random.Random] = None,
) -> Optional[Example]:
if case.classification.incident_type is None:
return None
routes, blocked = route_line.plan(case)
input_obj = {
"incident_type": case.classification.incident_type.value,
"tlp": case.classification.tlp.value,
"severity": case.classification.severity.value if case.classification.severity else None,
"victim_country": case.victim.country,
"critical_infrastructure": case.victim.critical_infrastructure,
"url_count": len(case.observables.urls),
}
output_obj = {
"allowed": [r.destination_name for r in routes],
"blocked": [{"destination": b.destination_name, "reason": b.reason} for b in blocked],
}
return Example(
instruction="Given the case, decide which destinations may receive a submission and which must be blocked. Return JSON with keys 'allowed' (destination names) and 'blocked' (objects with destination + reason).",
input=json.dumps(input_obj, ensure_ascii=False),
output=json.dumps(output_obj, ensure_ascii=False),
task="routing_decision",
case_id=case.case_id,
tlp=case.classification.tlp.value,
)
def _ex_tlp_assignment(
case: Case,
options: Optional["BuildOptions"] = None,
rng: Optional[random.Random] = None,
) -> Optional[Example]:
input_obj = {
"source_type": case.source_type,
"incident_type": case.classification.incident_type.value if case.classification.incident_type else None,
"victim_named": bool(case.victim.name),
"critical_infrastructure": case.victim.critical_infrastructure,
"summary": case.summary,
}
return Example(
instruction="Assign the Traffic Light Protocol (TLP) level for this case: RED | AMBER | GREEN | CLEAR. Return only the label.",
input=json.dumps(input_obj, ensure_ascii=False),
output=case.classification.tlp.value,
task="tlp_assignment",
case_id=case.case_id,
tlp=case.classification.tlp.value,
)
_BUILDERS: Dict[str, Callable[[Case], Optional[Example]]] = {
"ioc_extraction": _ex_ioc_extraction,
"severity_classification": _ex_severity_classification,
"routing_decision": _ex_routing_decision,
"tlp_assignment": _ex_tlp_assignment,
}
# ---------- QualityGate ---------------------------------------------------
class _Reject(BaseModel):
reason: str
def quality_gate(example: Example, case: Case) -> Optional[str]:
if case.classification.tlp == TLP.RED:
return "tlp_red"
if case.source_type in RESTRICTED_SOURCE_TYPES:
return "restricted_source"
if not example.input.strip() or not example.output.strip():
return "empty"
if len(example.input) > 8000 or len(example.output) > 4000:
return "too_long"
if re.search(r"\b(password|api[_-]?key|secret|bearer)\b", example.input, re.IGNORECASE):
return "credential_leak_in_input"
return None
# ---------- DatasetWriter -------------------------------------------------
def _next_version(task: str) -> int:
DATASETS_DIR.mkdir(parents=True, exist_ok=True)
existing = list(DATASETS_DIR.glob(f"{task}-v*.jsonl"))
used = []
for p in existing:
m = re.search(rf"{re.escape(task)}-v(\d+)\.jsonl$", p.name)
if m:
used.append(int(m.group(1)))
return (max(used) + 1) if used else 1
def build(task: str, cases: Iterable[Case], options: Optional[BuildOptions] = None) -> DatasetReport:
if task not in _BUILDERS:
raise ValueError(f"unknown task: {task}; choices: {sorted(_BUILDERS)}")
builder = _BUILDERS[task]
options = options or BuildOptions()
rng = random.Random(options.seed)
version = _next_version(task)
path = DATASETS_DIR / f"{task}-v{version}.jsonl"
written = 0
skipped_quality = 0
skipped_tlp_red = 0
skipped_restricted = 0
skipped_empty = 0
with path.open("w", encoding="utf-8") as fh:
for case in cases:
example = builder(case, options, rng)
if example is None:
skipped_empty += 1
continue
reject = quality_gate(example, case)
if reject == "tlp_red":
skipped_tlp_red += 1
continue
if reject == "restricted_source":
skipped_restricted += 1
continue
if reject is not None:
skipped_quality += 1
continue
fh.write(example.model_dump_json() + "\n")
written += 1
_log.info("train.dataset.built", task=task, version=version, written=written, path=str(path))
return DatasetReport(
task=task,
path=path,
written=written,
skipped_quality=skipped_quality,
skipped_tlp_red=skipped_tlp_red,
skipped_restricted_source=skipped_restricted,
skipped_empty=skipped_empty,
)
def list_datasets() -> List[Dict[str, str]]:
if not DATASETS_DIR.exists():
return []
out: List[Dict[str, str]] = []
for p in sorted(DATASETS_DIR.glob("*.jsonl")):
line_count = sum(1 for _ in p.open("r", encoding="utf-8"))
out.append({
"name": p.name,
"path": str(p),
"examples": str(line_count),
"size_bytes": str(p.stat().st_size),
"modified": datetime.fromtimestamp(p.stat().st_mtime, tz=timezone.utc).isoformat(),
})
return out
def _adapter_status(d: Path) -> str:
if (d / "final" / "adapter_model.safetensors").exists():
return "trained"
if (d / "checkpoints").exists():
return "in_progress"
return "not_started"
def list_adapters() -> List[Dict[str, object]]:
if not ADAPTERS_DIR.exists():
return []
out: List[Dict[str, object]] = []
for d in sorted(ADAPTERS_DIR.iterdir()):
if not d.is_dir():
continue
meta: Dict[str, object] = {}
meta_path = d / "training_meta.json"
if meta_path.exists():
meta = json.loads(meta_path.read_text(encoding="utf-8"))
out.append({
"name": d.name,
"status": _adapter_status(d),
"base_model": meta.get("base_model", ""),
"examples": meta.get("examples", 0),
"epochs": meta.get("epochs", 0),
"lora_r": meta.get("lora_r", 0),
"lr": meta.get("lr", 0),
"datasets": [Path(str(p)).name for p in meta.get("datasets", [])],
"train_loss": meta.get("train_loss"),
"loss_history": meta.get("loss_history", []),
})
return out