tomodrgn.commands.train_vae.encoder_inference#
- encoder_inference(*, model: TiltSeriesHetOnlyVAE | DataParallelPassthrough, lat: Lattice, data: TiltSeriesMRCData | TomoParticlesMRCData, num_workers: int = 0, prefetch_factor: int | None = None, pin_memory: bool = False, use_amp: bool = False, batchsize: int = 1) tuple[ndarray, ndarray] [source]#
Run inference on the encoder module using the specified data as input to be embedded in latent space.
- Parameters:
model – TiltSeriesHetOnlyVAE object to be used for encoder module inference. Informs device on which to run inference.
lat – Hartley-transform lattice of points for voxel grid operations
data – TiltSeriesMRCData or TomoParticlesMRCData object for accessing tilt images with known CTF and pose parameters, to be embedded in latent space
use_amp – If true, use Automatic Mixed Precision to reduce memory consumption and accelerate code execution via torch.autocast
batchsize – batch size used in dataloader for model inference
num_workers – Number of workers to use with dataloader when batching particles for inference.
prefetch_factor – Number of particles to prefetch per worker with dataloader for inference.
pin_memory – Whether to use pinned memory for dataloader.
- Returns:
Direct output of encoder module parameterizing the mean of the latent embedding for each particle, shape (batchsize, zdim). Direct output of encoder module parameterizing the log variance of the latent embedding for each particle, shape (batchsize, zdim)