tomodrgn train_vae#

Purpose#

Train a heterogeneous tomoDRGN network (i.e. encoder and decoder modules) to learn an embedding of pre-aligned 2-D tilt-series projections to a continuous latent space, and to learn to generate unique 3-D reconstructions consistent with input images given the corresponding latent embedding.

Sample usage#

The examples below are adapted from tomodrgn/testing/commandtest*.py, and rely on other outputs from commandtest.py to execute successfully.

# Warp v1 style inputs
tomodrgn \
    train_vae \
    data/10076_both_32_sim.star \
    --source-software cryosrpnt \  # tomoDRGN tries to automatically infer the software used to export particles, but allows this value to be set explicitly
    --outdir output/vae_both_sim_zdim8_dosetiltweightmask_batchsize8 \
    --zdim 8 \
    --uninvert-data \
    --num-epochs 40 \
    --l-dose-mask \
    --recon-dose-weight \
    --recon-tilt-weight \
    --batch-size 8

# WarpTools style inputs
tomodrgn \
    train_vae \
    data/warptools_test_4-tomos_10-ptcls_box-32_angpix-12_optimisation_set.star \
    --outdir output/vae_warptools_70S_zdim8_dosetiltweightmask_batchsize8 \
    --zdim 8 \
    --uninvert-data
    --num-epochs 40 \
    --l-dose-mask \
    --recon-dose-weight \
    --recon-tilt-weight \
    --batch-size 8 \
    --lazy \  # note: lazy is used because the separate .mrcs file per particle, as used by WarpTools, is well suited to lazy loading
    --num-workers 2 \  # note: num-workers, prefetch-factor, and persistent-workers are best used only if lazy is enabled to avoid excessive memory utilization
    --prefetch-factor 2 \
    --persistent-workers

Arguments#

usage: train_vae [-h] -o OUTDIR --zdim ZDIM [--load WEIGHTS.PKL]
                 [--checkpoint CHECKPOINT] [--log-interval LOG_INTERVAL] [-v]
                 [--seed SEED] [--plot-format {png,svgz}]
                 [--source-software {auto,warp,cryosrpnt,nextpyp,cistem,warptools,relion}]
                 [--ind-ptcls PKL] [--ind-imgs IND_IMGS]
                 [--sort-ptcl-imgs {unsorted,dose_ascending,random}]
                 [--use-first-ntilts USE_FIRST_NTILTS]
                 [--use-first-nptcls USE_FIRST_NPTCLS]
                 [--fraction-train FRACTION_TRAIN]
                 [--show-summary-stats SHOW_SUMMARY_STATS] [--uninvert-data]
                 [--no-window] [--window-r WINDOW_R]
                 [--window-r-outer WINDOW_R_OUTER] [--datadir DATADIR]
                 [--lazy] [--sequential-tilt-sampling] [--recon-tilt-weight]
                 [--recon-dose-weight] [--l-dose-mask] [-n NUM_EPOCHS]
                 [-b BATCH_SIZE] [--wd WD] [--lr LR] [--beta BETA]
                 [--beta-control BETA_CONTROL] [--norm NORM NORM] [--no-amp]
                 [--multigpu] [--enc-layers-A QLAYERSA] [--enc-dim-A QDIMA]
                 [--out-dim-A OUT_DIM_A] [--enc-layers-B QLAYERSB]
                 [--enc-dim-B QDIMB] [--enc-mask ENC_MASK]
                 [--pooling-function {concatenate,max,mean,median,set_encoder}]
                 [--num-seeds NUM_SEEDS] [--num-heads NUM_HEADS]
                 [--layer-norm] [--dec-layers PLAYERS] [--dec-dim PDIM]
                 [--l-extent L_EXTENT]
                 [--pe-type {geom_ft,geom_full,geom_lowf,geom_nohighf,linear_lowf,gaussian,none}]
                 [--feat-sigma FEAT_SIGMA] [--pe-dim PE_DIM]
                 [--activation {relu,leaky_relu}] [--num-workers NUM_WORKERS]
                 [--prefetch-factor PREFETCH_FACTOR] [--persistent-workers]
                 [--pin-memory]
                 particles

Positional Arguments#

particles

Input particles_imageseries.star (if using Warp/M or NextPYP), or optimisation set star file (if using WarpTools or RELION v5)

Core arguments#

-o, --outdir

Output directory to save model

--zdim

Dimension of latent variable

Default: 128

--load

Initialize training from a checkpoint

--checkpoint

Checkpointing interval in N_EPOCHS (default: 1)

Default: 1

--log-interval

Logging interval in N_PTCLS (default: 200)

Default: 200

-v, --verbose

Increases verbosity

Default: False

--seed

Random seed

Default: 76166

--plot-format

Possible choices: png, svgz

File format with which to save plots

Default: 'png'

Particle starfile loading and filtering#

--source-software

Possible choices: auto, warp, cryosrpnt, nextpyp, cistem, warptools, relion

Manually set the software used to extract particles. Default is to auto-detect.

Default: 'auto'

--ind-ptcls

