diff --git a/tests/test_pulse_respond.py b/tests/test_pulse_respond.py new file mode 100644 index 0000000..c940b3c --- /dev/null +++ b/tests/test_pulse_respond.py @@ -0,0 +1,314 @@ +"""Pulseline auto-response gating — severity threshold, quorum, local-only. + +The runner here is the live `_run_respond` from pulse.py. We point it at a +temp DB, monkeypatch federation.is_quorum_met to a controllable function, and +swap respond.execute_action for a counter so we don't reach the SOAR sink. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import List, Tuple + +import pytest +from sqlalchemy import create_engine + +from psyc import db +from psyc.lines import pulse, respond +from psyc.lines import federation +from psyc.models import ( + ActionStatus, + ActionType, + Case, + Classification, + Observables, + ResponseAction, + Severity, + TLP, +) +from psyc.result import Ok + +from conftest import make_case + + +# ----- fixtures -------------------------------------------------------------- + +@pytest.fixture +def fresh_db(tmp_path, monkeypatch): + """Temp SQLite + the real runner registry. Mode pinned to auto-execute.""" + test_db = tmp_path / "respond.db" + eng = create_engine(f"sqlite:///{test_db}", future=True) + db._metadata.create_all(eng, checkfirst=True) + monkeypatch.setattr(db, "_engine", eng) + monkeypatch.setattr(db, "DB_PATH", test_db) + yield test_db + + +@pytest.fixture +def fired(monkeypatch): + """Capture every execute_action(action_id, approver=...) — no SOAR sink call.""" + log: List[Tuple[int, str]] = [] + + def fake_execute(action_id: int, approver: str = "operator"): + log.append((action_id, approver)) + # Re-read the action so we can return a realistic Ok value + got = respond.get_action(action_id) + return got if isinstance(got, Ok) else got + + monkeypatch.setattr(respond, "execute_action", fake_execute) + return log + + +@pytest.fixture +def quorum_yes(monkeypatch): + monkeypatch.setattr(federation, "is_quorum_met", + lambda h, k=None: True, raising=False) + + +@pytest.fixture +def quorum_no(monkeypatch): + monkeypatch.setattr(federation, "is_quorum_met", + lambda h, k=None: False, raising=False) + + +def _set_respond_mode(mode: pulse.PulseMode) -> None: + pulse.set_mode("respond", mode) + + +def _propose_one(case: Case) -> int: + db.upsert_case(case) + ids = respond.propose_for_case(case) + assert ids, "test setup expected at least one action proposed" + return ids[0] + + +# ----- severity rank --------------------------------------------------------- + +def test_severity_rank_ordering(): + assert pulse._severity_rank(Severity.LOW) == 0 + assert pulse._severity_rank(Severity.MEDIUM) == 1 + assert pulse._severity_rank(Severity.HIGH) == 2 + assert pulse._severity_rank(Severity.CRITICAL) == 3 + assert pulse._severity_rank(None) == -1 + + +# ----- runner mode gating ---------------------------------------------------- + +def test_runner_no_auto_fire_when_mode_is_propose(fresh_db, fired, quorum_yes): + case = make_case(feed="feodo", ips=["9.9.9.9"], severity=Severity.HIGH) + db.upsert_case(case) + # default seed mode for respond is auto-propose → no auto-fire even with PROPOSED actions + result = pulse._run_respond() + assert "no auto-fire" in result + assert fired == [] + + +def test_runner_no_auto_fire_when_manual(fresh_db, fired, quorum_yes): + case = make_case(feed="feodo", ips=["9.9.9.9"], severity=Severity.HIGH) + db.upsert_case(case) + _set_respond_mode(pulse.PulseMode.MANUAL) + result = pulse._run_respond() + assert "no auto-fire" in result + assert fired == [] + + +# ----- severity threshold ---------------------------------------------------- + +def test_below_threshold_is_skipped(fresh_db, fired, quorum_yes): + # Propose an action carrying severity=MEDIUM by hand — propose_for_case + # only generates HIGH/CRITICAL actions, but the gate must still work for + # any below-threshold severity we drop in. + case = make_case(feed="feodo", ips=["9.9.9.9"], severity=Severity.HIGH) + db.upsert_case(case) + respond.propose_for_case(case) + # Demote every action's severity to MEDIUM so all should be skipped under HIGH threshold. + from sqlalchemy import update as sa_update + with db.engine().begin() as conn: + conn.execute(sa_update(db.response_actions).values(severity=Severity.MEDIUM.value)) + pulse.set_respond_auto_threshold(Severity.HIGH) + _set_respond_mode(pulse.PulseMode.AUTO_EXECUTE) + + pulse._run_respond() + assert fired == [], "below-threshold action must not fire" + audit = db.pulse_audit_recent("respond", limit=5) + assert any(r["action"] == "skipped" and "below threshold" in (r["detail"] or "") for r in audit) + + +# ----- quorum gate ----------------------------------------------------------- + +def test_federation_case_no_quorum_skipped(fresh_db, fired, quorum_no): + case = make_case(feed="urlhaus", ips=["9.9.9.9"], severity=Severity.HIGH) + db.upsert_case(case) + # Mark this case as federation-sourced by inserting a signal row for it. + db.record_signal(dict( + peer_fingerprint="peer-a", + signal_type="case", + signal_id=case.case_id, + signal_hash="dummyhash", + received_at=datetime.now(timezone.utc).isoformat(), + raw_json="{}", + )) + respond.propose_for_case(case) + + pulse.set_respond_require_quorum(True) + pulse.set_respond_local_only(False) + pulse.set_respond_auto_threshold(Severity.HIGH) + _set_respond_mode(pulse.PulseMode.AUTO_EXECUTE) + + pulse._run_respond() + assert fired == [] + audit = db.pulse_audit_recent("respond", limit=5) + assert any(r["action"] == "skipped" and "no quorum" in (r["detail"] or "") for r in audit) + + +def test_local_case_fires_when_quorum_required(fresh_db, fired, quorum_no): + """Locally-generated cases bypass quorum — they're our own work.""" + case = make_case(feed="urlhaus", ips=["9.9.9.9"], severity=Severity.HIGH) + db.upsert_case(case) + # No federation_signals row → locally-generated + respond.propose_for_case(case) + + pulse.set_respond_require_quorum(True) + pulse.set_respond_local_only(True) # both armed; local cases still fire + pulse.set_respond_auto_threshold(Severity.HIGH) + _set_respond_mode(pulse.PulseMode.AUTO_EXECUTE) + + pulse._run_respond() + assert len(fired) >= 1 + audit = db.pulse_audit_recent("respond", limit=10) + assert any(r["action"] == "auto-fire" for r in audit) + + +def test_local_case_fires_local_only_off(fresh_db, fired, quorum_no): + """Even with local_only OFF, a locally-generated case still fires (no quorum needed).""" + case = make_case(feed="urlhaus", ips=["1.1.1.1"], severity=Severity.CRITICAL) + db.upsert_case(case) + respond.propose_for_case(case) + + pulse.set_respond_require_quorum(True) + pulse.set_respond_local_only(False) + pulse.set_respond_auto_threshold(Severity.HIGH) + _set_respond_mode(pulse.PulseMode.AUTO_EXECUTE) + + pulse._run_respond() + assert len(fired) >= 1 + + +def test_federation_case_with_quorum_fires(fresh_db, fired, quorum_yes): + case = make_case(feed="urlhaus", ips=["2.2.2.2"], severity=Severity.HIGH) + db.upsert_case(case) + db.record_signal(dict( + peer_fingerprint="peer-b", + signal_type="case", + signal_id=case.case_id, + signal_hash="dummyhash2", + received_at=datetime.now(timezone.utc).isoformat(), + raw_json="{}", + )) + respond.propose_for_case(case) + + pulse.set_respond_require_quorum(True) + pulse.set_respond_local_only(False) + pulse.set_respond_auto_threshold(Severity.HIGH) + _set_respond_mode(pulse.PulseMode.AUTO_EXECUTE) + + pulse._run_respond() + assert len(fired) >= 1 + + +def test_quorum_off_fires_federation_case(fresh_db, fired, quorum_no): + """With quorum gating disabled entirely, federation cases fire too.""" + case = make_case(feed="urlhaus", ips=["3.3.3.3"], severity=Severity.HIGH) + db.upsert_case(case) + db.record_signal(dict( + peer_fingerprint="peer-c", + signal_type="case", + signal_id=case.case_id, + signal_hash="dummyhash3", + received_at=datetime.now(timezone.utc).isoformat(), + raw_json="{}", + )) + respond.propose_for_case(case) + + pulse.set_respond_require_quorum(False) + pulse.set_respond_auto_threshold(Severity.HIGH) + _set_respond_mode(pulse.PulseMode.AUTO_EXECUTE) + + pulse._run_respond() + assert len(fired) >= 1 + + +# ----- kill switch ----------------------------------------------------------- + +def test_kill_switch_blocks_tick(fresh_db, fired, quorum_yes): + """The parent tick() skips everything when kill switch is armed.""" + case = make_case(feed="feodo", ips=["9.9.9.9"], severity=Severity.HIGH) + db.upsert_case(case) + respond.propose_for_case(case) + _set_respond_mode(pulse.PulseMode.AUTO_EXECUTE) + pulse.set_kill_switch(True) + results = pulse.tick() + assert all(o == "skipped" for _, o, _ in results) + assert fired == [] + + +# ----- audit ----------------------------------------------------------------- + +def test_pulse_audit_records_fire_and_skip(fresh_db, fired, quorum_no): + # Local case → should fire and audit auto-fire + local = make_case(feed="urlhaus", ips=["10.0.0.1"], severity=Severity.HIGH, age_days=1) + db.upsert_case(local) + respond.propose_for_case(local) + + # Federation-sourced case w/o quorum → should skip and audit skip + fedcase = make_case(feed="urlhaus", ips=["10.0.0.2"], severity=Severity.HIGH, age_days=2) + db.upsert_case(fedcase) + db.record_signal(dict( + peer_fingerprint="peer-x", + signal_type="case", + signal_id=fedcase.case_id, + signal_hash="xhash", + received_at=datetime.now(timezone.utc).isoformat(), + raw_json="{}", + )) + respond.propose_for_case(fedcase) + + pulse.set_respond_require_quorum(True) + pulse.set_respond_local_only(False) + pulse.set_respond_auto_threshold(Severity.HIGH) + _set_respond_mode(pulse.PulseMode.AUTO_EXECUTE) + pulse._run_respond() + + audit = db.pulse_audit_recent("respond", limit=20) + actions = {r["action"] for r in audit} + assert "auto-fire" in actions + assert "skipped" in actions + + +def test_audit_count_since(fresh_db, fired, quorum_no): + case = make_case(feed="urlhaus", ips=["8.8.8.8"], severity=Severity.HIGH) + db.upsert_case(case) + respond.propose_for_case(case) + pulse.set_respond_require_quorum(True) + pulse.set_respond_auto_threshold(Severity.HIGH) + _set_respond_mode(pulse.PulseMode.AUTO_EXECUTE) + pulse._run_respond() + from datetime import timedelta + since = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + assert db.pulse_audit_count_since("respond", "auto-fire", since) >= 1 + + +# ----- config round-trip ----------------------------------------------------- + +def test_config_round_trips(fresh_db): + assert pulse.respond_auto_threshold() == Severity.HIGH + assert pulse.respond_require_quorum() is True + assert pulse.respond_local_only() is False + + pulse.set_respond_auto_threshold(Severity.CRITICAL) + pulse.set_respond_require_quorum(False) + pulse.set_respond_local_only(True) + + assert pulse.respond_auto_threshold() == Severity.CRITICAL + assert pulse.respond_require_quorum() is False + assert pulse.respond_local_only() is True