tomodrgn.commands.train_nn.loss_function#
- loss_function(*, batch_images: Tensor, batch_images_recon: Tensor, batch_ctf_weights: Tensor, batch_recon_error_weights: Tensor, batch_hartley_2d_mask: Tensor, image_ctf_premultiplied: bool, image_dose_weighted: bool) Tensor [source]#
Calculate generative loss between reconstructed and input images
- Parameters:
batch_images – Batch of images to be used for training, shape (batchsize, ntilts, boxsize_ht**2)
batch_images_recon – Reconstructed central slices of Fourier space volumes corresponding to each particle in the batch, shape (batchsize * ntilts * boxsize_ht**2 [batch_hartley_2d_mask])
batch_ctf_weights – CTF evaluated at each spatial frequency corresponding to input images, shape (batchsize, ntilts, boxsize_ht**2) or None if no CTF should be applied
batch_recon_error_weights – Batch of 2-D weights to be applied to the per-spatial-frequency error between each reconstructed slice and input image, shape (batchsize, ntilts, boxsize_ht**2). Calculated from critical dose exposure curves and electron beam vs sample tilt geometry.
batch_hartley_2d_mask – Batch of 2-D masks to be applied per-spatial-frequency. Calculated as the intersection of critical dose exposure curves and a Nyquist-limited circular mask in reciprocal space, including masking the DC component.
image_ctf_premultiplied – Whether images were multiplied by their CTF during particle extraction.
image_dose_weighted – Whether images were multiplied by their exposure-dependent frequency weighting during particle extraction. :param beta: scaling factor to apply to KLD during loss calculation.
- Returns:
generative loss between reconstructed slices and input images