tomodrgn train_vae#


Train a heterogeneous tomoDRGN network (i.e. encoder and decoder modules) to learn an embedding of pre-aligned 2-D tilt-series projections to a continuous latent space, and to learn to generate unique 3-D reconstructions consistent with input images given the corresponding latent embedding.

Sample usage#

The examples below are adapted from tomodrgn/testing/commandtest*.py, and rely on other outputs from to execute successfully.

# Warp v1 style inputs
tomodrgn \
    train_vae \
    data/ \
    --source-software cryosrpnt \  # tomoDRGN tries to automatically infer the software used to export particles, but allows this value to be set explicitly
    --outdir output/vae_both_sim_zdim8_dosetiltweightmask_batchsize8 \
    --zdim 8 \
    --uninvert-data \
    --num-epochs 40 \
    --l-dose-mask \
    --recon-dose-weight \
    --recon-tilt-weight \
    --batch-size 8

# WarpTools style inputs
tomodrgn \
    train_vae \
    data/ \
    --outdir output/vae_warptools_70S_zdim8_dosetiltweightmask_batchsize8 \
    --zdim 8 \
    --num-epochs 40 \
    --l-dose-mask \
    --recon-dose-weight \
    --recon-tilt-weight \
    --batch-size 8 \
    --lazy \  # note: lazy is used because the separate .mrcs file per particle, as used by WarpTools, is well suited to lazy loading
    --num-workers 2 \  # note: num-workers, prefetch-factor, and persistent-workers are best used only if lazy is enabled to avoid excessive memory utilization
    --prefetch-factor 2 \


usage: train_vae [-h] -o OUTDIR --zdim ZDIM [--load WEIGHTS.PKL]
                 [--checkpoint CHECKPOINT] [--log-interval LOG_INTERVAL] [-v]
                 [--seed SEED] [--plot-format {png,svgz}]
                 [--source-software {auto,warp,cryosrpnt,nextpyp,cistem,warptools,relion}]
                 [--ind-ptcls PKL] [--ind-imgs IND_IMGS]
                 [--sort-ptcl-imgs {unsorted,dose_ascending,random}]
                 [--use-first-ntilts USE_FIRST_NTILTS]
                 [--use-first-nptcls USE_FIRST_NPTCLS]
                 [--fraction-train FRACTION_TRAIN]
                 [--show-summary-stats SHOW_SUMMARY_STATS] [--uninvert-data]
                 [--no-window] [--window-r WINDOW_R]
                 [--window-r-outer WINDOW_R_OUTER] [--datadir DATADIR]
                 [--lazy] [--sequential-tilt-sampling] [--recon-tilt-weight]
                 [--recon-dose-weight] [--l-dose-mask] [-n NUM_EPOCHS]
                 [-b BATCH_SIZE] [--wd WD] [--lr LR] [--beta BETA]
                 [--beta-control BETA_CONTROL] [--norm NORM NORM] [--no-amp]
                 [--multigpu] [--enc-layers-A QLAYERSA] [--enc-dim-A QDIMA]
                 [--out-dim-A OUT_DIM_A] [--enc-layers-B QLAYERSB]
                 [--enc-dim-B QDIMB] [--enc-mask ENC_MASK]
                 [--pooling-function {concatenate,max,mean,median,set_encoder}]
                 [--num-seeds NUM_SEEDS] [--num-heads NUM_HEADS]
                 [--layer-norm] [--dec-layers PLAYERS] [--dec-dim PDIM]
                 [--l-extent L_EXTENT]
                 [--pe-type {geom_ft,geom_full,geom_lowf,geom_nohighf,linear_lowf,gaussian,none}]
                 [--feat-sigma FEAT_SIGMA] [--pe-dim PE_DIM]
                 [--activation {relu,leaky_relu}] [--num-workers NUM_WORKERS]
                 [--prefetch-factor PREFETCH_FACTOR] [--persistent-workers]

Positional Arguments#


Input (if using Warp/M or NextPYP), or optimisation set star file (if using WarpTools or RELION v5)

Core arguments#

-o, --outdir

Output directory to save model


Dimension of latent variable

Default: 128


Initialize training from a checkpoint


Checkpointing interval in N_EPOCHS (default: 1)

Default: 1


Logging interval in N_PTCLS (default: 200)

Default: 200

-v, --verbose

Increases verbosity

Default: False


Random seed

Default: 76166


Possible choices: png, svgz

File format with which to save plots

Default: 'png'

