tomodrgn.commands.train_vae.train_batch#
- train_batch(*, model: TiltSeriesHetOnlyVAE | DataParallelPassthrough, scaler: GradScaler, optim: Optimizer, lat: Lattice, batch_images: Tensor, batch_rots: Tensor, batch_trans: Tensor, batch_ctf_params: Tensor, batch_recon_error_weights: Tensor, batch_hartley_2d_mask: Tensor, image_ctf_premultiplied: bool, image_dose_weighted: bool, beta: float, beta_control: float | None = None, use_amp: bool = False) ndarray [source]#
Train a TiltSeriesHetOnlyVAE model on a batch of tilt series particle images.
- Parameters:
model – TiltSeriesHetOnlyVAE object to be trained
scaler – GradScaler object to be used for scaling loss involving fp16 tensors to avoid over/underflow
optim – torch.optim.Optimizer object to be used for optimizing the model
lat – Hartley-transform lattice of points for voxel grid operations
batch_images – Batch of images to be used for training, shape (batchsize, ntilts, boxsize_ht, boxsize_ht)
batch_rots – Batch of 3-D rotation matrices corresponding to batch_images known poses, shape (batchsize, ntilts, 3, 3)
batch_trans – Batch of 2-D translation matrices corresponding to batch_images known poses, shape (batchsize, ntilts, 2). May be torch.zeros((batchsize)) instead to indicate no translations should be applied to the input images.
batch_ctf_params – Batch of CTF parameters corresponding to batch_images known CTF parameters, shape (batchsize, ntilts, 9). May be torch.zeros((batchsize)) instead to indicate no CTF corruption should be applied to the reconstructed slice.
batch_recon_error_weights – Batch of 2-D weights to be applied to the per-spatial-frequency error between the reconstructed slice and the input image. Calculated from critical dose exposure curves and electron beam vs sample tilt geometry. May be torch.zeros((batchsize)) instead to indicate no weighting should be applied to the reconstructed slice error.
batch_hartley_2d_mask – Batch of 2-D masks to be applied per-spatial-frequency. Calculated as the intersection of critical dose exposure curves and a Nyquist-limited circular mask in reciprocal space, including masking the DC component.
image_ctf_premultiplied – Whether images were multiplied by their CTF during particle extraction.
image_dose_weighted – Whether images were multiplied by their exposure-dependent frequency weighting during particle extraction.
beta – scaling factor to apply to KLD during loss calculation.
beta_control – KL-Controlled VAE gamma. Beta is KL target.
use_amp – If true, use Automatic Mixed Precision to reduce memory consumption and accelerate code execution via autocast and GradScaler
- Returns:
numpy array of losses: total loss, generative loss between reconstructed slices and input images, and kld loss between latent embeddings and standard normal