"""Reference GAP-TV solver for the CASSI forward operator. Implements the algorithm described in `../solution.md` §1–§2. The public `solve` function is what the PWM evaluation harness calls on each benchmark scene. This is the genesis baseline for `PWM-L3-cassi-kaist-10scenes`. """ from __future__ import annotations import numpy as np from scipy.ndimage import shift as nd_shift def forward_cassi(x: np.ndarray, mask: np.ndarray, dispersion: np.ndarray) -> np.ndarray: """CASSI forward operator: 3-D cube -> 2-D coded snapshot. For each band b, multiply by the mask shifted by delta_b along the dispersion axis (axis=1), then sum across bands. """ H, W, B = x.shape y = np.zeros((H, W), dtype=np.float32) for b in range(B): d = int(dispersion[b]) shifted_mask = nd_shift(mask, (0, d), order=0, mode="constant", cval=0) y += shifted_mask * x[:, :, b] return y def adjoint_cassi(y: np.ndarray, mask: np.ndarray, dispersion: np.ndarray, B: int) -> np.ndarray: """Adjoint (back-projection) of the CASSI forward operator.""" H, W = y.shape x = np.zeros((H, W, B), dtype=np.float32) for b in range(B): d = int(dispersion[b]) shifted_mask = nd_shift(mask, (0, d), order=0, mode="constant", cval=0) x[:, :, b] = shifted_mask * y return x def tv_denoise_2d(z: np.ndarray, lam: float, iters: int = 20) -> np.ndarray: """Chambolle (2004) 2-D total-variation denoising on a single band.""" p = np.zeros((2,) + z.shape, dtype=np.float32) tau = 0.25 for _ in range(iters): div_p = np.zeros_like(z) div_p[:-1, :] += p[0, :-1, :] div_p[1:, :] -= p[0, :-1, :] div_p[:, :-1] += p[1, :, :-1] div_p[:, 1:] -= p[1, :, :-1] grad = np.stack(np.gradient(z + div_p / lam), axis=0) denom = 1 + tau * np.sqrt(np.sum(grad ** 2, axis=0)) p = (p + tau * grad) / np.maximum(denom, 1e-8) div_p = np.zeros_like(z) div_p[:-1, :] += p[0, :-1, :] div_p[1:, :] -= p[0, :-1, :] div_p[:, :-1] += p[1, :, :-1] div_p[:, 1:] -= p[1, :, :-1] return z + div_p / lam def solve( y: np.ndarray, mask: np.ndarray, dispersion: np.ndarray, sigma: float, *, B: int = 28, iters: int = 200, lam_init: float = 1.0, lam_final: float = 0.05, tol: float = 1e-4, ) -> np.ndarray: """Recover a hyperspectral cube from a CASSI snapshot via GAP-TV. Parameters ---------- y : (H, W) coded snapshot. mask : (H, W) binary aperture (uint8 or float). dispersion : (B,) per-band integer pixel shift. sigma : noise standard deviation (used to set the denoising scale). B : number of spectral bands (default 28 for KAIST). iters : maximum GAP iterations. lam_init : starting TV weight (annealed geometrically to lam_final). lam_final : ending TV weight. tol : early-stop on relative change. Returns ------- x_hat : (H, W, B) reconstructed cube, clipped to [0, 1]. """ mask = mask.astype(np.float32) H, W = y.shape x = adjoint_cassi(y, mask, dispersion, B) phi_norm_sq = float(np.sum(mask ** 2)) * B # rough operator norm squared lams = np.geomspace(lam_init, lam_final, iters).astype(np.float32) for k in range(iters): residual = y - forward_cassi(x, mask, dispersion) x_data = x + adjoint_cassi(residual, mask, dispersion, B) / max(phi_norm_sq, 1.0) lam_k = float(lams[k]) * max(sigma, 1e-3) x_next = np.empty_like(x_data) for b in range(B): x_next[:, :, b] = tv_denoise_2d(x_data[:, :, b], lam=lam_k, iters=5) rel = float(np.linalg.norm(x_next - x) / max(np.linalg.norm(x), 1e-8)) x = x_next if rel < tol: break return np.clip(x, 0.0, 1.0) if __name__ == "__main__": import argparse, json, sys, pathlib ap = argparse.ArgumentParser(description="GAP-TV reference solver") ap.add_argument("--benchmark", default="PWM-L3-cassi-kaist-10scenes") ap.add_argument("--data-dir", default="./cassi-kaist-10s") ap.add_argument("--report", action="store_true", help="emit reports/results.md") args = ap.parse_args() data = pathlib.Path(args.data_dir) mask = np.load(data / "mask.npy") dispersion = np.load(data / "dispersion.npy") from skimage.metrics import peak_signal_noise_ratio as psnr_fn from skimage.metrics import structural_similarity as ssim_fn rng = np.random.default_rng(42) results = [] for i in range(1, 11): x_gt = np.load(data / f"scene_{i:02d}.npy") B = x_gt.shape[-1] y_clean = forward_cassi(x_gt, mask, dispersion) y = y_clean + rng.normal(0, 0.01, size=y_clean.shape).astype(np.float32) x_hat = solve(y, mask, dispersion, sigma=0.01, B=B) ps = float(np.mean([psnr_fn(x_gt[..., b], x_hat[..., b], data_range=1.0) for b in range(B)])) ss = float(np.mean([ssim_fn(x_gt[..., b], x_hat[..., b], data_range=1.0) for b in range(B)])) results.append({"scene": i, "psnr_db": ps, "ssim": ss}) print(f"scene_{i:02d}: psnr={ps:.2f} dB, ssim={ss:.3f}") if args.report: out = pathlib.Path("reports/results.md") out.parent.mkdir(exist_ok=True) lines = ["# Reconstruction results", "", "| Scene | PSNR (dB) | SSIM |", "|---|---|---|"] for r in results: lines.append(f"| {r['scene']:02d} | {r['psnr_db']:.2f} | {r['ssim']:.3f} |") mean_psnr = sum(r["psnr_db"] for r in results) / len(results) mean_ssim = sum(r["ssim"] for r in results) / len(results) lines += ["", f"**Mean PSNR:** {mean_psnr:.2f} dB", f"**Mean SSIM:** {mean_ssim:.3f}"] out.write_text("\n".join(lines)) print(f"wrote {out}")