tomodrgn.commands.train_vae.encoder_inference#

encoder_inference(*, model: TiltSeriesHetOnlyVAE | DataParallelPassthrough, lat: Lattice, data: TiltSeriesMRCData | TomoParticlesMRCData, num_workers: int = 0, prefetch_factor: int | None = None, pin_memory: bool = False, use_amp: bool = False, batchsize: int = 1) tuple[ndarray, ndarray][source]#

Run inference on the encoder module using the specified data as input to be embedded in latent space.

Parameters:
  • model – TiltSeriesHetOnlyVAE object to be used for encoder module inference. Informs device on which to run inference.

  • lat – Hartley-transform lattice of points for voxel grid operations

  • data – TiltSeriesMRCData or TomoParticlesMRCData object for accessing tilt images with known CTF and pose parameters, to be embedded in latent space

  • use_amp – If true, use Automatic Mixed Precision to reduce memory consumption and accelerate code execution via torch.autocast

  • batchsize – batch size used in dataloader for model inference

  • num_workers – Number of workers to use with dataloader when batching particles for inference.

  • prefetch_factor – Number of particles to prefetch per worker with dataloader for inference.

  • pin_memory – Whether to use pinned memory for dataloader.

Returns:

Direct output of encoder module parameterizing the mean of the latent embedding for each particle, shape (batchsize, zdim). Direct output of encoder module parameterizing the log variance of the latent embedding for each particle, shape (batchsize, zdim)