#!/usr/bin/env python3
"""Build the asset bundle embedded in combine_simulator.html.

Reads:
  - scripts/match_nights/2026-04-29.json  (canonical event log)
  - scripts/match_nights/2026-04-30.json  (canonical event log)
  - scripts/match_nights/historical/eligible_players.json

Writes:
  - scripts/match_nights/combine_simulator_data.json

The bundle is small (~200 KB) and self-contained: it carries both nights'
JOIN/USER_LEAVE event streams (so the JS replay re-derives pops), the active
eligible pool, the tier bounds, and a handful of empirical priors used by the
synthetic-night generator.
"""

import datetime as dt
import json
import math
import statistics
from collections import Counter
from pathlib import Path

ROOT = Path(__file__).resolve().parent
NIGHTS = [ROOT / "2026-04-29.json", ROOT / "2026-04-30.json"]
ELIGIBLE = ROOT / "historical" / "eligible_players.json"
OUT = ROOT / "combine_simulator_data.json"

# Player types that actually queue for combines. Filters out Spec/Expired.
ACTIVE_TYPES = {"DE", "FA", "PermFA", "UGM", "UAGM", "Signed"}


def parse_iso(s: str) -> dt.datetime:
    return dt.datetime.fromisoformat(s.replace("Z", "+00:00"))


def load_night(path: Path) -> dict:
    with path.open() as f:
        return json.load(f)


def night_label(path: Path) -> str:
    return path.stem.replace("-", "_")


def fit_arrival_cdf(nights: list[dict], buckets: int = 100) -> list[dict]:
    """Pool first-JOIN times (per player) across nights; normalize to [0,1]."""
    pooled: list[float] = []
    for night in nights:
        start_ms = parse_iso(night["metadata"]["window"]["start"]).timestamp() * 1000
        end_ms = parse_iso(night["metadata"]["window"]["end"]).timestamp() * 1000
        duration = max(end_ms - start_ms, 1)
        seen: set[int] = set()
        for ev in night["events"]:
            if ev["type"] != "JOIN":
                continue
            pid = ev.get("player_id")
            if pid is None or pid in seen:
                continue
            seen.add(pid)
            frac = (ev["t_ms"] - start_ms) / duration
            pooled.append(max(0.0, min(1.0, frac)))
    pooled.sort()
    # Build a buckets-step CDF.
    out = []
    for i in range(buckets + 1):
        x = i / buckets
        # Number of values <= x.
        lo, hi = 0, len(pooled)
        while lo < hi:
            mid = (lo + hi) // 2
            if pooled[mid] <= x:
                lo = mid + 1
            else:
                hi = mid
        out.append({"t_frac": round(x, 4), "cdf": round(lo / max(len(pooled), 1), 5)})
    return out


def fit_arrivals_per_player(nights: list[dict]) -> list[float]:
    """PMF of JOIN-count per player, indexed by count (0 means 'never JOINed' — not used)."""
    counts: Counter[int] = Counter()
    for night in nights:
        per_player: Counter[int] = Counter()
        for ev in night["events"]:
            if ev["type"] == "JOIN":
                pid = ev.get("player_id")
                if pid is not None:
                    per_player[pid] += 1
        for c in per_player.values():
            counts[c] += 1
    if not counts:
        return [0.0]
    max_c = max(counts)
    total = sum(counts.values())
    pmf = [counts.get(i, 0) / total for i in range(max_c + 1)]
    return [round(x, 6) for x in pmf]


def fit_requeue_gap(nights: list[dict]) -> dict:
    """Lognormal fit on consecutive-JOIN intervals per player (seconds)."""
    gaps: list[float] = []
    for night in nights:
        per_player_joins: dict[int, list[int]] = {}
        for ev in night["events"]:
            if ev["type"] == "JOIN":
                pid = ev.get("player_id")
                if pid is not None:
                    per_player_joins.setdefault(pid, []).append(ev["t_ms"])
        for ts in per_player_joins.values():
            ts.sort()
            for i in range(1, len(ts)):
                gap = (ts[i] - ts[i - 1]) / 1000
                if gap > 0:
                    gaps.append(gap)
    # Lognormal MLE: mu = mean(log(x)), sigma = stdev(log(x)).
    if not gaps:
        return {"mu": math.log(60), "sigma": 1.0, "median_seconds": 60, "n": 0}
    log_gaps = [math.log(g) for g in gaps if g > 0]
    mu = statistics.fmean(log_gaps)
    sigma = statistics.pstdev(log_gaps) or 1.0
    return {
        "mu": round(mu, 4),
        "sigma": round(sigma, 4),
        "median_seconds": round(statistics.median(gaps), 1),
        "n": len(gaps),
    }


def user_leave_rate(nights: list[dict]) -> float:
    joins = leaves = 0
    for night in nights:
        for ev in night["events"]:
            if ev["type"] == "JOIN":
                joins += 1
            elif ev["type"] == "USER_LEAVE":
                leaves += 1
    return round(leaves / max(joins, 1), 5)