Particle starfile loading and filtering#


Possible choices: auto, warp, cryosrpnt, nextpyp, cistem, warptools, relion

Manually set the software used to extract particles. Default is to auto-detect.

Default: 'auto'


Filter starfile by particles (unique rlnGroupName values) using np array pkl as indices


Filter starfile by particle images (star file rows) using np array pkl as indices


Possible choices: unsorted, dose_ascending, random

Sort the star file images on a per-particle basis by the specified criteria

Default: 'unsorted'


Keep the first use_first_ntilts images of each particle in the sorted star file.Default -1 means to use all. Will drop particles with fewer than this many tilt images.

Default: -1


Keep the first use_first_nptcls particles in the sorted star file. Default -1 means to use all.

Default: -1

Particle starfile train/test split#


Derive new train/test split with this fraction of each particles images assigned to train

Default: 1.0


Log distribution statistics of particle sampling for test/train splits

Default: True

Dataset loading and preprocessing#


Do not invert data sign

Default: True


Turn off real space windowing of dataset

Default: True


Real space inner windowing radius to begin cosine falloff

Default: 0.8


Real space outer windowing radius to end cosine falloff

Default: 0.9


Path prefix to particle stack if loading relative paths from a .star file


Lazy loading if full dataset is too large to fit in memory (Should copy dataset to SSD)

Default: False


Supply particle images of one particle to encoder in filtered starfile order

Default: False

Weighting and masking#


Weight reconstruction loss by cosine(tilt_angle)

Default: False


Weight reconstruction loss per tilt per pixel by dose dependent amplitude attenuation

Default: False


Do not train on frequencies exposed to > 2.5x critical dose. Training lattice is intersection of this with –l-extent

Default: False

Training parameters#

-n, --num-epochs

Number of training epochs

Default: 20

-b, --batch-size

Minibatch size

Default: 1


Weight decay in Adam optimizer

Default: 0


Learning rate in Adam optimizer for batch size 1. Is automatically further scaled as square-root of batch size.

Default: 0.0001


Choice of beta schedule or a constant for KLD weight


KL-Controlled VAE gamma. Beta is KL target


Data normalization as shift, 1/scale (default: 0, std of dataset)


Disable use of mixed-precision training

Default: False


Parallelize training across all detected GPUs

Default: False

Encoder Network#


Number of hidden layers for each tilt

Default: 3


Number of nodes in hidden layers for each tilt

Default: 256


Number of nodes in output layer of encA == ntilts * number of nodes input to encB

Default: 128


Number of hidden layers encoding merged tilts

Default: 3


Number of nodes in hidden layers encoding merged tilts

Default: 256


Diameter of circular mask of image for encoder in pixels (default: boxsize+1 to use up to Nyquist; -1 for no mask)


Possible choices: concatenate, max, mean, median, set_encoder

Function used to pool features along ntilts dimension after encA

Default: 'concatenate'


number of seeds for PMA

Default: 1


number of heads for multi head attention blocks

Default: 4


whether to apply layer normalization in the set transformer block

Default: False

Decoder Network#


Number of hidden layers

Default: 3


Number of nodes in hidden layers

Default: 256


Coordinate lattice size (if not using positional encoding) (default: 0.5)

Default: 0.5


Possible choices: geom_ft, geom_full, geom_lowf, geom_nohighf, linear_lowf, gaussian, none

Type of positional encoding

Default: 'gaussian'


Scale for random Gaussian features

Default: 0.5


Num features in positional encoding


Possible choices: relu, leaky_relu


Default: 'relu'

Dataloader arguments#


Number of workers to use when batching particles for training. Has moderate impact on epoch time

Default: 0


Number of particles to prefetch per worker for training. Has moderate impact on epoch time


Whether to persist workers after dataset has been fully consumed. Has minimal impact on run time

Default: False


Whether to use pinned memory for dataloader. Has large impact on epoch time. Recommended.

Default: False

Common next steps#

  • Assess model convergence with tomodrgn convergence_vae

  • Analyze model at a particular epoch in latent space with tomodrgn analyze

  • Analyze model at a particular epoch in volume space with tomodrgn analyze_volumes

  • Generate volumes for all particles at a particular epoch with tomodrgn eval_vol

  • Embed a (potentially related) dataset of images into the learned latent space with tomodrgn eval_images

  • Map back generated volumes (for all particles) to source tomograms to explore spatially contextuallized heterogeneity with tomodrgn subtomo2chimerax