tomodrgn.commands.train_vae.save_checkpoint#
- save_checkpoint(*, model: TiltSeriesHetOnlyVAE | DataParallelPassthrough, scaler: GradScaler, optim: Optimizer, epoch: int, z_mu_train: ndarray, z_logvar_train: ndarray, out_weights: str, out_z_train: str, z_mu_test: ndarray | None = None, z_logvar_test: ndarray | None = None, out_z_test: str | None = None) None [source]#
Save model weights and latent encoding z
- Parameters:
model – TiltSeriesHetOnlyVAE object used for model training and evaluation
scaler – GradScaler object used for scaling loss involving fp16 tensors to avoid over/underflow
optim – torch.optim.Optimizer object used for optimizing the model
epoch – epoch count at which checkpoint is being saved
z_mu_train – array of latent embedding means for the dataset train split
z_logvar_train – array of latent embedding log variances for the dataset train split
out_weights – name of output file to save model, optimizer, and scaler state dicts
out_z_train – name of output file to save latent embeddings for dataset train split
z_mu_test – array of latent embedding log variances for the dataset test split
z_logvar_test – array of latent embedding log variances for the dataset test split
out_z_test – name of output file to save latent embeddings for dataset test split
- Returns:
None