tomodrgn.commands.train_vae.save_checkpoint#

save_checkpoint(*, model: TiltSeriesHetOnlyVAE | DataParallelPassthrough, scaler: GradScaler, optim: Optimizer, epoch: int, z_mu_train: ndarray, z_logvar_train: ndarray, out_weights: str, out_z_train: str, z_mu_test: ndarray | None = None, z_logvar_test: ndarray | None = None, out_z_test: str | None = None) None[source]#

Save model weights and latent encoding z

Parameters:
  • model – TiltSeriesHetOnlyVAE object used for model training and evaluation

  • scaler – GradScaler object used for scaling loss involving fp16 tensors to avoid over/underflow

  • optim – torch.optim.Optimizer object used for optimizing the model

  • epoch – epoch count at which checkpoint is being saved

  • z_mu_train – array of latent embedding means for the dataset train split

  • z_logvar_train – array of latent embedding log variances for the dataset train split

  • out_weights – name of output file to save model, optimizer, and scaler state dicts

  • out_z_train – name of output file to save latent embeddings for dataset train split

  • z_mu_test – array of latent embedding log variances for the dataset test split

  • z_logvar_test – array of latent embedding log variances for the dataset test split

  • out_z_test – name of output file to save latent embeddings for dataset test split

Returns:

None