tomodrgn.commands.train_vae.encode_batch#

encode_batch(*, model: TiltSeriesHetOnlyVAE | DataParallelPassthrough, batch_images: Tensor) tuple[Tensor, Tensor, Tensor][source]#

Encode a batch of particles represented by multiple images to per-particle latent embeddings

Parameters:
  • model – TiltSeriesHetOnlyVAE object to be trained

  • batch_images – Batch of images to be used for training, shape (batchsize, ntilts, boxsize_ht*2)

Returns:

z_mu: Direct output of encoder module parameterizing the mean of the latent embedding for each particle, shape (batchsize, zdim)

Returns:

z_logvar: Direct output of encoder module parameterizing the log variance of the latent embedding for each particle, shape (batchsize, zdim)

Returns:

z: Resampling of the latent embedding for each particle parameterized as a gaussian with mean z_mu and variance z_logvar, shape (batchsize, zdim)