Filter starfile by particles (unique rlnGroupName values) using np array pkl as indices

--ind-imgs

Filter starfile by particle images (star file rows) using np array pkl as indices

--sort-ptcl-imgs

Possible choices: unsorted, dose_ascending, random

Sort the star file images on a per-particle basis by the specified criteria

Default: 'unsorted'

--use-first-ntilts

Keep the first use_first_ntilts images of each particle in the sorted star file.Default -1 means to use all. Will drop particles with fewer than this many tilt images.

Default: -1

--use-first-nptcls

Keep the first use_first_nptcls particles in the sorted star file. Default -1 means to use all.

Default: -1

Particle starfile train/test split#

--fraction-train

Derive new train/test split with this fraction of each particles images assigned to train

Default: 1.0

--show-summary-stats

Log distribution statistics of particle sampling for test/train splits

Default: True

Dataset loading and preprocessing#

--uninvert-data

Do not invert data sign

Default: True

--no-window

Turn off real space windowing of dataset

Default: True

--window-r

Real space inner windowing radius to begin cosine falloff

Default: 0.8

--window-r-outer

Real space outer windowing radius to end cosine falloff

Default: 0.9

--datadir

Path prefix to particle stack if loading relative paths from a .star file

--lazy

Lazy loading if full dataset is too large to fit in memory (Should copy dataset to SSD)

Default: False

--sequential-tilt-sampling

Supply particle images of one particle to encoder in filtered starfile order

Default: False

Weighting and masking#

--recon-tilt-weight

Weight reconstruction loss by cosine(tilt_angle)

Default: False

--recon-dose-weight

Weight reconstruction loss per tilt per pixel by dose dependent amplitude attenuation

Default: False

--l-dose-mask

Do not train on frequencies exposed to > 2.5x critical dose. Training lattice is intersection of this with –l-extent

Default: False

Training parameters#

-n, --num-epochs

Number of training epochs

Default: 20

-b, --batch-size

Minibatch size

Default: 1

--wd

Weight decay in Adam optimizer

Default: 0

--lr

Learning rate in Adam optimizer for batch size 1. Is automatically further scaled as square-root of batch size.

Default: 0.0001

--beta

Choice of beta schedule or a constant for KLD weight

--beta-control

KL-Controlled VAE gamma. Beta is KL target

--norm

Data normalization as shift, 1/scale (default: 0, std of dataset)

--no-amp

Disable use of mixed-precision training

Default: False

--multigpu

Parallelize training across all detected GPUs

Default: False

Encoder Network#

--enc-layers-A

Number of hidden layers for each tilt

Default: 3

--enc-dim-A

Number of nodes in hidden layers for each tilt

Default: 256

--out-dim-A

Number of nodes in output layer of encA == ntilts * number of nodes input to encB

Default: 128

--enc-layers-B

Number of hidden layers encoding merged tilts

Default: 3

--enc-dim-B

Number of nodes in hidden layers encoding merged tilts

Default: 256

--enc-mask

Diameter of circular mask of image for encoder in pixels (default: boxsize+1 to use up to Nyquist; -1 for no mask)

--pooling-function

Possible choices: concatenate, max, mean, median, set_encoder

Function used to pool features along ntilts dimension after encA

Default: 'concatenate'

--num-seeds

number of seeds for PMA

Default: 1

--num-heads

number of heads for multi head attention blocks

Default: 4

--layer-norm

whether to apply layer normalization in the set transformer block

Default: False

Decoder Network#

--dec-layers

Number of hidden layers

Default: 3

--dec-dim

Number of nodes in hidden layers

Default: 256

--l-extent

Coordinate lattice size (if not using positional encoding) (default: 0.5)

Default: 0.5

--pe-type

Possible choices: geom_ft, geom_full, geom_lowf, geom_nohighf, linear_lowf, gaussian, none

Type of positional encoding

Default: 'gaussian'

--feat-sigma

Scale for random Gaussian features

Default: 0.5

--pe-dim

Num features in positional encoding

--activation

Possible choices: relu, leaky_relu

Activation

Default: 'relu'

Dataloader arguments#

--num-workers

Number of workers to use when batching particles for training. Has moderate impact on epoch time

Default: 0

--prefetch-factor

Number of particles to prefetch per worker for training. Has moderate impact on epoch time

--persistent-workers

Whether to persist workers after dataset has been fully consumed. Has minimal impact on run time

Default: False

--pin-memory

Whether to use pinned memory for dataloader. Has large impact on epoch time. Recommended.

Default: False

Common next steps#

  • Assess model convergence with tomodrgn convergence_vae

  • Analyze model at a particular epoch in latent space with tomodrgn analyze

  • Analyze model at a particular epoch in volume space with tomodrgn analyze_volumes

  • Generate volumes for all particles at a particular epoch with tomodrgn eval_vol

  • Embed a (potentially related) dataset of images into the learned latent space with tomodrgn eval_images

  • Map back generated volumes (for all particles) to source tomograms to explore spatially contextuallized heterogeneity with tomodrgn subtomo2chimerax