def night_duration_seconds(nights: list[dict]) -> int:
    durations = []
    for night in nights:
        s = parse_iso(night["metadata"]["window"]["start"])
        e = parse_iso(night["metadata"]["window"]["end"])
        durations.append((e - s).total_seconds())
    return int(statistics.median(durations))


def slim_events(night: dict) -> dict:
    """Keep only JOIN+USER_LEAVE for replay; record actual pop timestamps for the sanity strip."""
    join_leave = []
    pops = []
    for ev in night["events"]:
        if ev["type"] in ("JOIN", "USER_LEAVE"):
            entry = {"t_ms": ev["t_ms"], "pid": ev["player_id"], "type": ev["type"]}
            join_leave.append(entry)
        elif ev["type"] == "POP":
            pops.append({"t_ms": ev["t_ms"]})
    join_leave.sort(key=lambda e: e["t_ms"])
    pops.sort(key=lambda e: e["t_ms"])
    return {"events": join_leave, "pops": pops}


def player_metadata(night: dict) -> dict:
    """id -> {name, mmr, type, tier}. Strips discord_id."""
    out = {}
    for pid_str, info in night.get("players", {}).items():
        out[int(pid_str)] = {
            "mmr": float(info.get("mmr") or 0),
            "type": info.get("type") or "",
            "tier": info.get("tier") or "Unrated",
        }
    return out


def main() -> None:
    nights = [load_night(p) for p in NIGHTS]

    with ELIGIBLE.open() as f:
        eligible_raw = json.load(f)
    eligible_active = [
        {
            "id": p["id"],
            "mmr": float(p["mmr"]),
            "tier": p["tier"],
            "type": p["type"],
        }
        for p in eligible_raw
        if p.get("type") in ACTIVE_TYPES and float(p.get("mmr") or 0) > 0
    ]
    eligible_active.sort(key=lambda p: p["id"])

    # Tier bounds: prefer the 04-29 export (well-formed schema with all 7 tiers).
    tiers = nights[0]["metadata"]["tier_bounds"]
    tier_order = nights[0]["metadata"]["tier_order"]

    bundle = {
        "schema_version": 1,
        "generated_at": dt.datetime.now(dt.timezone.utc).isoformat().replace("+00:00", "Z"),
        "tiers": tiers,
        "tier_order": tier_order,
        "eligible_players": eligible_active,
        "eligible_breakdown": {
            "by_type": dict(Counter(p["type"] for p in eligible_active)),
            "by_tier": dict(Counter(p["tier"] for p in eligible_active)),
            "total": len(eligible_active),
        },
        "nights": {},
        "priors": {
            "arrival_cdf": fit_arrival_cdf(nights),
            "arrivals_per_player_pmf": fit_arrivals_per_player(nights),
            "requeue_gap_lognormal": fit_requeue_gap(nights),
            "user_leave_rate": user_leave_rate(nights),
            "night_duration_seconds": night_duration_seconds(nights),
        },
    }

    for path, night in zip(NIGHTS, nights):
        label = night_label(path)
        slim = slim_events(night)
        start_ms = parse_iso(night["metadata"]["window"]["start"]).timestamp() * 1000
        end_ms = parse_iso(night["metadata"]["window"]["end"]).timestamp() * 1000
        # Re-base events to relative ms-since-window-start so JS doesn't sweat 13-digit ints.
        for ev in slim["events"]:
            ev["t_ms"] = ev["t_ms"] - int(start_ms)
        for ev in slim["pops"]:
            ev["t_ms"] = ev["t_ms"] - int(start_ms)
        bundle["nights"][label] = {
            "label": path.stem,
            "duration_ms": int(end_ms - start_ms),
            "config_snapshot": night["metadata"]["config_snapshot"],
            "stats": night["metadata"]["stats"],
            "events": slim["events"],
            "actual_pops": slim["pops"],
            "players": player_metadata(night),
        }

    OUT.write_text(json.dumps(bundle, separators=(",", ":")))
    size_kb = OUT.stat().st_size / 1024
    print(f"wrote {OUT} ({size_kb:.1f} KB)")
    print(
        f"  eligible (active): {bundle['eligible_breakdown']['total']} "
        f"by_type={bundle['eligible_breakdown']['by_type']}"
    )
    for label, data in bundle["nights"].items():
        print(
            f"  {label}: {len(data['events'])} JOIN/LEAVE events, "
            f"{len(data['actual_pops'])} pops, {len(data['players'])} players"
        )
    print(f"  arrival_cdf: {len(bundle['priors']['arrival_cdf'])} buckets")
    print(f"  arrivals_per_player_pmf: max_count={len(bundle['priors']['arrivals_per_player_pmf'])-1}")
    print(f"  requeue_gap_lognormal: {bundle['priors']['requeue_gap_lognormal']}")
    print(f"  user_leave_rate: {bundle['priors']['user_leave_rate']}")


if __name__ == "__main__":
    main()
