Skip to content

nrcatalogtools.comparisons

End-to-end NR vs NRSur7dq4 comparison pipeline.

Conceptual background: See Package Internals § 12 for a description of the pipeline steps and output format.


Functions

comparisons

NR vs NRSur7dq4 comparison utilities.

compare_sim_vs_surrogate runs the full per-mode match pipeline for a single simulation: loads the catalog waveform, generates surrogate modes, computes noise-weighted matches and phase-drift metrics, writes a CSV, and saves a figure.

compare_sim_vs_surrogate

compare_sim_vs_surrogate(catalog_name: str, sim_name: str, total_mass: float = 40.0, psd_name: str = 'aLIGOZeroDetHighPower', outdir: str | None = None, figsdir: str | None = None, delta_t: float = DELTA_T, rotate: bool = False) -> dict

Run the full NR vs NRSur7dq4 comparison for one simulation.

Parameters:

Name Type Description Default
catalog_name str

One of 'SXS', 'RIT', 'MAYA'.

required
sim_name str

Simulation identifier (e.g. 'SXS:BBH:0001').

required
total_mass float

Total mass in solar masses (default 40).

40.0
psd_name str

PyCBC analytic PSD name (default 'aLIGOZeroDetHighPower').

'aLIGOZeroDetHighPower'
outdir str

Directory for the output CSV (default 'results' under cwd).

None
figsdir str

Directory for the output figure (default 'figs' under cwd).

None
delta_t float

Sample spacing in physical seconds (default 1/4096).

DELTA_T
rotate bool

Also compute the SO(3)-rotation-optimized match (slow).

False

Returns:

Type Description
dict

{(ell, em): {'match', 'f_lower_mode', 'phase_diff_per_cycle', 'n_cycles', 'match_rotated'}}

