tomodrgn.commands.train_vae.loss_function#

loss_function(*, z_mu: Tensor, z_logvar: Tensor, batch_images: Tensor, batch_images_recon: Tensor, batch_ctf_weights: Tensor, batch_recon_error_weights: Tensor, batch_hartley_2d_mask: Tensor, image_ctf_premultiplied: bool, image_dose_weighted: bool, beta: float, beta_control: float | None = None) tuple[Tensor, Tensor, Tensor][source]#

Calculate generative loss between reconstructed and input images, and beta-weighted KLD between latent embeddings and standard normal

Parameters:
  • z_mu – Direct output of encoder module parameterizing the mean of the latent embedding for each particle, shape (batchsize, zdim)

  • z_logvar – Direct output of encoder module parameterizing the log variance of the latent embedding for each particle, shape (batchsize, zdim)

  • batch_images – Batch of images to be used for training, shape (batchsize, ntilts, boxsize_ht**2)

  • batch_images_recon – Reconstructed central slices of Fourier space volumes corresponding to each particle in the batch, shape (batchsize * ntilts * boxsize_ht**2 [batch_hartley_2d_mask])

  • batch_ctf_weights – CTF evaluated at each spatial frequency corresponding to input images, shape (batchsize, ntilts, boxsize_ht**2) or None if no CTF should be applied

  • batch_recon_error_weights – Batch of 2-D weights to be applied to the per-spatial-frequency error between each reconstructed slice and input image, shape (batchsize, ntilts, boxsize_ht**2). Calculated from critical dose exposure curves and electron beam vs sample tilt geometry.

  • batch_hartley_2d_mask – Batch of 2-D masks to be applied per-spatial-frequency. Calculated as the intersection of critical dose exposure curves and a Nyquist-limited circular mask in reciprocal space, including masking the DC component.

  • image_ctf_premultiplied – Whether images were multiplied by their CTF during particle extraction.

  • image_dose_weighted – Whether images were multiplied by their exposure-dependent frequency weighting during particle extraction. :param beta: scaling factor to apply to KLD during loss calculation.

  • beta – scaling factor to apply to KLD during loss calculation.

  • beta_control – KL-Controlled VAE gamma. Beta is KL target.

Returns:

total summed loss, generative loss between reconstructed slices and input images, and beta-weighted kld loss between latent embeddings and standard normal