stage-3b: Trainline — JSONL dataset pipeline for QLoRA training

ExampleBuilder emits Alpaca-style training rows for four defensive tasks
(ioc_extraction, severity_classification, routing_decision, tlp_assignment).
QualityGate enforces the dossier's training-data policy: drops TLP:RED,
restricted-source, empty, oversize, and credential-leak examples. DatasetWriter
versions outputs as data/datasets/<task>-v<n>.jsonl. CLI: train-build,
train-build-all, train-list-datasets.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
m17hr1l
2026-05-14 14:15:58 +02:00
parent da4792c179
commit b8ea4ead02
2 changed files with 285 additions and 1 deletions

View File

@@ -8,7 +8,7 @@ import typer
import uvicorn import uvicorn
from psyc import db, log from psyc import db, log
from psyc.lines import classify, courier, route, scout, seal from psyc.lines import classify, courier, route, scout, seal, train
from psyc.lines import map as map_line from psyc.lines import map as map_line
from psyc.models import Outcome from psyc.models import Outcome
from psyc.result import Err, Ok from psyc.result import Err, Ok
@@ -204,6 +204,43 @@ def mock_cert_serve(host: str = "127.0.0.1", port: int = 8770) -> None:
uvicorn.run("psyc.mock_cert:app", host=host, port=port) uvicorn.run("psyc.mock_cert:app", host=host, port=port)
@app.command("train-build")
def train_build(
task: str = typer.Option(..., "--task", "-t", help=f"one of: {', '.join(train.TASKS)}"),
limit: int = typer.Option(10_000, help="max cases to process"),
) -> None:
if task not in train.TASKS:
typer.echo(f"unknown task: {task}; choices: {', '.join(train.TASKS)}", err=True)
raise typer.Exit(1)
cases = db.list_cases(limit=limit)
report = train.build(task, cases)
typer.echo(f"task: {report.task}")
typer.echo(f"path: {report.path}")
typer.echo(f" written: {report.written}")
typer.echo(f" skipped (empty): {report.skipped_empty}")
typer.echo(f" skipped (tlp red): {report.skipped_tlp_red}")
typer.echo(f" skipped (source): {report.skipped_restricted_source}")
typer.echo(f" skipped (quality): {report.skipped_quality}")
@app.command("train-build-all")
def train_build_all(limit: int = typer.Option(10_000, help="max cases per task")) -> None:
cases = db.list_cases(limit=limit)
for task in train.TASKS:
report = train.build(task, cases)
typer.echo(f" {task}: wrote {report.written}{report.path.name}")
@app.command("train-list-datasets")
def train_list_datasets() -> None:
rows = train.list_datasets()
if not rows:
typer.echo("(no datasets)")
return
for r in rows:
typer.echo(f"{r['name']:<40} {r['examples']:>6} examples {int(r['size_bytes']):>8} B {r['modified']}")
@app.command("demo") @app.command("demo")
def demo() -> None: def demo() -> None:
db.init_db() db.init_db()

247
src/psyc/lines/train.py Normal file
View File

