#!/usr/bin/env python3
"""Vloex audit-chain verifier — runs offline, stdlib only.

Usage: python vloex-verify.py <bundle-dir>

The bundle directory must contain:
  events.jsonl   — one envelope per line:
                   {event_id, day, key_id, signature, payload_hash,
                    leaf_hash, anchor_id, status}
  proofs.jsonl   — {event_id, day, anchor_id, leaf_index,
                    sibling_hashes (hex), status}
  anchors.json   — list of {id, day, sequence, root, prev_root,
                            leaf_count, s3_uri, created_at}

Each verified event carries an ``anchor_id`` pointing at the specific
Merkle anchor that committed to its leaf. The verifier looks up that
anchor's root and walks the supplied sibling path. Anchors are
identified by id (UUID), not by day, because a single day can have
multiple sequences when straggler events arrive after the initial
seal.

Each leaf is re-derived from the envelope's (payload_hash, key_id,
sig_value) — the same triple the backend used at ingest. Auditors do
NOT need the underlying event text (encrypted at rest, redaction-
sensitive). Mismatches print and the script exits 1; pending_anchor
events (verified rows the worker has not folded into a tree yet —
e.g. exported before today's seal at 00:05 UTC) are accounted for
separately and do NOT cause failure.
"""

import hashlib
import json
import os
import sys

LEAF_PREFIX = b"\x00"
NODE_PREFIX = b"\x01"


def compute_leaf_hash(payload_hash_hex, key_id, sig_value_b64):
    body = (
        payload_hash_hex.encode("ascii")
        + b"|"
        + (key_id or "").encode("ascii", errors="replace")
        + b"|"
        + (sig_value_b64 or "").encode("ascii", errors="replace")
    )
    return hashlib.sha256(LEAF_PREFIX + body).digest()


def verify_inclusion(leaf, idx, sibling_hashes_hex, anchored_root_hex):
    cur = leaf
    i = idx
    for sib_hex in sibling_hashes_hex:
        try:
            sib = bytes.fromhex(sib_hex)
        except ValueError:
            return False
        if i % 2 == 1:
            cur = hashlib.sha256(NODE_PREFIX + sib + cur).digest()
        else:
            cur = hashlib.sha256(NODE_PREFIX + cur + sib).digest()
        i //= 2
    return cur.hex() == anchored_root_hex


def main(bundle_dir):
    events = {}
    with open(os.path.join(bundle_dir, "events.jsonl"), "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            evt = json.loads(line)
            events[evt["event_id"]] = evt

    proofs = {}
    with open(os.path.join(bundle_dir, "proofs.jsonl"), "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            proof = json.loads(line)
            proofs[proof["event_id"]] = proof

    # Index anchors by their id (UUID) — events carry anchor_id, not day,
    # so this is the lookup auditors actually need.
    with open(os.path.join(bundle_dir, "anchors.json"), "r", encoding="utf-8") as f:
        anchors_list = json.load(f)
        anchors = {a["id"]: a for a in anchors_list}

    bad = 0
    verified_count = 0
    pending_anchor = 0
    skipped = 0
    alg_counts: dict[str, int] = {}
    for event_id, evt in events.items():
        status = evt.get("status")
        # Only ``verified`` rows participate in the Merkle tree.
        #   - unsigned / invalid / duplicate: no proof, never had one
        #   - pending_anchor: verified row not yet folded into any
        #     anchor (will be picked up on the next worker pass).
        if status == "pending_anchor":
            pending_anchor += 1
            continue
        if status != "verified":
            skipped += 1
            continue
        proof = proofs.get(event_id)
        if proof is None or proof.get("leaf_index", -1) < 0:
            print("MISSING PROOF for event", event_id)
            bad += 1
            continue
        anchor_id = evt.get("anchor_id") or proof.get("anchor_id")
        if not anchor_id:
            # Verified status but no anchor_id — internal inconsistency.
            print("MISSING ANCHOR_ID for event", event_id)
            bad += 1
            continue
        anchor = anchors.get(anchor_id)
        if anchor is None:
            # The anchor isn't in this export (e.g. exported a tighter
            # range than the events span). Treat as pending — the
            # auditor can re-export with a wider window if they need
            # the proof.
            pending_anchor += 1
            continue
        sig = evt.get("signature") or {}
        leaf = compute_leaf_hash(
            evt["payload_hash"],
            evt.get("key_id") or "",
            sig.get("value") or "",
        )
        # Sanity check: the envelope's leaf_hash must match what we
        # re-derived. If it doesn't, the export is internally
        # inconsistent — flag it before the proof step.
        if leaf.hex() != evt.get("leaf_hash"):
            print("LEAF DERIVATION MISMATCH for event", event_id)
            bad += 1
            continue
        ok = verify_inclusion(
            leaf,
            proof["leaf_index"],
            proof["sibling_hashes"],
            anchor["root"],
        )
        if not ok:
            print(
                "MISMATCH for event",
                event_id,
                "anchor",
                anchor_id,
                "day",
                anchor.get("day"),
                "sequence",
                anchor.get("sequence"),
            )
            bad += 1
        else:
            verified_count += 1
            sig_alg = (sig.get("alg") or "unknown") if isinstance(sig, dict) else "unknown"
            alg_counts[sig_alg] = alg_counts.get(sig_alg, 0) + 1

    if bad:
        print("FAILED:", bad, "events did not verify")
        sys.exit(1)
    alg_breakdown = (
        " [" + ", ".join(f"{a}: {n}" for a, n in sorted(alg_counts.items())) + "]"
        if alg_counts
        else ""
    )
    print(
        "OK:",
        verified_count,
        "events verified against",
        len(anchors),
        "anchors (",
        pending_anchor,
        "pending_anchor,",
        skipped,
        "non-verified skipped)" + alg_breakdown,
    )


if __name__ == "__main__":
    if len(sys.argv) != 2:
        print(__doc__)
        sys.exit(2)
    main(sys.argv[1])
