stage-fed-d federation: signed feed export + verified import
This commit is contained in:
@@ -12,15 +12,19 @@ vouching/quorum, transparency log and auto-pull live in later stages.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Tuple
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from psyc import DATA_DIR, log
|
from psyc import DATA_DIR, db, log
|
||||||
|
from psyc.result import Err, Ok, Result
|
||||||
|
|
||||||
|
|
||||||
_log = log.get(__name__)
|
_log = log.get(__name__)
|
||||||
@@ -151,3 +155,251 @@ def dns_record(domain: str, port: int = 443) -> DNSRecord:
|
|||||||
txt_value=txt_value,
|
txt_value=txt_value,
|
||||||
human_instructions=instructions,
|
human_instructions=instructions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- signing -----------------------------------------------------
|
||||||
|
|
||||||
|
def canonical_json(obj: Dict[str, Any]) -> bytes:
|
||||||
|
"""Deterministic JSON serialization — what we sign + hash over."""
|
||||||
|
return json.dumps(obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def sign_payload(payload: bytes) -> bytes:
|
||||||
|
"""Ed25519 signature over `payload`. Raw 64-byte sig."""
|
||||||
|
priv, _ = node_keypair()
|
||||||
|
return priv.sign(payload)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_payload(payload: bytes, signature: bytes, pubkey_pem: str) -> bool:
|
||||||
|
"""True iff `signature` verifies under `pubkey_pem`. Never raises."""
|
||||||
|
try:
|
||||||
|
pub = serialization.load_pem_public_key(pubkey_pem.encode("ascii"))
|
||||||
|
if not isinstance(pub, ed25519.Ed25519PublicKey):
|
||||||
|
return False
|
||||||
|
pub.verify(signature, payload)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- feed export -------------------------------------------------
|
||||||
|
|
||||||
|
def _case_digest(case_record: Dict[str, Any]) -> str:
|
||||||
|
return hashlib.sha256(canonical_json(case_record)).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _build_case_records(window_hours: int) -> List[Dict[str, Any]]:
|
||||||
|
cutoff = datetime.now(timezone.utc) - timedelta(hours=window_hours)
|
||||||
|
out: List[Dict[str, Any]] = []
|
||||||
|
for case in db.list_cases(limit=10_000):
|
||||||
|
if case.ingested_at < cutoff:
|
||||||
|
continue
|
||||||
|
record: Dict[str, Any] = {
|
||||||
|
"case_id": case.case_id,
|
||||||
|
"summary": case.summary,
|
||||||
|
"severity": case.classification.severity.value if case.classification.severity else None,
|
||||||
|
"incident_type": case.classification.incident_type.value if case.classification.incident_type else None,
|
||||||
|
"observed_at": case.observed_at.isoformat(),
|
||||||
|
"feed_source": case.source_metadata.get("feed", ""),
|
||||||
|
"iocs": (
|
||||||
|
[{"value": v, "type": "url"} for v in case.observables.urls]
|
||||||
|
+ [{"value": v, "type": "domain"} for v in case.observables.domains]
|
||||||
|
+ [{"value": v, "type": "ip"} for v in case.observables.ips]
|
||||||
|
+ [{"value": v, "type": "hash"} for v in case.observables.hashes]
|
||||||
|
+ [{"value": v, "type": "cve"} for v in case.observables.cves]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
record["digest_sha256"] = _case_digest(
|
||||||
|
{k: v for k, v in record.items() if k != "digest_sha256"}
|
||||||
|
)
|
||||||
|
out.append(record)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _build_ioc_records(window_hours: int) -> List[Dict[str, Any]]:
|
||||||
|
cutoff = datetime.now(timezone.utc) - timedelta(hours=window_hours)
|
||||||
|
out: List[Dict[str, Any]] = []
|
||||||
|
seen: set = set()
|
||||||
|
for ioc_type in ("url", "domain", "ip", "hash", "cve"):
|
||||||
|
for row in db.iocs_by_type(ioc_type):
|
||||||
|
first_seen = row.get("first_seen")
|
||||||
|
if first_seen:
|
||||||
|
try:
|
||||||
|
if datetime.fromisoformat(first_seen) < cutoff:
|
||||||
|
continue
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
key = (row["value"], row["ioc_type"])
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
record = {
|
||||||
|
"value": row["value"],
|
||||||
|
"type": row["ioc_type"],
|
||||||
|
"severity": row.get("severity"),
|
||||||
|
"first_seen": first_seen,
|
||||||
|
}
|
||||||
|
record["digest_sha256"] = hashlib.sha256(canonical_json(record)).hexdigest()
|
||||||
|
out.append(record)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def build_signed_feed(window_hours: int = 24) -> Dict[str, Any]:
|
||||||
|
"""Build the JSON feed peers will pull from /federation/feed.
|
||||||
|
|
||||||
|
Pulls cases ingested in the last `window_hours` plus the corresponding
|
||||||
|
IOC slice, attaches per-record `digest_sha256` (so peers can later
|
||||||
|
quorum-match across nodes), and signs the canonical JSON of the whole
|
||||||
|
payload-minus-signature with our Ed25519 key.
|
||||||
|
"""
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"version": FEED_VERSION,
|
||||||
|
"fingerprint": node_fingerprint(),
|
||||||
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"window_hours": window_hours,
|
||||||
|
"cases": _build_case_records(window_hours),
|
||||||
|
"iocs": _build_ioc_records(window_hours),
|
||||||
|
}
|
||||||
|
sig = sign_payload(canonical_json(payload))
|
||||||
|
payload["signature"] = base64.b64encode(sig).decode("ascii")
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- import + quorum-signal buffer -------------------------------
|
||||||
|
|
||||||
|
class ImportSummary(BaseModel):
|
||||||
|
peer_fingerprint: str
|
||||||
|
cases_seen: int
|
||||||
|
iocs_seen: int
|
||||||
|
signal_ids: List[Tuple[str, str]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def import_signed_feed(feed: Dict[str, Any], expected_pubkey_pem: str) -> Result[ImportSummary, str]:
|
||||||
|
"""Verify + record a peer's feed into the federation_signals buffer.
|
||||||
|
|
||||||
|
Does NOT merge into the local case store — that's the quorum stage's
|
||||||
|
job. The buffer is the per-hash signal log that quorum logic later
|
||||||
|
aggregates ("3 trusted peers reported this same IOC → promote").
|
||||||
|
"""
|
||||||
|
sig_b64 = feed.get("signature")
|
||||||
|
if not sig_b64:
|
||||||
|
return Err("missing signature")
|
||||||
|
try:
|
||||||
|
signature = base64.b64decode(sig_b64)
|
||||||
|
except Exception:
|
||||||
|
return Err("malformed signature (not base64)")
|
||||||
|
|
||||||
|
unsigned = {k: v for k, v in feed.items() if k != "signature"}
|
||||||
|
if not verify_payload(canonical_json(unsigned), signature, expected_pubkey_pem):
|
||||||
|
return Err("signature verification failed")
|
||||||
|
|
||||||
|
peer_fp = feed.get("fingerprint", "")
|
||||||
|
if not peer_fp:
|
||||||
|
return Err("missing fingerprint")
|
||||||
|
if peer_fp == node_fingerprint():
|
||||||
|
return Err("loop: own feed")
|
||||||
|
|
||||||
|
# Cross-check the declared fingerprint matches the pubkey we verified with.
|
||||||
|
try:
|
||||||
|
if _fingerprint_for_pubkey_pem(expected_pubkey_pem) != peer_fp:
|
||||||
|
return Err("fingerprint does not match provided pubkey")
|
||||||
|
except Exception as exc:
|
||||||
|
return Err(f"bad pubkey: {exc}")
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
signal_ids: List[Tuple[str, str]] = []
|
||||||
|
cases = feed.get("cases") or []
|
||||||
|
iocs = feed.get("iocs") or []
|
||||||
|
|
||||||
|
for c in cases:
|
||||||
|
case_id = c.get("case_id") or ""
|
||||||
|
digest = c.get("digest_sha256") or hashlib.sha256(canonical_json(c)).hexdigest()
|
||||||
|
db.record_signal(dict(
|
||||||
|
peer_fingerprint=peer_fp,
|
||||||
|
signal_type="case",
|
||||||
|
signal_id=case_id,
|
||||||
|
signal_hash=digest,
|
||||||
|
received_at=now,
|
||||||
|
raw_json=json.dumps(c, sort_keys=True),
|
||||||
|
))
|
||||||
|
signal_ids.append(("case", digest))
|
||||||
|
|
||||||
|
for i in iocs:
|
||||||
|
value = i.get("value") or ""
|
||||||
|
digest = i.get("digest_sha256") or hashlib.sha256(canonical_json(i)).hexdigest()
|
||||||
|
db.record_signal(dict(
|
||||||
|
peer_fingerprint=peer_fp,
|
||||||
|
signal_type="ioc",
|
||||||
|
signal_id=value,
|
||||||
|
signal_hash=digest,
|
||||||
|
received_at=now,
|
||||||
|
raw_json=json.dumps(i, sort_keys=True),
|
||||||
|
))
|
||||||
|
signal_ids.append(("ioc", digest))
|
||||||
|
|
||||||
|
_log.info("federation.import.ok", peer=peer_fp, cases=len(cases), iocs=len(iocs))
|
||||||
|
return Ok(ImportSummary(
|
||||||
|
peer_fingerprint=peer_fp,
|
||||||
|
cases_seen=len(cases),
|
||||||
|
iocs_seen=len(iocs),
|
||||||
|
signal_ids=signal_ids,
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- peer registry ------------------------------------------------
|
||||||
|
|
||||||
|
class Peer(BaseModel):
|
||||||
|
domain: str
|
||||||
|
fingerprint: str
|
||||||
|
pubkey_pem: str
|
||||||
|
status: str = "unknown" # unknown | trusted | blocked
|
||||||
|
discovered_at: str
|
||||||
|
last_seen: Optional[str] = None
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _row_to_peer(row: Dict[str, Any]) -> Peer:
|
||||||
|
return Peer(
|
||||||
|
domain=row["domain"],
|
||||||
|
fingerprint=row["fingerprint"],
|
||||||
|
pubkey_pem=row["pubkey_pem"],
|
||||||
|
status=row.get("status") or "unknown",
|
||||||
|
discovered_at=row.get("discovered_at") or "",
|
||||||
|
last_seen=row.get("last_seen"),
|
||||||
|
notes=row.get("notes"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_peer(domain: str, fingerprint: str, pubkey_pem: str, status: str = "unknown") -> None:
|
||||||
|
"""Insert or update a peer in the registry. Idempotent on `domain`."""
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
existing = db.get_peer(domain)
|
||||||
|
discovered_at = existing["discovered_at"] if existing else now
|
||||||
|
db.upsert_peer(dict(
|
||||||
|
domain=domain,
|
||||||
|
fingerprint=fingerprint,
|
||||||
|
pubkey_pem=pubkey_pem,
|
||||||
|
status=status,
|
||||||
|
discovered_at=discovered_at,
|
||||||
|
last_seen=now,
|
||||||
|
notes=existing.get("notes") if existing else None,
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def list_peers() -> List[Peer]:
|
||||||
|
return [_row_to_peer(r) for r in db.list_peers()]
|
||||||
|
|
||||||
|
|
||||||
|
def get_peer(domain: str) -> Optional[Peer]:
|
||||||
|
row = db.get_peer(domain)
|
||||||
|
return _row_to_peer(row) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def set_peer_status(domain: str, status: str) -> None:
|
||||||
|
if status not in ("unknown", "trusted", "blocked"):
|
||||||
|
raise ValueError(f"unknown peer status: {status}")
|
||||||
|
db.set_peer_status(domain, status)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_peer(domain: str) -> None:
|
||||||
|
db.remove_peer(domain)
|
||||||
|
|||||||
Reference in New Issue
Block a user