@@ -0,0 +1,247 @@
"""Trainline — ExampleBuilder + QualityGate + DatasetWriter.
Reads reviewed cases from the DB, emits Alpaca-style JSONL training examples
for the defensive task set defined in docs/dossier.md §intelminer:
- ioc_extraction
- severity_classification
- routing_decision
- tlp_assignment
Outputs are versioned JSONL files under data/datasets/<task>-v<n>.jsonl.
QualityGate enforces the dossier's training-data policy: never TLP:RED, never
restricted source types, never empty input/output.
"""
from __future__ import annotations
import json
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional
from pydantic import BaseModel, Field
from psyc import DATA_DIR, log
from psyc.lines import classify as classify_line
from psyc.lines import route as route_line
from psyc.models import Case, TLP
_log = log.get(__name__)
DATASETS_DIR = DATA_DIR / "datasets"
TASKS = ("ioc_extraction", "severity_classification", "routing_decision", "tlp_assignment")
RESTRICTED_SOURCE_TYPES = {"criminal_forum", "stolen_data_paste", "unauthorized_dump"}
class Example(BaseModel):
instruction: str
input: str
output: str
task: str
case_id: str
tlp: str
meta: Dict[str, str] = Field(default_factory=dict)
class DatasetReport(BaseModel):
task: str
path: Path
written: int
skipped_quality: int
skipped_tlp_red: int
skipped_restricted_source: int
skipped_empty: int
# ---------- ExampleBuilder per task ---------------------------------------
def _ex_ioc_extraction(case: Case) -> Optional[Example]:
obs = case.observables
if not (obs.urls or obs.domains or obs.ips or obs.hashes):
return None
summary_or_threat = case.summary or case.source_metadata.get("threat", "")
input_text = f"Advisory: {summary_or_threat}\nSource: {case.source_ref or case.source_type}"
output_obj = {
"urls": obs.urls,
"domains": obs.domains,
"ips": obs.ips,
"hashes": obs.hashes,
"cves": obs.cves,
}
return Example(
instruction="Extract all indicators of compromise (URLs, domains, IPs, hashes, CVEs) from the advisory. Return JSON with keys: urls, domains, ips, hashes, cves.",
input=input_text,
output=json.dumps(output_obj, ensure_ascii=False),
task="ioc_extraction",
case_id=case.case_id,
tlp=case.classification.tlp.value,
)
def _ex_severity_classification(case: Case) -> Optional[Example]:
if case.classification.severity is None:
return None
input_obj = {
"summary": case.summary,
"source_type": case.source_type,
"incident_type": case.classification.incident_type.value if case.classification.incident_type else None,
"url_status": case.source_metadata.get("url_status", ""),
"victim_country": case.victim.country,
"critical_infrastructure": case.victim.critical_infrastructure,
"url_count": len(case.observables.urls),
"ip_count": len(case.observables.ips),
}
return Example(
instruction="Classify the defensive severity of this case as one of: low | medium | high | critical. Return only the label.",
input=json.dumps(input_obj, ensure_ascii=False),
output=case.classification.severity.value,
task="severity_classification",
case_id=case.case_id,
tlp=case.classification.tlp.value,
)
def _ex_routing_decision(case: Case) -> Optional[Example]:
if case.classification.incident_type is None:
return None
routes, blocked = route_line.plan(case)
input_obj = {
"incident_type": case.classification.incident_type.value,
"tlp": case.classification.tlp.value,
"severity": case.classification.severity.value if case.classification.severity else None,
"victim_country": case.victim.country,
"critical_infrastructure": case.victim.critical_infrastructure,
"url_count": len(case.observables.urls),
}
output_obj = {
"allowed": [r.destination_name for r in routes],
"blocked": [{"destination": b.destination_name, "reason": b.reason} for b in blocked],
}
return Example(
instruction="Given the case, decide which destinations may receive a submission and which must be blocked. Return JSON with keys 'allowed' (destination names) and 'blocked' (objects with destination + reason).",
input=json.dumps(input_obj, ensure_ascii=False),
output=json.dumps(output_obj, ensure_ascii=False),
task="routing_decision",
case_id=case.case_id,
tlp=case.classification.tlp.value,
)
def _ex_tlp_assignment(case: Case) -> Optional[Example]:
input_obj = {
"source_type": case.source_type,
"incident_type": case.classification.incident_type.value if case.classification.incident_type else None,
"victim_named": bool(case.victim.name),
"critical_infrastructure": case.victim.critical_infrastructure,
"summary": case.summary,
}
return Example(
instruction="Assign the Traffic Light Protocol (TLP) level for this case: RED | AMBER | GREEN | CLEAR. Return only the label.",
input=json.dumps(input_obj, ensure_ascii=False),
output=case.classification.tlp.value,
task="tlp_assignment",
case_id=case.case_id,
tlp=case.classification.tlp.value,
)
_BUILDERS: Dict[str, Callable[[Case], Optional[Example]]] = {
"ioc_extraction": _ex_ioc_extraction,
"severity_classification": _ex_severity_classification,
"routing_decision": _ex_routing_decision,
"tlp_assignment": _ex_tlp_assignment,
}
# ---------- QualityGate ---------------------------------------------------
class _Reject(BaseModel):
reason: str
def quality_gate(example: Example, case: Case) -> Optional[str]:
if case.classification.tlp == TLP.RED:
return "tlp_red"
if case.source_type in RESTRICTED_SOURCE_TYPES:
return "restricted_source"
if not example.input.strip() or not example.output.strip():
return "empty"
if len(example.input) > 8000 or len(example.output) > 4000:
return "too_long"
if re.search(r"\b(password|api[_-]?key|secret|bearer)\b", example.input, re.IGNORECASE):
return "credential_leak_in_input"
return None
# ---------- DatasetWriter -------------------------------------------------
def _next_version(task: str) -> int:
DATASETS_DIR.mkdir(parents=True, exist_ok=True)
existing = list(DATASETS_DIR.glob(f"{task}-v*.jsonl"))
used = []
for p in existing:
m = re.search(rf"{re.escape(task)}-v(\d+)\.jsonl$", p.name)
if m:
used.append(int(m.group(1)))
return (max(used) + 1) if used else 1
def build(task: str, cases: Iterable[Case]) -> DatasetReport:
if task not in _BUILDERS:
raise ValueError(f"unknown task: {task}; choices: {sorted(_BUILDERS)}")
builder = _BUILDERS[task]
version = _next_version(task)
path = DATASETS_DIR / f"{task}-v{version}.jsonl"
written = 0
skipped_quality = 0
skipped_tlp_red = 0
skipped_restricted = 0
skipped_empty = 0
with path.open("w", encoding="utf-8") as fh:
for case in cases:
example = builder(case)
if example is None:
skipped_empty += 1
continue
reject = quality_gate(example, case)
if reject == "tlp_red":
skipped_tlp_red += 1
continue
if reject == "restricted_source":
skipped_restricted += 1
continue
if reject is not None:
skipped_quality += 1
continue
fh.write(example.model_dump_json() + "\n")
written += 1
_log.info("train.dataset.built", task=task, version=version, written=written, path=str(path))
return DatasetReport(
task=task,
path=path,
written=written,
skipped_quality=skipped_quality,
skipped_tlp_red=skipped_tlp_red,
skipped_restricted_source=skipped_restricted,
skipped_empty=skipped_empty,
)
def list_datasets() -> List[Dict[str, str]]:
if not DATASETS_DIR.exists():
return []
out: List[Dict[str, str]] = []
for p in sorted(DATASETS_DIR.glob("*.jsonl")):
line_count = sum(1 for _ in p.open("r", encoding="utf-8"))
out.append({
"name": p.name,
"path": str(p),
"examples": str(line_count),
"size_bytes": str(p.stat().st_size),
"modified": datetime.fromtimestamp(p.stat().st_mtime, tz=timezone.utc).isoformat(),
})
return out