Source code for stereocomplex.calibration.fit_zernike_origin_field

from __future__ import annotations

from dataclasses import dataclass
from collections.abc import Sequence
from typing import Literal

import numpy as np

from stereocomplex.rayfields.zernike_origin_field import (
    ZernikeOriginField,
    ZernikeOriginFieldCoefficients,
    ZernikeOriginFieldConfig,
    ZernikeRayField,
    ZernikeRayFieldCoefficients,
    _project_transverse,
)
from stereocomplex.synthetic.parallel_plate import SyntheticStereoDataset, transform_points


[docs] @dataclass(frozen=True) class StereoZernikeOriginFieldFitResult: """Fitted stereo Zernike origin fields and residual diagnostics.""" left_field: ZernikeOriginField right_field: ZernikeOriginField stereo_transform: np.ndarray board_poses: list[np.ndarray] residual_rms: float residual_median: float residual_p95: float n_observations: int success: bool message: str
def _make_transform(R: np.ndarray, t: np.ndarray) -> np.ndarray: T = np.eye(4, dtype=np.float64) T[:3, :3] = np.asarray(R, dtype=np.float64).reshape(3, 3) T[:3, 3] = np.asarray(t, dtype=np.float64).reshape(3) return T def _mode_weights(field: ZernikeOriginField) -> np.ndarray: return np.array([1.0 + float(mode.n) ** 2 for mode in field.modes], dtype=np.float64) def _point_line_residual(P: np.ndarray, origin: np.ndarray, direction: np.ndarray) -> np.ndarray: return np.cross(P - origin, direction) def _left_board_poses( board_poses: Sequence[np.ndarray], T_left_world: np.ndarray, ) -> list[np.ndarray]: T_left_world_arr = np.asarray(T_left_world, dtype=np.float64).reshape(4, 4) return [ T_left_world_arr @ np.asarray(T_board_world, dtype=np.float64).reshape(4, 4) for T_board_world in board_poses ] def _frame_points_from_left_board_poses( object_points: np.ndarray, left_board_poses: Sequence[np.ndarray], ) -> list[np.ndarray]: return [transform_points(T_left_board, object_points) for T_left_board in left_board_poses] def _pose_params_from_transforms(left_board_poses: Sequence[np.ndarray]) -> np.ndarray: from scipy.spatial.transform import Rotation # type: ignore params = [] for T in left_board_poses: T_arr = np.asarray(T, dtype=np.float64).reshape(4, 4) params.append(Rotation.from_matrix(T_arr[:3, :3]).as_rotvec()) params.append(T_arr[:3, 3]) return np.concatenate(params, axis=0) def _se3_params_from_transform(T: np.ndarray) -> np.ndarray: from scipy.spatial.transform import Rotation # type: ignore T_arr = np.asarray(T, dtype=np.float64).reshape(4, 4) return np.concatenate([Rotation.from_matrix(T_arr[:3, :3]).as_rotvec(), T_arr[:3, 3]], axis=0) def _se3_params_to_transform(params: np.ndarray) -> np.ndarray: from scipy.spatial.transform import Rotation # type: ignore p = np.asarray(params, dtype=np.float64).reshape(6) return _make_transform(Rotation.from_rotvec(p[:3]).as_matrix(), p[3:6]) def _pose_params_to_transforms(params: np.ndarray, n_frames: int) -> list[np.ndarray]: from scipy.spatial.transform import Rotation # type: ignore p = np.asarray(params, dtype=np.float64).reshape(-1) if p.size != 6 * int(n_frames): raise ValueError(f"expected {6 * int(n_frames)} pose parameters, got {p.size}") poses = [] for i in range(int(n_frames)): rv = p[6 * i : 6 * i + 3] t = p[6 * i + 3 : 6 * i + 6] poses.append(_make_transform(Rotation.from_rotvec(rv).as_matrix(), t)) return poses def _world_board_poses( left_board_poses: Sequence[np.ndarray], T_left_world: np.ndarray, ) -> list[np.ndarray]: T_world_left = np.linalg.inv(np.asarray(T_left_world, dtype=np.float64).reshape(4, 4)) return [T_world_left @ np.asarray(T, dtype=np.float64).reshape(4, 4) for T in left_board_poses] def _residual_stats(values: np.ndarray) -> tuple[float, float, float]: vals = np.asarray(values, dtype=np.float64).reshape(-1) if vals.size == 0: return float("nan"), float("nan"), float("nan") return ( float(np.sqrt(np.mean(vals**2))), float(np.median(vals)), float(np.percentile(vals, 95)), )
[docs] def fit_stereo_zernike_origin_field( observations: SyntheticStereoDataset, K_left: np.ndarray, K_right: np.ndarray, T_right_left_initial: np.ndarray, board_poses_initial: Sequence[np.ndarray], config_left: ZernikeOriginFieldConfig, config_right: ZernikeOriginFieldConfig, optimize_board_poses: bool = False, optimize_stereo_extrinsics: bool = False, optimize_directions: bool = False, robust_loss: Literal["linear", "huber", "soft_l1", "cauchy", "arctan"] = "huber", regularization: float = 1e-6, direction_regularization: float | None = None, pose_regularization: float = 0.0, rig_regularization: float = 0.0, max_nfev: int = 200, ) -> StereoZernikeOriginFieldFitResult: """Fit stereo non-central Zernike origin fields from point-to-ray residuals. The default mode identifies per-pixel ray origins ``O_left(u,v)`` and ``O_right(u,v)`` on top of fixed pinhole directions, fixed board poses and a fixed stereo rig. Optional flags can extend the optimisation to direction perturbations, per-frame board poses and the stereo extrinsics. The fit is purely geometric: it uses only observed pixels, camera intrinsics, initial poses and the board object points, never the physical oracle parameters used to generate synthetic data. Parameters ---------- observations : SyntheticStereoDataset Stereo calibration observations. ``left_pixels`` and ``right_pixels`` are per-frame pixel arrays; ``object_points`` or ``per_frame_object_points`` provide matching board coordinates in mm. K_left, K_right : ndarray, shape (3, 3) Left and right pinhole intrinsics used for the fixed base directions. T_right_left_initial : ndarray, shape (4, 4) Initial transform from left-camera coordinates to right-camera coordinates. board_poses_initial : sequence of ndarray, each shape (4, 4) Initial board-to-world transforms, one per observed frame. config_left, config_right : ZernikeOriginFieldConfig Zernike basis, image size and gauge settings for each channel. optimize_board_poses : bool If True, include one SE(3) board pose block per frame in the least- squares vector. optimize_stereo_extrinsics : bool If True, refine the stereo rig transform around ``T_right_left_initial``. optimize_directions : bool If True, fit Zernike direction perturbation coefficients as well as origin coefficients. robust_loss : {"linear", "huber", "soft_l1", "cauchy", "arctan"} Robust loss passed to ``scipy.optimize.least_squares``. regularization : float L2 weight for origin-field coefficients. direction_regularization : float or None L2 weight for direction coefficients. ``None`` reuses ``regularization``. pose_regularization : float L2 weight anchoring optimised board poses to their initial values. rig_regularization : float L2 weight anchoring optimised stereo extrinsics to the initial rig. max_nfev : int Maximum number of least-squares function evaluations. Returns ------- StereoZernikeOriginFieldFitResult Dataclass result object containing the fitted left/right rayfields, optional direction coefficients, refined poses/rig and residual diagnostics. """ if len(board_poses_initial) != len(observations.left_pixels): raise ValueError("board_poses_initial must match the number of observed frames") left0 = ZernikeOriginField(K_left, config_left) right0 = ZernikeOriginField(K_right, config_right) if len(left0.modes) != len(right0.modes): raise ValueError("left and right fields must use the same number of Zernike terms") n_terms = len(left0.modes) T_RL_initial = np.asarray(T_right_left_initial, dtype=np.float64).reshape(4, 4) left_board_initial = _left_board_poses(board_poses_initial, observations.T_left_world) all_obj_per_frame = ( observations.per_frame_object_points if observations.per_frame_object_points is not None else [observations.object_points] * len(observations.left_pixels) ) P_left_frames = [ transform_points(T, obj) for T, obj in zip(left_board_initial, all_obj_per_frame, strict=True) ] frame_data: list[tuple[np.ndarray, np.ndarray]] = [] for uvL, uvR, P_L in zip( observations.left_pixels, observations.right_pixels, P_left_frames, strict=True, ): uvL_arr = np.asarray(uvL, dtype=np.float64).reshape(-1, 2) uvR_arr = np.asarray(uvR, dtype=np.float64).reshape(-1, 2) P_L_arr = np.asarray(P_L, dtype=np.float64).reshape(-1, 3) is_left_only = uvR_arr.shape[0] == 0 if not is_left_only and uvL_arr.shape != uvR_arr.shape: raise ValueError("inconsistent observation shapes") if uvL_arr.shape[0] != P_L_arr.shape[0]: raise ValueError("inconsistent observation shapes") frame_data.append((uvL_arr, uvR_arr)) # Pixel coordinates never change during optimisation; precompute Zernike design # matrices A(u,v) and pinhole directions d0(u,v) once for each frame. A_left_per_frame: list[np.ndarray] = [] A_right_per_frame: list[np.ndarray] = [] d0_left_per_frame: list[np.ndarray] = [] d0_right_per_frame: list[np.ndarray] = [] for uvL_arr, uvR_arr in frame_data: A_left_per_frame.append(left0.basis(uvL_arr[:, 0], uvL_arr[:, 1])) d0_left_per_frame.append(left0.direction(uvL_arr[:, 0], uvL_arr[:, 1])) if uvR_arr.shape[0] > 0: A_right_per_frame.append(right0.basis(uvR_arr[:, 0], uvR_arr[:, 1])) d0_right_per_frame.append(right0.direction(uvR_arr[:, 0], uvR_arr[:, 1])) else: A_right_per_frame.append(np.zeros((0, n_terms), dtype=np.float64)) d0_right_per_frame.append(np.zeros((0, 3), dtype=np.float64)) _gauge_left = config_left.enforce_transverse_gauge _gauge_right = config_right.enforce_transverse_gauge weights = _mode_weights(left0) sqrt_reg = np.sqrt(float(max(regularization, 0.0))) direction_reg = ( regularization if direction_regularization is None else float(direction_regularization) ) sqrt_dir_reg = np.sqrt(float(max(direction_reg, 0.0))) sqrt_pose_reg = np.sqrt(float(max(pose_regularization, 0.0))) sqrt_rig_reg = np.sqrt(float(max(rig_regularization, 0.0))) n_coeff = n_terms * 3 pose0 = _pose_params_from_transforms(left_board_initial) rig0 = _se3_params_from_transform(T_RL_initial) p0_parts = [np.zeros((2 * n_coeff,), dtype=np.float64)] if optimize_directions: p0_parts.append(np.zeros((2 * n_coeff,), dtype=np.float64)) if optimize_board_poses: p0_parts.append(pose0.copy()) if optimize_stereo_extrinsics: p0_parts.append(rig0.copy()) p0 = np.concatenate(p0_parts, axis=0) def unpack(p: np.ndarray): """Unpack the parameter vector into origin, direction, poses, and rig components.""" arr = np.asarray(p, dtype=np.float64).reshape(-1) cursor = 0 left_origin = arr[cursor : cursor + n_coeff].reshape(n_terms, 3) cursor += n_coeff right_origin = arr[cursor : cursor + n_coeff].reshape(n_terms, 3) cursor += n_coeff if optimize_directions: left_direction = arr[cursor : cursor + n_coeff].reshape(n_terms, 3) cursor += n_coeff right_direction = arr[cursor : cursor + n_coeff].reshape(n_terms, 3) cursor += n_coeff else: left_direction = np.zeros_like(left_origin) right_direction = np.zeros_like(right_origin) if optimize_board_poses: n_pose_params = 6 * len(frame_data) pose_params = arr[cursor : cursor + n_pose_params] cursor += n_pose_params left_board_poses = _pose_params_to_transforms(pose_params, len(frame_data)) else: pose_params = pose0 left_board_poses = left_board_initial if optimize_stereo_extrinsics: rig_params = arr[cursor : cursor + 6] cursor += 6 else: rig_params = rig0 if cursor != arr.size: raise ValueError(f"unused optimization parameters: {arr.size - cursor}") T_RL_current = _se3_params_to_transform(rig_params) return ( left_origin, right_origin, left_direction, right_direction, pose_params, left_board_poses, rig_params, T_RL_current, ) def make_fields( left_origin: np.ndarray, right_origin: np.ndarray, left_direction: np.ndarray, right_direction: np.ndarray, ): """Build left and right ZernikeRayField from coefficient arrays. Uses closure variables ``config_left``, ``config_right``, ``K_left``, ``K_right``, and ``optimize_directions`` set by the enclosing fit function. Parameters ---------- left_origin : ndarray Left-channel origin coefficients, shape (n_terms, 3), in mm. right_origin : ndarray Right-channel origin coefficients. left_direction : ndarray Left-channel direction coefficients, shape (n_terms, 3), unitless. right_direction : ndarray Right-channel direction coefficients. """ if optimize_directions: left = ZernikeRayField( K_left, config_left, ZernikeRayFieldCoefficients(left_origin, left_direction), ) right = ZernikeRayField( K_right, config_right, ZernikeRayFieldCoefficients(right_origin, right_direction), ) else: left = ZernikeOriginField( K_left, config_left, ZernikeOriginFieldCoefficients(left_origin), ) right = ZernikeOriginField( K_right, config_right, ZernikeOriginFieldCoefficients(right_origin), ) return left, right def _ray_cached( A: np.ndarray, d0: np.ndarray, origin_coeffs: np.ndarray, direction_coeffs: np.ndarray, enforce_gauge: bool, ) -> tuple[np.ndarray, np.ndarray]: """Compute (O, d) using precomputed design matrix A and pinhole directions d0. Mirrors ZernikeOriginField.origin() / ZernikeRayField.direction() using _project_transverse so both code paths stay in sync automatically. """ O_raw = A @ origin_coeffs # (N, 3) if optimize_directions: d = d0 + _project_transverse(A @ direction_coeffs, d0) d = d / np.linalg.norm(d, axis=1, keepdims=True) else: d = d0 origin = _project_transverse(O_raw, d) if enforce_gauge else O_raw return origin, d def residuals(p: np.ndarray) -> np.ndarray: """Compute ray-space residuals for the current BA state.""" ( left_origin, right_origin, left_direction, right_direction, pose_params, left_board_poses, rig_params, T_RL_current, ) = unpack(p) P_left_current = [ transform_points(T, obj) for T, obj in zip(left_board_poses, all_obj_per_frame, strict=True) ] R_RL = T_RL_current[:3, :3] t_RL = T_RL_current[:3, 3] parts: list[np.ndarray] = [] for i, ((_uvL_arr, uvR_arr), P_L_arr) in enumerate( zip(frame_data, P_left_current, strict=True), ): O_L, d_L = _ray_cached( A_left_per_frame[i], d0_left_per_frame[i], left_origin, left_direction, _gauge_left, ) parts.append(_point_line_residual(P_L_arr, O_L, d_L).reshape(-1)) if uvR_arr.shape[0] > 0: P_R_arr = (R_RL @ P_L_arr.T).T + t_RL.reshape(1, 3) O_R, d_R = _ray_cached( A_right_per_frame[i], d0_right_per_frame[i], right_origin, right_direction, _gauge_right, ) parts.append(_point_line_residual(P_R_arr, O_R, d_R).reshape(-1)) if sqrt_reg > 0: reg_left = (np.sqrt(weights)[:, None] * left_origin).reshape(-1) reg_right = (np.sqrt(weights)[:, None] * right_origin).reshape(-1) parts.append(sqrt_reg * reg_left) parts.append(sqrt_reg * reg_right) if optimize_directions and sqrt_dir_reg > 0: reg_left_dir = (np.sqrt(weights)[:, None] * left_direction).reshape(-1) reg_right_dir = (np.sqrt(weights)[:, None] * right_direction).reshape(-1) parts.append(sqrt_dir_reg * reg_left_dir) parts.append(sqrt_dir_reg * reg_right_dir) if optimize_board_poses and sqrt_pose_reg > 0: parts.append(sqrt_pose_reg * (pose_params - pose0)) if optimize_stereo_extrinsics and sqrt_rig_reg > 0: parts.append(sqrt_rig_reg * (rig_params - rig0)) return np.concatenate(parts, axis=0) from scipy.optimize import least_squares # type: ignore sol = least_squares( residuals, p0, method="trf", loss=robust_loss, f_scale=1.0, max_nfev=int(max_nfev), ) ( left_origin, right_origin, left_direction, right_direction, _pose_params, left_board_poses, _rig_params, T_RL_final, ) = unpack(sol.x) left_field, right_field = make_fields( left_origin, right_origin, left_direction, right_direction, ) P_left_final = [ transform_points(T, obj) for T, obj in zip(left_board_poses, all_obj_per_frame, strict=True) ] R_RL = T_RL_final[:3, :3] t_RL = T_RL_final[:3, 3] residual_norms: list[np.ndarray] = [] for (uvL_arr, uvR_arr), P_L_arr in zip(frame_data, P_left_final, strict=True): O_L, d_L = left_field.ray(uvL_arr[:, 0], uvL_arr[:, 1]) residual_norms.append(np.linalg.norm(_point_line_residual(P_L_arr, O_L, d_L), axis=1)) if uvR_arr.shape[0] > 0: P_R_arr = (R_RL @ P_L_arr.T).T + t_RL.reshape(1, 3) O_R, d_R = right_field.ray(uvR_arr[:, 0], uvR_arr[:, 1]) residual_norms.append(np.linalg.norm(_point_line_residual(P_R_arr, O_R, d_R), axis=1)) all_norms = np.concatenate(residual_norms, axis=0) rms, median, p95 = _residual_stats(all_norms) return StereoZernikeOriginFieldFitResult( left_field=left_field, right_field=right_field, stereo_transform=T_RL_final, board_poses=_world_board_poses(left_board_poses, observations.T_left_world), residual_rms=rms, residual_median=median, residual_p95=p95, n_observations=int(sum(uvL.shape[0] for uvL, _uvR in frame_data)), success=bool(sol.success), message=str(sol.message), )