stage-3b: Trainline — JSONL dataset pipeline for QLoRA training
ExampleBuilder emits Alpaca-style training rows for four defensive tasks (ioc_extraction, severity_classification, routing_decision, tlp_assignment). QualityGate enforces the dossier's training-data policy: drops TLP:RED, restricted-source, empty, oversize, and credential-leak examples. DatasetWriter versions outputs as data/datasets/<task>-v<n>.jsonl. CLI: train-build, train-build-all, train-list-datasets. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -8,7 +8,7 @@ import typer
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from psyc import db, log
|
from psyc import db, log
|
||||||
from psyc.lines import classify, courier, route, scout, seal
|
from psyc.lines import classify, courier, route, scout, seal, train
|
||||||
from psyc.lines import map as map_line
|
from psyc.lines import map as map_line
|
||||||
from psyc.models import Outcome
|
from psyc.models import Outcome
|
||||||
from psyc.result import Err, Ok
|
from psyc.result import Err, Ok
|
||||||
@@ -204,6 +204,43 @@ def mock_cert_serve(host: str = "127.0.0.1", port: int = 8770) -> None:
|
|||||||
uvicorn.run("psyc.mock_cert:app", host=host, port=port)
|
uvicorn.run("psyc.mock_cert:app", host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("train-build")
|
||||||
|
def train_build(
|
||||||
|
task: str = typer.Option(..., "--task", "-t", help=f"one of: {', '.join(train.TASKS)}"),
|
||||||
|
limit: int = typer.Option(10_000, help="max cases to process"),
|
||||||
|
) -> None:
|
||||||
|
if task not in train.TASKS:
|
||||||
|
typer.echo(f"unknown task: {task}; choices: {', '.join(train.TASKS)}", err=True)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
cases = db.list_cases(limit=limit)
|
||||||
|
report = train.build(task, cases)
|
||||||
|
typer.echo(f"task: {report.task}")
|
||||||
|
typer.echo(f"path: {report.path}")
|
||||||
|
typer.echo(f" written: {report.written}")
|
||||||
|
typer.echo(f" skipped (empty): {report.skipped_empty}")
|
||||||
|
typer.echo(f" skipped (tlp red): {report.skipped_tlp_red}")
|
||||||
|
typer.echo(f" skipped (source): {report.skipped_restricted_source}")
|
||||||
|
typer.echo(f" skipped (quality): {report.skipped_quality}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("train-build-all")
|
||||||
|
def train_build_all(limit: int = typer.Option(10_000, help="max cases per task")) -> None:
|
||||||
|
cases = db.list_cases(limit=limit)
|
||||||
|
for task in train.TASKS:
|
||||||
|
report = train.build(task, cases)
|
||||||
|
typer.echo(f" {task}: wrote {report.written} → {report.path.name}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("train-list-datasets")
|
||||||
|
def train_list_datasets() -> None:
|
||||||
|
rows = train.list_datasets()
|
||||||
|
if not rows:
|
||||||
|
typer.echo("(no datasets)")
|
||||||
|
return
|
||||||
|
for r in rows:
|
||||||
|
typer.echo(f"{r['name']:<40} {r['examples']:>6} examples {int(r['size_bytes']):>8} B {r['modified']}")
|
||||||
|
|
||||||
|
|
||||||
@app.command("demo")
|
@app.command("demo")
|
||||||
def demo() -> None:
|
def demo() -> None:
|
||||||
db.init_db()
|
db.init_db()
|
||||||
|
|||||||
247
src/psyc/lines/train.py
Normal file
247
src/psyc/lines/train.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
"""Trainline — ExampleBuilder + QualityGate + DatasetWriter.
|
||||||
|
|
||||||
|
Reads reviewed cases from the DB, emits Alpaca-style JSONL training examples
|
||||||
|
for the defensive task set defined in docs/dossier.md §intelminer:
|
||||||
|
- ioc_extraction
|
||||||
|
- severity_classification
|
||||||
|
- routing_decision
|
||||||
|
- tlp_assignment
|
||||||
|
|
||||||
|
Outputs are versioned JSONL files under data/datasets/<task>-v<n>.jsonl.
|
||||||
|
QualityGate enforces the dossier's training-data policy: never TLP:RED, never
|
||||||
|
restricted source types, never empty input/output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from psyc import DATA_DIR, log
|
||||||
|
from psyc.lines import classify as classify_line
|
||||||
|
from psyc.lines import route as route_line
|
||||||
|
from psyc.models import Case, TLP
|
||||||
|
|
||||||
|
|
||||||
|
_log = log.get(__name__)
|
||||||
|
|
||||||
|
DATASETS_DIR = DATA_DIR / "datasets"
|
||||||
|
|
||||||
|
TASKS = ("ioc_extraction", "severity_classification", "routing_decision", "tlp_assignment")
|
||||||
|
|
||||||
|
RESTRICTED_SOURCE_TYPES = {"criminal_forum", "stolen_data_paste", "unauthorized_dump"}
|
||||||
|
|
||||||
|
|
||||||
|
class Example(BaseModel):
|
||||||
|
instruction: str
|
||||||
|
input: str
|
||||||
|
output: str
|
||||||
|
task: str
|
||||||
|
case_id: str
|
||||||
|
tlp: str
|
||||||
|
meta: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetReport(BaseModel):
|
||||||
|
task: str
|
||||||
|
path: Path
|
||||||
|
written: int
|
||||||
|
skipped_quality: int
|
||||||
|
skipped_tlp_red: int
|
||||||
|
skipped_restricted_source: int
|
||||||
|
skipped_empty: int
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- ExampleBuilder per task ---------------------------------------
|
||||||
|
|
||||||
|
def _ex_ioc_extraction(case: Case) -> Optional[Example]:
|
||||||
|
obs = case.observables
|
||||||
|
if not (obs.urls or obs.domains or obs.ips or obs.hashes):
|
||||||
|
return None
|
||||||
|
summary_or_threat = case.summary or case.source_metadata.get("threat", "")
|
||||||
|
input_text = f"Advisory: {summary_or_threat}\nSource: {case.source_ref or case.source_type}"
|
||||||
|
output_obj = {
|
||||||
|
"urls": obs.urls,
|
||||||
|
"domains": obs.domains,
|
||||||
|
"ips": obs.ips,
|
||||||
|
"hashes": obs.hashes,
|
||||||
|
"cves": obs.cves,
|
||||||
|
}
|
||||||
|
return Example(
|
||||||
|
instruction="Extract all indicators of compromise (URLs, domains, IPs, hashes, CVEs) from the advisory. Return JSON with keys: urls, domains, ips, hashes, cves.",
|
||||||
|
input=input_text,
|
||||||
|
output=json.dumps(output_obj, ensure_ascii=False),
|
||||||
|
task="ioc_extraction",
|
||||||
|
case_id=case.case_id,
|
||||||
|
tlp=case.classification.tlp.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ex_severity_classification(case: Case) -> Optional[Example]:
|
||||||
|
if case.classification.severity is None:
|
||||||
|
return None
|
||||||
|
input_obj = {
|
||||||
|
"summary": case.summary,
|
||||||
|
"source_type": case.source_type,
|
||||||
|
"incident_type": case.classification.incident_type.value if case.classification.incident_type else None,
|
||||||
|
"url_status": case.source_metadata.get("url_status", ""),
|
||||||
|
"victim_country": case.victim.country,
|
||||||
|
"critical_infrastructure": case.victim.critical_infrastructure,
|
||||||
|
"url_count": len(case.observables.urls),
|
||||||
|
"ip_count": len(case.observables.ips),
|
||||||
|
}
|
||||||
|
return Example(
|
||||||
|
instruction="Classify the defensive severity of this case as one of: low | medium | high | critical. Return only the label.",
|
||||||
|
input=json.dumps(input_obj, ensure_ascii=False),
|
||||||
|
output=case.classification.severity.value,
|
||||||
|
task="severity_classification",
|
||||||
|
case_id=case.case_id,
|
||||||
|
tlp=case.classification.tlp.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ex_routing_decision(case: Case) -> Optional[Example]:
|
||||||
|
if case.classification.incident_type is None:
|
||||||
|
return None
|
||||||
|
routes, blocked = route_line.plan(case)
|
||||||
|
input_obj = {
|
||||||
|
"incident_type": case.classification.incident_type.value,
|
||||||
|
"tlp": case.classification.tlp.value,
|
||||||
|
"severity": case.classification.severity.value if case.classification.severity else None,
|
||||||
|
"victim_country": case.victim.country,
|
||||||
|
"critical_infrastructure": case.victim.critical_infrastructure,
|
||||||
|
"url_count": len(case.observables.urls),
|
||||||
|
}
|
||||||
|
output_obj = {
|
||||||
|
"allowed": [r.destination_name for r in routes],
|
||||||
|
"blocked": [{"destination": b.destination_name, "reason": b.reason} for b in blocked],
|
||||||
|
}
|
||||||
|
return Example(
|
||||||
|
instruction="Given the case, decide which destinations may receive a submission and which must be blocked. Return JSON with keys 'allowed' (destination names) and 'blocked' (objects with destination + reason).",
|
||||||
|
input=json.dumps(input_obj, ensure_ascii=False),
|
||||||
|
output=json.dumps(output_obj, ensure_ascii=False),
|
||||||
|
task="routing_decision",
|
||||||
|
case_id=case.case_id,
|
||||||
|
tlp=case.classification.tlp.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ex_tlp_assignment(case: Case) -> Optional[Example]:
|
||||||
|
input_obj = {
|
||||||
|
"source_type": case.source_type,
|
||||||
|
"incident_type": case.classification.incident_type.value if case.classification.incident_type else None,
|
||||||
|
"victim_named": bool(case.victim.name),
|
||||||
|
"critical_infrastructure": case.victim.critical_infrastructure,
|
||||||
|
"summary": case.summary,
|
||||||
|
}
|
||||||
|
return Example(
|
||||||
|
instruction="Assign the Traffic Light Protocol (TLP) level for this case: RED | AMBER | GREEN | CLEAR. Return only the label.",
|
||||||
|
input=json.dumps(input_obj, ensure_ascii=False),
|
||||||
|
output=case.classification.tlp.value,
|
||||||
|
task="tlp_assignment",
|
||||||
|
case_id=case.case_id,
|
||||||
|
tlp=case.classification.tlp.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_BUILDERS: Dict[str, Callable[[Case], Optional[Example]]] = {
|
||||||
|
"ioc_extraction": _ex_ioc_extraction,
|
||||||
|
"severity_classification": _ex_severity_classification,
|
||||||
|
"routing_decision": _ex_routing_decision,
|
||||||
|
"tlp_assignment": _ex_tlp_assignment,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- QualityGate ---------------------------------------------------
|
||||||
|
|
||||||
|
class _Reject(BaseModel):
|
||||||
|
reason: str
|
||||||
|
|
||||||
|
|
||||||
|
def quality_gate(example: Example, case: Case) -> Optional[str]:
|
||||||
|
if case.classification.tlp == TLP.RED:
|
||||||
|
return "tlp_red"
|
||||||
|
if case.source_type in RESTRICTED_SOURCE_TYPES:
|
||||||
|
return "restricted_source"
|
||||||
|
if not example.input.strip() or not example.output.strip():
|
||||||
|
return "empty"
|
||||||
|
if len(example.input) > 8000 or len(example.output) > 4000:
|
||||||
|
return "too_long"
|
||||||
|
if re.search(r"\b(password|api[_-]?key|secret|bearer)\b", example.input, re.IGNORECASE):
|
||||||
|
return "credential_leak_in_input"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- DatasetWriter -------------------------------------------------
|
||||||
|
|
||||||
|
def _next_version(task: str) -> int:
|
||||||
|
DATASETS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
existing = list(DATASETS_DIR.glob(f"{task}-v*.jsonl"))
|
||||||
|
used = []
|
||||||
|
for p in existing:
|
||||||
|
m = re.search(rf"{re.escape(task)}-v(\d+)\.jsonl$", p.name)
|
||||||
|
if m:
|
||||||
|
used.append(int(m.group(1)))
|
||||||
|
return (max(used) + 1) if used else 1
|
||||||
|
|
||||||
|
|
||||||
|
def build(task: str, cases: Iterable[Case]) -> DatasetReport:
|
||||||
|
if task not in _BUILDERS:
|
||||||
|
raise ValueError(f"unknown task: {task}; choices: {sorted(_BUILDERS)}")
|
||||||
|
builder = _BUILDERS[task]
|
||||||
|
version = _next_version(task)
|
||||||
|
path = DATASETS_DIR / f"{task}-v{version}.jsonl"
|
||||||
|
written = 0
|
||||||
|
skipped_quality = 0
|
||||||
|
skipped_tlp_red = 0
|
||||||
|
skipped_restricted = 0
|
||||||
|
skipped_empty = 0
|
||||||
|
with path.open("w", encoding="utf-8") as fh:
|
||||||
|
for case in cases:
|
||||||
|
example = builder(case)
|
||||||
|
if example is None:
|
||||||
|
skipped_empty += 1
|
||||||
|
continue
|
||||||
|
reject = quality_gate(example, case)
|
||||||
|
if reject == "tlp_red":
|
||||||
|
skipped_tlp_red += 1
|
||||||
|
continue
|
||||||
|
if reject == "restricted_source":
|
||||||
|
skipped_restricted += 1
|
||||||
|
continue
|
||||||
|
if reject is not None:
|
||||||
|
skipped_quality += 1
|
||||||
|
continue
|
||||||
|
fh.write(example.model_dump_json() + "\n")
|
||||||
|
written += 1
|
||||||
|
_log.info("train.dataset.built", task=task, version=version, written=written, path=str(path))
|
||||||
|
return DatasetReport(
|
||||||
|
task=task,
|
||||||
|
path=path,
|
||||||
|
written=written,
|
||||||
|
skipped_quality=skipped_quality,
|
||||||
|
skipped_tlp_red=skipped_tlp_red,
|
||||||
|
skipped_restricted_source=skipped_restricted,
|
||||||
|
skipped_empty=skipped_empty,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_datasets() -> List[Dict[str, str]]:
|
||||||
|
if not DATASETS_DIR.exists():
|
||||||
|
return []
|
||||||
|
out: List[Dict[str, str]] = []
|
||||||
|
for p in sorted(DATASETS_DIR.glob("*.jsonl")):
|
||||||
|
line_count = sum(1 for _ in p.open("r", encoding="utf-8"))
|
||||||
|
out.append({
|
||||||
|
"name": p.name,
|
||||||
|
"path": str(p),
|
||||||
|
"examples": str(line_count),
|
||||||
|
"size_bytes": str(p.stat().st_size),
|
||||||
|
"modified": datetime.fromtimestamp(p.stat().st_mtime, tz=timezone.utc).isoformat(),
|
||||||
|
})
|
||||||
|
return out
|
||||||
Reference in New Issue
Block a user