Real CTI prose defangs IOCs (1[.]2[.]3[.]4, hxxp://, evil[dot]com) so they
don't auto-link in email/chat. A model trained only on canonical inputs
will fail to extract them.
New lines/defang.py: defang_ip, defang_domain, defang_url, defang_text —
four dot-styles ([.], (.), [dot], {.}) plus protocol defanging
(http→hxxp, https→hxxps). Each occurrence picks its style independently
since real advisories don't keep one style across paragraphs.
train.BuildOptions adds defang_frac (default 0.0) and seed; build()
threads options + a seeded Random through the example builders so
the augmentation is reproducible. Only _ex_ioc_extraction reads it
today — output stays canonical so the model learns messy→canonical.
CLI: train-build and train-build-all gain --defang-frac and --seed.
8 new tests including a frac=1.0 / output-canonical integration check.
The pipeline runs but is dormant at defang_frac=0.0 — psyc-v5 dataset
build will set 0.5 once OTX cases land.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
343 lines
12 KiB
Python
343 lines
12 KiB
Python
"""Trainline — ExampleBuilder + QualityGate + DatasetWriter.
|
|
|
|
Reads reviewed cases from the DB, emits Alpaca-style JSONL training examples
|
|
for the defensive task set defined in docs/dossier.md §intelminer:
|
|
- ioc_extraction
|
|
- severity_classification
|
|
- routing_decision
|
|
- tlp_assignment
|
|
|
|
Outputs are versioned JSONL files under data/datasets/<task>-v<n>.jsonl.
|
|
QualityGate enforces the dossier's training-data policy: never TLP:RED, never
|
|
restricted source types, never empty input/output.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import random
|
|
import re
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Callable, Dict, Iterable, List, Optional
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from psyc import DATA_DIR, log
|
|
from psyc.lines import classify as classify_line
|
|
from psyc.lines import defang as defang_line
|
|
from psyc.lines import route as route_line
|
|
from psyc.models import Case, TLP
|
|
|
|
|
|
class BuildOptions(BaseModel):
|
|
"""Per-build configuration. Currently only ioc_extraction reads any field."""
|
|
|
|
defang_frac: float = 0.0 # in [0.0, 1.0] — fraction of ioc_extraction inputs to defang
|
|
seed: Optional[int] = None # reproducible RNG when set
|
|
|
|
|
|
_log = log.get(__name__)
|
|
|
|
DATASETS_DIR = DATA_DIR / "datasets"
|
|
ADAPTERS_DIR = DATA_DIR / "adapters"
|
|
|
|
TASKS = ("ioc_extraction", "severity_classification", "routing_decision", "tlp_assignment")
|
|
|
|
RESTRICTED_SOURCE_TYPES = {"criminal_forum", "stolen_data_paste", "unauthorized_dump"}
|
|
|
|
|
|
class Example(BaseModel):
|
|
instruction: str
|
|
input: str
|
|
output: str
|
|
task: str
|
|
case_id: str
|
|
tlp: str
|
|
meta: Dict[str, str] = Field(default_factory=dict)
|
|
|
|
|
|
class DatasetReport(BaseModel):
|
|
task: str
|
|
path: Path
|
|
written: int
|
|
skipped_quality: int
|
|
skipped_tlp_red: int
|
|
skipped_restricted_source: int
|
|
skipped_empty: int
|
|
|
|
|
|
# ---------- ExampleBuilder per task ---------------------------------------
|
|
|
|
def _ex_ioc_extraction(
|
|
case: Case,
|
|
options: Optional["BuildOptions"] = None,
|
|
rng: Optional[random.Random] = None,
|
|
) -> Optional[Example]:
|
|
obs = case.observables
|
|
if not (obs.urls or obs.domains or obs.ips or obs.hashes or obs.cves):
|
|
return None
|
|
threat = case.source_metadata.get("threat", "malware")
|
|
tags = case.source_metadata.get("tags", "")
|
|
# The extraction task is only well-posed if every IOC in the output also
|
|
# appears in the input — so build the advisory body from the observables.
|
|
body = [f"Threat advisory — {threat}."]
|
|
if obs.urls:
|
|
body.append("Malicious URLs: " + ", ".join(obs.urls) + ".")
|
|
if obs.domains:
|
|
body.append("Domains: " + ", ".join(obs.domains) + ".")
|
|
if obs.ips:
|
|
body.append("Hosting IPs: " + ", ".join(obs.ips) + ".")
|
|
if obs.hashes:
|
|
body.append("Sample hashes: " + ", ".join(obs.hashes) + ".")
|
|
if obs.cves:
|
|
body.append("Related CVEs: " + ", ".join(obs.cves) + ".")
|
|
if tags:
|
|
body.append(f"Tags: {tags}.")
|
|
body_text = " ".join(body)
|
|
# Defanging augmentation: with probability options.defang_frac, replace IOCs
|
|
# in the input with common real-world defanged forms (1[.]2[.]3[.]4,
|
|
# hxxp://, etc.). Output stays canonical so the model learns the mapping.
|
|
if options is not None and rng is not None and options.defang_frac > 0.0:
|
|
if rng.random() < options.defang_frac:
|
|
body_text = defang_line.defang_text(body_text, obs.ips, obs.domains, obs.urls, rng)
|
|
output_obj = {
|
|
"urls": obs.urls,
|
|
"domains": obs.domains,
|
|
"ips": obs.ips,
|
|
"hashes": obs.hashes,
|
|
"cves": obs.cves,
|
|
}
|
|
return Example(
|
|
instruction="Extract all indicators of compromise from the advisory and return JSON with keys: urls, domains, ips, hashes, cves.",
|
|
input=body_text,
|
|
output=json.dumps(output_obj, ensure_ascii=False),
|
|
task="ioc_extraction",
|
|
case_id=case.case_id,
|
|
tlp=case.classification.tlp.value,
|
|
)
|
|
|
|
|
|
# Shared so the cockpit's live inference sends the exact prompt the model
|
|
# trained on — keep _ex_severity_classification and the inference client in sync.
|
|
SEVERITY_INSTRUCTION = (
|
|
"Classify the defensive severity of this case as one of: "
|
|
"low | medium | high | critical. Return only the label."
|
|
)
|
|
|
|
|
|
def severity_features(case: Case) -> Dict[str, object]:
|
|
return {
|
|
"summary": case.summary,
|
|
"source_type": case.source_type,
|
|
"incident_type": case.classification.incident_type.value if case.classification.incident_type else None,
|
|
"status": case.source_metadata.get("url_status") or case.source_metadata.get("status", ""),
|
|
"victim_country": case.victim.country,
|
|
"critical_infrastructure": case.victim.critical_infrastructure,
|
|
"url_count": len(case.observables.urls),
|
|
"ip_count": len(case.observables.ips),
|
|
}
|
|
|
|
|
|
def _ex_severity_classification(
|
|
case: Case,
|
|
options: Optional["BuildOptions"] = None,
|
|
rng: Optional[random.Random] = None,
|
|
) -> Optional[Example]:
|
|
if case.classification.severity is None:
|
|
return None
|
|
return Example(
|
|
instruction=SEVERITY_INSTRUCTION,
|
|
input=json.dumps(severity_features(case), ensure_ascii=False),
|
|
output=case.classification.severity.value,
|
|
task="severity_classification",
|
|
case_id=case.case_id,
|
|
tlp=case.classification.tlp.value,
|
|
)
|
|
|
|
|
|
def _ex_routing_decision(
|
|
case: Case,
|
|
options: Optional["BuildOptions"] = None,
|
|
rng: Optional[random.Random] = None,
|
|
) -> Optional[Example]:
|
|
if case.classification.incident_type is None:
|
|
return None
|
|
routes, blocked = route_line.plan(case)
|
|
input_obj = {
|
|
"incident_type": case.classification.incident_type.value,
|
|
"tlp": case.classification.tlp.value,
|
|
"severity": case.classification.severity.value if case.classification.severity else None,
|
|
"victim_country": case.victim.country,
|
|
"critical_infrastructure": case.victim.critical_infrastructure,
|
|
"url_count": len(case.observables.urls),
|
|
}
|
|
output_obj = {
|
|
"allowed": [r.destination_name for r in routes],
|
|
"blocked": [{"destination": b.destination_name, "reason": b.reason} for b in blocked],
|
|
}
|
|
return Example(
|
|
instruction="Given the case, decide which destinations may receive a submission and which must be blocked. Return JSON with keys 'allowed' (destination names) and 'blocked' (objects with destination + reason).",
|
|
input=json.dumps(input_obj, ensure_ascii=False),
|
|
output=json.dumps(output_obj, ensure_ascii=False),
|
|
task="routing_decision",
|
|
case_id=case.case_id,
|
|
tlp=case.classification.tlp.value,
|
|
)
|
|
|
|
|
|
def _ex_tlp_assignment(
|
|
case: Case,
|
|
options: Optional["BuildOptions"] = None,
|
|
rng: Optional[random.Random] = None,
|
|
) -> Optional[Example]:
|
|
input_obj = {
|
|
"source_type": case.source_type,
|
|
"incident_type": case.classification.incident_type.value if case.classification.incident_type else None,
|
|
"victim_named": bool(case.victim.name),
|
|
"critical_infrastructure": case.victim.critical_infrastructure,
|
|
"summary": case.summary,
|
|
}
|
|
return Example(
|
|
instruction="Assign the Traffic Light Protocol (TLP) level for this case: RED | AMBER | GREEN | CLEAR. Return only the label.",
|
|
input=json.dumps(input_obj, ensure_ascii=False),
|
|
output=case.classification.tlp.value,
|
|
task="tlp_assignment",
|
|
case_id=case.case_id,
|
|
tlp=case.classification.tlp.value,
|
|
)
|
|
|
|
|
|
_BUILDERS: Dict[str, Callable[[Case], Optional[Example]]] = {
|
|
"ioc_extraction": _ex_ioc_extraction,
|
|
"severity_classification": _ex_severity_classification,
|
|
"routing_decision": _ex_routing_decision,
|
|
"tlp_assignment": _ex_tlp_assignment,
|
|
}
|
|
|
|
|
|
# ---------- QualityGate ---------------------------------------------------
|
|
|
|
class _Reject(BaseModel):
|
|
reason: str
|
|
|
|
|
|
def quality_gate(example: Example, case: Case) -> Optional[str]:
|
|
if case.classification.tlp == TLP.RED:
|
|
return "tlp_red"
|
|
if case.source_type in RESTRICTED_SOURCE_TYPES:
|
|
return "restricted_source"
|
|
if not example.input.strip() or not example.output.strip():
|
|
return "empty"
|
|
if len(example.input) > 8000 or len(example.output) > 4000:
|
|
return "too_long"
|
|
if re.search(r"\b(password|api[_-]?key|secret|bearer)\b", example.input, re.IGNORECASE):
|
|
return "credential_leak_in_input"
|
|
return None
|
|
|
|
|
|
# ---------- DatasetWriter -------------------------------------------------
|
|
|
|
def _next_version(task: str) -> int:
|
|
DATASETS_DIR.mkdir(parents=True, exist_ok=True)
|
|
existing = list(DATASETS_DIR.glob(f"{task}-v*.jsonl"))
|
|
used = []
|
|
for p in existing:
|
|
m = re.search(rf"{re.escape(task)}-v(\d+)\.jsonl$", p.name)
|
|
if m:
|
|
used.append(int(m.group(1)))
|
|
return (max(used) + 1) if used else 1
|
|
|
|
|
|
def build(task: str, cases: Iterable[Case], options: Optional[BuildOptions] = None) -> DatasetReport:
|
|
if task not in _BUILDERS:
|
|
raise ValueError(f"unknown task: {task}; choices: {sorted(_BUILDERS)}")
|
|
builder = _BUILDERS[task]
|
|
options = options or BuildOptions()
|
|
rng = random.Random(options.seed)
|
|
version = _next_version(task)
|
|
path = DATASETS_DIR / f"{task}-v{version}.jsonl"
|
|
written = 0
|
|
skipped_quality = 0
|
|
skipped_tlp_red = 0
|
|
skipped_restricted = 0
|
|
skipped_empty = 0
|
|
with path.open("w", encoding="utf-8") as fh:
|
|
for case in cases:
|
|
example = builder(case, options, rng)
|
|
if example is None:
|
|
skipped_empty += 1
|
|
continue
|
|
reject = quality_gate(example, case)
|
|
if reject == "tlp_red":
|
|
skipped_tlp_red += 1
|
|
continue
|
|
if reject == "restricted_source":
|
|
skipped_restricted += 1
|
|
continue
|
|
if reject is not None:
|
|
skipped_quality += 1
|
|
continue
|
|
fh.write(example.model_dump_json() + "\n")
|
|
written += 1
|
|
_log.info("train.dataset.built", task=task, version=version, written=written, path=str(path))
|
|
return DatasetReport(
|
|
task=task,
|
|
path=path,
|
|
written=written,
|
|
skipped_quality=skipped_quality,
|
|
skipped_tlp_red=skipped_tlp_red,
|
|
skipped_restricted_source=skipped_restricted,
|
|
skipped_empty=skipped_empty,
|
|
)
|
|
|
|
|
|
def list_datasets() -> List[Dict[str, str]]:
|
|
if not DATASETS_DIR.exists():
|
|
return []
|
|
out: List[Dict[str, str]] = []
|
|
for p in sorted(DATASETS_DIR.glob("*.jsonl")):
|
|
line_count = sum(1 for _ in p.open("r", encoding="utf-8"))
|
|
out.append({
|
|
"name": p.name,
|
|
"path": str(p),
|
|
"examples": str(line_count),
|
|
"size_bytes": str(p.stat().st_size),
|
|
"modified": datetime.fromtimestamp(p.stat().st_mtime, tz=timezone.utc).isoformat(),
|
|
})
|
|
return out
|
|
|
|
|
|
def _adapter_status(d: Path) -> str:
|
|
if (d / "final" / "adapter_model.safetensors").exists():
|
|
return "trained"
|
|
if (d / "checkpoints").exists():
|
|
return "in_progress"
|
|
return "not_started"
|
|
|
|
|
|
def list_adapters() -> List[Dict[str, object]]:
|
|
if not ADAPTERS_DIR.exists():
|
|
return []
|
|
out: List[Dict[str, object]] = []
|
|
for d in sorted(ADAPTERS_DIR.iterdir()):
|
|
if not d.is_dir():
|
|
continue
|
|
meta: Dict[str, object] = {}
|
|
meta_path = d / "training_meta.json"
|
|
if meta_path.exists():
|
|
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
|
out.append({
|
|
"name": d.name,
|
|
"status": _adapter_status(d),
|
|
"base_model": meta.get("base_model", "—"),
|
|
"examples": meta.get("examples", 0),
|
|
"epochs": meta.get("epochs", 0),
|
|
"lora_r": meta.get("lora_r", 0),
|
|
"lr": meta.get("lr", 0),
|
|
"datasets": [Path(str(p)).name for p in meta.get("datasets", [])],
|
|
"train_loss": meta.get("train_loss"),
|
|
"loss_history": meta.get("loss_history", []),
|
|
})
|
|
return out
|