Source code in nrcatalogtools/comparisons.py
def compare_sim_vs_surrogate(
    catalog_name: str,
    sim_name: str,
    total_mass: float = 40.0,
    psd_name: str = "aLIGOZeroDetHighPower",
    outdir: str | None = None,
    figsdir: str | None = None,
    delta_t: float = DELTA_T,
    rotate: bool = False,
) -> dict:
    """Run the full NR vs NRSur7dq4 comparison for one simulation.

    Parameters
    ----------
    catalog_name : str
        One of ``'SXS'``, ``'RIT'``, ``'MAYA'``.
    sim_name : str
        Simulation identifier (e.g. ``'SXS:BBH:0001'``).
    total_mass : float, optional
        Total mass in solar masses (default 40).
    psd_name : str, optional
        PyCBC analytic PSD name (default ``'aLIGOZeroDetHighPower'``).
    outdir : str, optional
        Directory for the output CSV (default ``'results'`` under cwd).
    figsdir : str, optional
        Directory for the output figure (default ``'figs'`` under cwd).
    delta_t : float, optional
        Sample spacing in physical seconds (default 1/4096).
    rotate : bool, optional
        Also compute the SO(3)-rotation-optimized match (slow).

    Returns
    -------
    dict
        ``{(ell, em): {'match', 'f_lower_mode', 'phase_diff_per_cycle',
        'n_cycles', 'match_rotated'}}``
    """
    if outdir is None:
        outdir = os.path.join(os.getcwd(), "results")
    if figsdir is None:
        figsdir = os.path.join(os.getcwd(), "figs")

    os.makedirs(outdir, exist_ok=True)
    os.makedirs(figsdir, exist_ok=True)

    # 1. Load catalog and waveform
    from . import load_catalog

    print(f"\n[1/6] Loading {catalog_name} catalog...")
    cat = load_catalog(catalog_name)
    print(f"      Fetching {sim_name}...")
    wfm = cat.get(sim_name)

    # 2. Extract parameters
    print(f"[2/6] Extracting source parameters (M={total_mass} M☉)...")
    params = cat.get_parameters(sim_name, total_mass=total_mass)
    q = params["mass1"] / params["mass2"]
    chi1 = np.sqrt(
        params["spin1x"] ** 2 + params["spin1y"] ** 2 + params["spin1z"] ** 2
    )
    chi2 = np.sqrt(
        params["spin2x"] ** 2 + params["spin2y"] ** 2 + params["spin2z"] ** 2
    )
    f_lower = params["f_lower"]
    print(
        f"      q={q:.4f}  |χ₁|={chi1:.4f}  |χ₂|={chi2:.4f}  f_lower={f_lower:.2f} Hz"
    )

    if not check_surrogate_prior(params):
        print(
            "WARNING: parameters lie outside NRSur7dq4 prior (q > 4 or |χ| > 0.8). "
            "Proceeding, but surrogate extrapolation may be unreliable."
        )

    chi1_perp = np.sqrt(params["spin1x"] ** 2 + params["spin1y"] ** 2)
    chi2_perp = np.sqrt(params["spin2x"] ** 2 + params["spin2y"] ** 2)
    is_precessing = (chi1_perp > 1e-4) or (chi2_perp > 1e-4)
    if is_precessing and not rotate:
        print("      Precessing system: enabling SO(3)-optimized match automatically.")
        rotate = True

    # 3. Generate surrogate modes.
    # For precessing SXS binaries the surrogate spins are epoch-aligned to the
    # NR dynamics at the surrogate's training-window start (Phase 2), regardless
    # of the rotate flag.  The rotate flag only controls whether an additional
    # SO(3) frame optimisation is performed after the per-mode matches.
    print(
        f"[3/6] Generating NRSur7dq4 modes (M={total_mass:.1f} M☉, D={DISTANCE:.1f} Mpc)..."
    )
    try:
        h_sur, f_lower_sur = generate_surrogate_modes(
            params,
            total_mass,
            DISTANCE,
            delta_t_seconds=delta_t,
            sim_name=sim_name,
            catalog=cat,
            nr_wfm=wfm,
        )
    except Exception as exc:
        print(f"      Failed to generate surrogate: {exc}")
        raise
    print(f"      Generated {len(h_sur)} surrogate modes.")
    if f_lower_sur > f_lower * 1.05:
        print(
            f"      NOTE: surrogate starts at f_GW={f_lower_sur:.1f} Hz "
            f"(NRSur7dq4 minimum for these params), above NR f_lower={f_lower:.1f} Hz. "
            f"Match f_lower raised to {f_lower_sur:.1f} Hz."
        )
    f_lower_match = max(f_lower, f_lower_sur)

    print(f"[4/6] PSD: {psd_name} (built per-mode at matched frequency resolution)")

    # 5. Per-mode match
    print("[5/6] Computing per-mode matches...")
    results = {}

    for (ell, em) in NR_MODES:
        try:
            h_nr_complex = wfm.get_mode(
                ell,
                em,
                total_mass=total_mass,
                distance=DISTANCE,
                delta_t_seconds=delta_t,
            )
        except Exception as exc:
            print(f"      ({ell},{em:+d}): NR mode unavailable — {exc}")
            results[(ell, em)] = {
                "match": float("nan"),
                "f_lower_mode": float("nan"),
                "phase_diff_per_cycle": float("nan"),
                "n_cycles": float("nan"),
                "match_rotated": None,
                "R_alpha": None,
                "R_beta": None,
                "R_gamma": None,
            }
            continue

        h_nr = h_nr_complex.real()

        if (ell, em) not in h_sur:
            print(f"      ({ell},{em:+d}): surrogate mode unavailable (ell > 4?)")
            results[(ell, em)] = {
                "match": float("nan"),
                "f_lower_mode": float("nan"),
                "phase_diff_per_cycle": float("nan"),
                "n_cycles": float("nan"),
                "match_rotated": None,
                "R_alpha": None,
                "R_beta": None,
                "R_gamma": None,
            }
            continue

        h_sur_mode = h_sur[(ell, em)].real()
        f_low_mode = mode_f_lower(f_lower_match, em)
        mm = compute_mode_match(h_nr, h_sur_mode, f_low_mode, psd_name=psd_name)
        dphase, n_cycles = compute_phase_diff_per_cycle(h_nr_complex, h_sur[(ell, em)])
        results[(ell, em)] = {
            "match": mm,
            "f_lower_mode": f_low_mode,
            "phase_diff_per_cycle": dphase,
            "n_cycles": n_cycles,
            "match_rotated": None,
            "R_alpha": None,
            "R_beta": None,
            "R_gamma": None,
        }
        flag = "" if np.isnan(mm) else f"{mm:.6f}"
        dp_str = "N/A" if np.isnan(dphase) else f"{dphase:.4f} rad"
        print(
            f"      ({ell},{em:+d}): match = {flag}  phase_diff/cycle = {dp_str}"
            f"  [f_lower_mode={f_low_mode:.1f} Hz]"
        )

    if rotate:
        print(
            "[5b] Computing SO(3)-rotation-optimized match (differential evolution)..."
        )
        try:
            from .waveform.matching import load_psd

            dur_nr = _waveform_duration(wfm, total_mass)
            dur_sur = len(next(iter(h_sur.values()))) * delta_t
            psd_rot = load_psd(
                f_lower_match, delta_t, max(dur_nr, dur_sur) * 1.1, psd_name=psd_name
            )

            mm_rot, R_opt = wfm.match_sphere_averaged(
                h_sur,
                psd=psd_rot,
                f_lower=f_lower_match,
                delta_t=delta_t,
                return_rotation=True,
                total_mass=total_mass,
                distance=DISTANCE,
            )
            print(f"      SO(3)-optimized match = {mm_rot:.6f}")
            alpha, beta, gamma = R_opt.to_euler_angles
            for key in results:
                results[key]["match_rotated"] = mm_rot
                results[key]["R_alpha"] = alpha
                results[key]["R_beta"] = beta
                results[key]["R_gamma"] = gamma
        except Exception as exc:
            print(f"      SO(3)-optimized match failed: {exc}")
            import traceback

            traceback.print_exc()

    # 6. Output
    print(f"[6/6] Writing outputs (results to {outdir}/, figures to {figsdir}/)...")
    _write_csv(results, sim_name, catalog_name, total_mass, params, outdir)
    _plot(
        results,
        wfm,
        h_sur,
        sim_name,
        catalog_name,
        total_mass,
        params,
        delta_t,
        psd_name,
        figsdir,
    )
    _print_table(results, sim_name)

    return results

Constants

Name Value Description
DELTA_T 1/4096 s Default waveform sample interval
DISTANCE 1.0 Mpc Reference distance used internally (amplitude-irrelevant for match)