tomodrgn.commands.train_vae#

Train a VAE for heterogeneous reconstruction with known pose for tomography data

Functions

add_args

decode_batch

Decode a batch of particles represented by multiple images from per-particle latent embeddings and corresponding lattice positions to evaluate

encode_batch

Encode a batch of particles represented by multiple images to per-particle latent embeddings

encoder_inference

Run inference on the encoder module using the specified data as input to be embedded in latent space.

loss_function

Calculate generative loss between reconstructed and input images, and beta-weighted KLD between latent embeddings and standard normal

main

preprocess_batch

Center images via translation and phase flip for partial CTF correction, as needed

save_checkpoint

Save model weights and latent encoding z

train_batch

Train a TiltSeriesHetOnlyVAE model on a batch of tilt series particle images.