tomodrgn.commands.train_vae.encode_batch#
- encode_batch(*, model: TiltSeriesHetOnlyVAE | DataParallelPassthrough, batch_images: Tensor) tuple[Tensor, Tensor, Tensor] [source]#
Encode a batch of particles represented by multiple images to per-particle latent embeddings
- Parameters:
model – TiltSeriesHetOnlyVAE object to be trained
batch_images – Batch of images to be used for training, shape (batchsize, ntilts, boxsize_ht*2)
- Returns:
z_mu: Direct output of encoder module parameterizing the mean of the latent embedding for each particle, shape (batchsize, zdim)
- Returns:
z_logvar: Direct output of encoder module parameterizing the log variance of the latent embedding for each particle, shape (batchsize, zdim)
- Returns:
z: Resampling of the latent embedding for each particle parameterized as a gaussian with mean z_mu and variance z_logvar, shape (batchsize, zdim)