Source code for stereocomplex.metrics.rayfield_metrics

from __future__ import annotations

from dataclasses import dataclass
from collections.abc import Callable

import numpy as np


[docs] @dataclass(frozen=True) class RayfieldComparisonReport: """Summary of rayfield comparison between two models. Attributes ---------- plane_intersection_rms : float RMS of 3D intersection point distances at the z-plane, in mm. plane_intersection_median : float Median intersection point distance at the z-plane, in mm. plane_intersection_p95 : float 95th percentile intersection distance at the z-plane, in mm. direction_angle_rms_deg : float RMS angular difference between ray directions, in degrees. n_samples : int Number of rays compared. """ plane_intersection_rms: float plane_intersection_median: float plane_intersection_p95: float direction_angle_rms_deg: float n_samples: int
[docs] def intersect_rays_with_z_plane( origin: np.ndarray, direction: np.ndarray, z_plane: float ) -> np.ndarray: """Intersect rays with a horizontal z-plane. Parameters ---------- origin : np.ndarray Ray origins, shape (N, 3). direction : np.ndarray Ray directions (unit vectors), shape (N, 3). z_plane : float Z-coordinate of the target plane. Returns ------- np.ndarray Intersection points, shape (N, 3). Raises ------ ValueError If any ray is parallel to the z-plane (z-component < 1e-12). """ origin_arr = np.asarray(origin, dtype=np.float64).reshape(-1, 3) direction_arr = np.asarray(direction, dtype=np.float64).reshape(-1, 3) denom = direction_arr[:, 2] if np.any(np.abs(denom) < 1e-12): raise ValueError("ray is parallel to z plane") lam = (float(z_plane) - origin_arr[:, 2]) / denom return origin_arr + lam[:, None] * direction_arr
[docs] def compare_rayfields_on_planes( fitted_field, oracle_ray_function: Callable[[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray]], image_size: tuple[int, int], z_planes: tuple[float, float], grid_shape: tuple[int, int] = (25, 19), ) -> RayfieldComparisonReport: """ Compare two ray fields by intersecting each ray with two reference z planes. This avoids comparing raw origins that may be expressed in different gauges. """ width, height = image_size nx, ny = grid_shape u = np.linspace(0.0, float(width - 1), int(nx)) v = np.linspace(0.0, float(height - 1), int(ny)) uu, vv = np.meshgrid(u, v) uf = uu.reshape(-1) vf = vv.reshape(-1) O_fit, d_fit = fitted_field.ray(uf, vf) O_true, d_true = oracle_ray_function(uf, vf) O_true = np.asarray(O_true, dtype=np.float64).reshape(-1, 3) d_true = np.asarray(d_true, dtype=np.float64).reshape(-1, 3) errors: list[np.ndarray] = [] for z in z_planes: A_fit = intersect_rays_with_z_plane(O_fit, d_fit, z) A_true = intersect_rays_with_z_plane(O_true, d_true, z) errors.append(np.linalg.norm(A_fit - A_true, axis=1)) plane_err = np.sqrt(errors[0] ** 2 + errors[1] ** 2) dots = np.sum(d_fit * d_true, axis=1) dots = np.clip(dots, -1.0, 1.0) angles_deg = np.rad2deg(np.arccos(dots)) return RayfieldComparisonReport( plane_intersection_rms=float(np.sqrt(np.mean(plane_err**2))), plane_intersection_median=float(np.median(plane_err)), plane_intersection_p95=float(np.percentile(plane_err, 95)), direction_angle_rms_deg=float(np.sqrt(np.mean(angles_deg**2))), n_samples=int(uf.size), )