tomodrgn.commands.train_nn.loss_function#

loss_function(*, batch_images: Tensor, batch_images_recon: Tensor, batch_ctf_weights: Tensor, batch_recon_error_weights: Tensor, batch_hartley_2d_mask: Tensor, image_ctf_premultiplied: bool, image_dose_weighted: bool) Tensor[source]#

Calculate generative loss between reconstructed and input images

Parameters:
  • batch_images – Batch of images to be used for training, shape (batchsize, ntilts, boxsize_ht**2)

  • batch_images_recon – Reconstructed central slices of Fourier space volumes corresponding to each particle in the batch, shape (batchsize * ntilts * boxsize_ht**2 [batch_hartley_2d_mask])

  • batch_ctf_weights – CTF evaluated at each spatial frequency corresponding to input images, shape (batchsize, ntilts, boxsize_ht**2) or None if no CTF should be applied

  • batch_recon_error_weights – Batch of 2-D weights to be applied to the per-spatial-frequency error between each reconstructed slice and input image, shape (batchsize, ntilts, boxsize_ht**2). Calculated from critical dose exposure curves and electron beam vs sample tilt geometry.

  • batch_hartley_2d_mask – Batch of 2-D masks to be applied per-spatial-frequency. Calculated as the intersection of critical dose exposure curves and a Nyquist-limited circular mask in reciprocal space, including masking the DC component.

  • image_ctf_premultiplied – Whether images were multiplied by their CTF during particle extraction.

  • image_dose_weighted – Whether images were multiplied by their exposure-dependent frequency weighting during particle extraction. :param beta: scaling factor to apply to KLD during loss calculation.

Returns:

generative loss between reconstructed slices and input images