diff --git a/src/psyc/cli.py b/src/psyc/cli.py index 12ecdfe..5a0615c 100644 --- a/src/psyc/cli.py +++ b/src/psyc/cli.py @@ -8,7 +8,7 @@ import typer import uvicorn from psyc import db, log -from psyc.lines import classify, courier, route, scout, seal +from psyc.lines import classify, courier, route, scout, seal, train from psyc.lines import map as map_line from psyc.models import Outcome from psyc.result import Err, Ok @@ -204,6 +204,43 @@ def mock_cert_serve(host: str = "127.0.0.1", port: int = 8770) -> None: uvicorn.run("psyc.mock_cert:app", host=host, port=port) +@app.command("train-build") +def train_build( + task: str = typer.Option(..., "--task", "-t", help=f"one of: {', '.join(train.TASKS)}"), + limit: int = typer.Option(10_000, help="max cases to process"), +) -> None: + if task not in train.TASKS: + typer.echo(f"unknown task: {task}; choices: {', '.join(train.TASKS)}", err=True) + raise typer.Exit(1) + cases = db.list_cases(limit=limit) + report = train.build(task, cases) + typer.echo(f"task: {report.task}") + typer.echo(f"path: {report.path}") + typer.echo(f" written: {report.written}") + typer.echo(f" skipped (empty): {report.skipped_empty}") + typer.echo(f" skipped (tlp red): {report.skipped_tlp_red}") + typer.echo(f" skipped (source): {report.skipped_restricted_source}") + typer.echo(f" skipped (quality): {report.skipped_quality}") + + +@app.command("train-build-all") +def train_build_all(limit: int = typer.Option(10_000, help="max cases per task")) -> None: + cases = db.list_cases(limit=limit) + for task in train.TASKS: + report = train.build(task, cases) + typer.echo(f" {task}: wrote {report.written} → {report.path.name}") + + +@app.command("train-list-datasets") +def train_list_datasets() -> None: + rows = train.list_datasets() + if not rows: + typer.echo("(no datasets)") + return + for r in rows: + typer.echo(f"{r['name']:<40} {r['examples']:>6} examples {int(r['size_bytes']):>8} B {r['modified']}") + + @app.command("demo") def demo() -> None: db.init_db() diff --git a/src/psyc/lines/train.py b/src/psyc/lines/train.py new file mode 100644 index 0000000..14ecd55 --- /dev/null +++ b/src/psyc/lines/train.py @@ -0,0 +1,247 @@ +"""Trainline — ExampleBuilder + QualityGate + DatasetWriter. + +Reads reviewed cases from the DB, emits Alpaca-style JSONL training examples +for the defensive task set defined in docs/dossier.md §intelminer: + - ioc_extraction + - severity_classification + - routing_decision + - tlp_assignment + +Outputs are versioned JSONL files under data/datasets/-v.jsonl. +QualityGate enforces the dossier's training-data policy: never TLP:RED, never +restricted source types, never empty input/output. +""" + +from __future__ import annotations + +import json +import re +from datetime import datetime, timezone +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional + +from pydantic import BaseModel, Field + +from psyc import DATA_DIR, log +from psyc.lines import classify as classify_line +from psyc.lines import route as route_line +from psyc.models import Case, TLP + + +_log = log.get(__name__) + +DATASETS_DIR = DATA_DIR / "datasets" + +TASKS = ("ioc_extraction", "severity_classification", "routing_decision", "tlp_assignment") + +RESTRICTED_SOURCE_TYPES = {"criminal_forum", "stolen_data_paste", "unauthorized_dump"} + + +class Example(BaseModel): + instruction: str + input: str + output: str + task: str + case_id: str + tlp: str + meta: Dict[str, str] = Field(default_factory=dict) + + +class DatasetReport(BaseModel): + task: str + path: Path + written: int + skipped_quality: int + skipped_tlp_red: int + skipped_restricted_source: int + skipped_empty: int + + +# ---------- ExampleBuilder per task --------------------------------------- + +def _ex_ioc_extraction(case: Case) -> Optional[Example]: + obs = case.observables + if not (obs.urls or obs.domains or obs.ips or obs.hashes): + return None + summary_or_threat = case.summary or case.source_metadata.get("threat", "") + input_text = f"Advisory: {summary_or_threat}\nSource: {case.source_ref or case.source_type}" + output_obj = { + "urls": obs.urls, + "domains": obs.domains, + "ips": obs.ips, + "hashes": obs.hashes, + "cves": obs.cves, + } + return Example( + instruction="Extract all indicators of compromise (URLs, domains, IPs, hashes, CVEs) from the advisory. Return JSON with keys: urls, domains, ips, hashes, cves.", + input=input_text, + output=json.dumps(output_obj, ensure_ascii=False), + task="ioc_extraction", + case_id=case.case_id, + tlp=case.classification.tlp.value, + ) + + +def _ex_severity_classification(case: Case) -> Optional[Example]: + if case.classification.severity is None: + return None + input_obj = { + "summary": case.summary, + "source_type": case.source_type, + "incident_type": case.classification.incident_type.value if case.classification.incident_type else None, + "url_status": case.source_metadata.get("url_status", ""), + "victim_country": case.victim.country, + "critical_infrastructure": case.victim.critical_infrastructure, + "url_count": len(case.observables.urls), + "ip_count": len(case.observables.ips), + } + return Example( + instruction="Classify the defensive severity of this case as one of: low | medium | high | critical. Return only the label.", + input=json.dumps(input_obj, ensure_ascii=False), + output=case.classification.severity.value, + task="severity_classification", + case_id=case.case_id, + tlp=case.classification.tlp.value, + ) + + +def _ex_routing_decision(case: Case) -> Optional[Example]: + if case.classification.incident_type is None: + return None + routes, blocked = route_line.plan(case) + input_obj = { + "incident_type": case.classification.incident_type.value, + "tlp": case.classification.tlp.value, + "severity": case.classification.severity.value if case.classification.severity else None, + "victim_country": case.victim.country, + "critical_infrastructure": case.victim.critical_infrastructure, + "url_count": len(case.observables.urls), + } + output_obj = { + "allowed": [r.destination_name for r in routes], + "blocked": [{"destination": b.destination_name, "reason": b.reason} for b in blocked], + } + return Example( + instruction="Given the case, decide which destinations may receive a submission and which must be blocked. Return JSON with keys 'allowed' (destination names) and 'blocked' (objects with destination + reason).", + input=json.dumps(input_obj, ensure_ascii=False), + output=json.dumps(output_obj, ensure_ascii=False), + task="routing_decision", + case_id=case.case_id, + tlp=case.classification.tlp.value, + ) + + +def _ex_tlp_assignment(case: Case) -> Optional[Example]: + input_obj = { + "source_type": case.source_type, + "incident_type": case.classification.incident_type.value if case.classification.incident_type else None, + "victim_named": bool(case.victim.name), + "critical_infrastructure": case.victim.critical_infrastructure, + "summary": case.summary, + } + return Example( + instruction="Assign the Traffic Light Protocol (TLP) level for this case: RED | AMBER | GREEN | CLEAR. Return only the label.", + input=json.dumps(input_obj, ensure_ascii=False), + output=case.classification.tlp.value, + task="tlp_assignment", + case_id=case.case_id, + tlp=case.classification.tlp.value, + ) + + +_BUILDERS: Dict[str, Callable[[Case], Optional[Example]]] = { + "ioc_extraction": _ex_ioc_extraction, + "severity_classification": _ex_severity_classification, + "routing_decision": _ex_routing_decision, + "tlp_assignment": _ex_tlp_assignment, +} + + +# ---------- QualityGate --------------------------------------------------- + +class _Reject(BaseModel): + reason: str + + +def quality_gate(example: Example, case: Case) -> Optional[str]: + if case.classification.tlp == TLP.RED: + return "tlp_red" + if case.source_type in RESTRICTED_SOURCE_TYPES: + return "restricted_source" + if not example.input.strip() or not example.output.strip(): + return "empty" + if len(example.input) > 8000 or len(example.output) > 4000: + return "too_long" + if re.search(r"\b(password|api[_-]?key|secret|bearer)\b", example.input, re.IGNORECASE): + return "credential_leak_in_input" + return None + + +# ---------- DatasetWriter ------------------------------------------------- + +def _next_version(task: str) -> int: + DATASETS_DIR.mkdir(parents=True, exist_ok=True) + existing = list(DATASETS_DIR.glob(f"{task}-v*.jsonl")) + used = [] + for p in existing: + m = re.search(rf"{re.escape(task)}-v(\d+)\.jsonl$", p.name) + if m: + used.append(int(m.group(1))) + return (max(used) + 1) if used else 1 + + +def build(task: str, cases: Iterable[Case]) -> DatasetReport: + if task not in _BUILDERS: + raise ValueError(f"unknown task: {task}; choices: {sorted(_BUILDERS)}") + builder = _BUILDERS[task] + version = _next_version(task) + path = DATASETS_DIR / f"{task}-v{version}.jsonl" + written = 0 + skipped_quality = 0 + skipped_tlp_red = 0 + skipped_restricted = 0 + skipped_empty = 0 + with path.open("w", encoding="utf-8") as fh: + for case in cases: + example = builder(case) + if example is None: + skipped_empty += 1 + continue + reject = quality_gate(example, case) + if reject == "tlp_red": + skipped_tlp_red += 1 + continue + if reject == "restricted_source": + skipped_restricted += 1 + continue + if reject is not None: + skipped_quality += 1 + continue + fh.write(example.model_dump_json() + "\n") + written += 1 + _log.info("train.dataset.built", task=task, version=version, written=written, path=str(path)) + return DatasetReport( + task=task, + path=path, + written=written, + skipped_quality=skipped_quality, + skipped_tlp_red=skipped_tlp_red, + skipped_restricted_source=skipped_restricted, + skipped_empty=skipped_empty, + ) + + +def list_datasets() -> List[Dict[str, str]]: + if not DATASETS_DIR.exists(): + return [] + out: List[Dict[str, str]] = [] + for p in sorted(DATASETS_DIR.glob("*.jsonl")): + line_count = sum(1 for _ in p.open("r", encoding="utf-8")) + out.append({ + "name": p.name, + "path": str(p), + "examples": str(line_count), + "size_bytes": str(p.stat().st_size), + "modified": datetime.fromtimestamp(p.stat().st_mtime, tz=timezone.utc).isoformat(), + }) + return out