tomodrgn.models#

Classes for creating, loading, training, and evaluating pytorch models.

Functions

mlp_ascii

Create ASCII art of a fully connected multi-layer perceptron.

print_tiltserieshetonlyvae_ascii

Print an ASCII art representation of a TiltSeriesHetOnlyVAE model

Classes

DataParallelPassthrough

Class to wrap underlying module in DataParallel for GPU-parallelized computations, but allow accessing underlying module attributes and methods.

FTPositionalDecoder

A module to decode a (batch of tilts of) spatial frequency coordinates spanning (-0.5, 0.5) to the corresponding spatial frequency amplitude.

MedianPool1d

Median pool module. Primarily exists due to limitations in pre-existing layer-based definitions of median. * torch.nn does not have a MedianPool module * einops does not support a callable (such as torch.median) when defining a Reduce layer.

ResidLinear

A Residual Block layer consisting of a single linear layer with an identity skip connection.

ResidLinearMLP

Multiple connected Residual Blocks as a Multi-Layer Perceptron.

TiltSeriesEncoder

A module to encode multiple (tilt) images of a particle to a latent embedding.

TiltSeriesHetOnlyVAE

A module to encode multiple tilt images of a particle to a learned low-dimensional latent space embedding, then decode spatial frequency coordinates to corresponding voxel amplitudes conditioned on the per-particle latent embedding.

VolumeGenerator

Convenience class to generate volume(s) from a trained tomoDRGN model.