Source code for tomodrgn.so3_grid

"""
Implementation of Yershova et al. "Generating uniform incremental grids on SO(3) using the Hopf fribration"
See source code implementation: https://lavalle.pl/software/so3/so3.html
"""

import numpy as np
import healpy as hp


[docs] def grid_s1(resol: int) -> np.ndarray: """ Sample S1 at a given resolution (uses uniform sampling). The base resolution has 6 samples. :param resol: resolution at which to sample S1 :return: array of sampled points """ number_points = 6 * 2 ** resol interval = 2 * np.pi / number_points grid = np.arange(number_points) * interval + interval / 2 return grid
[docs] def grid_s2(resol: int) -> tuple[np.ndarray, np.ndarray]: """ Sample S2 at a given resolution (uses HEALPix sampling). The base resolution has 12 samples. :param resol: resolution at which to sample S2 :return: array of sampled points """ nside = 2 ** resol npix = 12 * nside * nside theta, phi = hp.pix2ang(nside, np.arange(npix), nest=True) return theta, phi
[docs] def hopf_to_quat(theta: np.ndarray, phi: np.ndarray, psi: np.ndarray) -> np.ndarray: """ Transform Hopf coordinates to quaternions :param theta: [0, pi), parameterizes spherical coordinates on S2 together with phi :param phi: [0, 2pi), parameterizes spherical coordinates on S2 together with theta :param psi: [0, 2pi), parameterizes circle S1 :return: quaternions """ ct = np.cos(theta / 2) st = np.sin(theta / 2) quat = np.array([ct * np.cos(psi / 2), ct * np.sin(psi / 2), st * np.cos(phi + psi / 2), st * np.sin(phi + psi / 2)]) return quat.T.astype(np.float32)
[docs] def grid_SO3(resol: int) -> np.ndarray: """ Sample points on SO(3) at the specified resolution. Relies on sampling S1 and coset space S2. :param resol: resolution at which to sample SO(3) :return: array of sampled points as quaternions """ theta, phi = grid_s2(resol) psi = grid_s1(resol) quat = hopf_to_quat(theta=np.repeat(theta, repeats=len(psi)), # repeats each element by len(psi) phi=np.repeat(phi, repeats=len(psi)), # repeats each element by len(psi) psi=np.tile(psi, reps=len(theta))) # tiles the array len(theta) times return quat
[docs] def base_SO3_grid() -> np.ndarray: """ Return the base resolution SO(3) grid :return: array of sampled points """ return grid_SO3(resol=1)
#################### # Neighbor finding # ####################
[docs] def get_s1_neighbor(mini: int, curr_res: int) -> tuple[np.ndarray, np.ndarray]: """ Return the 2 nearest neighbors on S1 at the next resolution level :param mini: :param curr_res: current grid resolution :return: psi of nearest neighbors on S1, indices of nearest neighbors on S1 """ npix = 6 * 2 ** (curr_res + 1) dt = 2 * np.pi / npix # return np.array([2*mini, 2*mini+1])*dt + dt/2 # the fiber bundle grid on SO3 is weird # the next resolution level's nearest neighbors in SO3 are not # necessarily the nearest neighbor grid points in S1 # include the 13 neighbors for now... eventually learn/memoize the mapping ind = np.arange(2 * mini - 1, 2 * mini + 3) if ind[0] < 0: ind[0] += npix return ind * dt + dt / 2, ind
[docs] def get_s2_neighbor(mini: int, curr_res: int) -> tuple[np.ndarray, np.ndarray]: """ Return the 4 nearest neighbors on S2 at the next resolution level :param mini: :param curr_res: current grid resolution :return: theta and phi of nearest neighbors on S2, indices of nearest neighbors on S2 """ nside = 2 ** (curr_res + 1) ind = np.arange(4) + 4 * mini return hp.pix2ang(nside, ind, nest=True), ind
[docs] def get_base_ind(ind: int) -> tuple[int, int]: """ Return the corresponding S2 and S1 grid index for an index on the base SO3 grid :param ind: number of points on the SO3 grid :return: corresponding S2 and S1 grid indices """ psii = ind % 12 thetai = ind // 12 return thetai, psii
[docs] def get_neighbor(quat: np.ndarray, s2i: int, s1i: int, curr_res: int) -> tuple[np.ndarray, np.ndarray]: """ Return the 8 nearest neighbors on SO3 at the next resolution level :param quat: rotation quaternions :param s2i: index of the current resolution level of S2 :param s1i: index of the current resolution level of S1 :param curr_res: current grid resolution :return: nearest neighbor quaternions, nearest neighbor indices on SO3 """ (theta, phi), s2_nexti = get_s2_neighbor(s2i, curr_res) psi, s1_nexti = get_s1_neighbor(s1i, curr_res) quat_n = hopf_to_quat(np.repeat(theta, len(psi)), np.repeat(phi, len(psi)), np.tile(psi, len(theta))) ind = np.array([np.repeat(s2_nexti, len(psi)), np.tile(s1_nexti, len(theta))]) ind = ind.T # find the 8 nearest neighbors of 16 possible points # need to check distance from both +q and -q dists = np.minimum(np.sum((quat_n - quat) ** 2, axis=1), np.sum((quat_n + quat) ** 2, axis=1)) ii = np.argsort(dists)[:8] return quat_n[ii], ind[ii]