tomodrgn train_vae#
Purpose#
Train a heterogeneous tomoDRGN network (i.e. encoder and decoder modules) to learn an embedding of pre-aligned 2-D tilt-series projections to a continuous latent space, and to learn to generate unique 3-D reconstructions consistent with input images given the corresponding latent embedding.
Sample usage#
The examples below are adapted from tomodrgn/testing/commandtest*.py
, and rely on other outputs from commandtest.py
to execute successfully.
# Warp v1 style inputs
tomodrgn \
train_vae \
data/10076_both_32_sim.star \
--source-software cryosrpnt \ # tomoDRGN tries to automatically infer the software used to export particles, but allows this value to be set explicitly
--outdir output/vae_both_sim_zdim8_dosetiltweightmask_batchsize8 \
--zdim 8 \
--uninvert-data \
--num-epochs 40 \
--l-dose-mask \
--recon-dose-weight \
--recon-tilt-weight \
--batch-size 8
# WarpTools style inputs
tomodrgn \
train_vae \
data/warptools_test_4-tomos_10-ptcls_box-32_angpix-12_optimisation_set.star \
--outdir output/vae_warptools_70S_zdim8_dosetiltweightmask_batchsize8 \
--zdim 8 \
--uninvert-data
--num-epochs 40 \
--l-dose-mask \
--recon-dose-weight \
--recon-tilt-weight \
--batch-size 8 \
--lazy \ # note: lazy is used because the separate .mrcs file per particle, as used by WarpTools, is well suited to lazy loading
--num-workers 2 \ # note: num-workers, prefetch-factor, and persistent-workers are best used only if lazy is enabled to avoid excessive memory utilization
--prefetch-factor 2 \
--persistent-workers
Arguments#
usage: train_vae [-h] -o OUTDIR --zdim ZDIM [--load WEIGHTS.PKL]
[--checkpoint CHECKPOINT] [--log-interval LOG_INTERVAL] [-v]
[--seed SEED] [--plot-format {png,svgz}]
[--source-software {auto,warp,cryosrpnt,nextpyp,cistem,warptools,relion}]
[--ind-ptcls PKL] [--ind-imgs IND_IMGS]
[--sort-ptcl-imgs {unsorted,dose_ascending,random}]
[--use-first-ntilts USE_FIRST_NTILTS]
[--use-first-nptcls USE_FIRST_NPTCLS]
[--fraction-train FRACTION_TRAIN]
[--show-summary-stats SHOW_SUMMARY_STATS] [--uninvert-data]
[--no-window] [--window-r WINDOW_R]
[--window-r-outer WINDOW_R_OUTER] [--datadir DATADIR]
[--lazy] [--sequential-tilt-sampling] [--recon-tilt-weight]
[--recon-dose-weight] [--l-dose-mask] [-n NUM_EPOCHS]
[-b BATCH_SIZE] [--wd WD] [--lr LR] [--beta BETA]
[--beta-control BETA_CONTROL] [--norm NORM NORM] [--no-amp]
[--multigpu] [--enc-layers-A QLAYERSA] [--enc-dim-A QDIMA]
[--out-dim-A OUT_DIM_A] [--enc-layers-B QLAYERSB]
[--enc-dim-B QDIMB] [--enc-mask ENC_MASK]
[--pooling-function {concatenate,max,mean,median,set_encoder}]
[--num-seeds NUM_SEEDS] [--num-heads NUM_HEADS]
[--layer-norm] [--dec-layers PLAYERS] [--dec-dim PDIM]
[--l-extent L_EXTENT]
[--pe-type {geom_ft,geom_full,geom_lowf,geom_nohighf,linear_lowf,gaussian,none}]
[--feat-sigma FEAT_SIGMA] [--pe-dim PE_DIM]
[--activation {relu,leaky_relu}] [--num-workers NUM_WORKERS]
[--prefetch-factor PREFETCH_FACTOR] [--persistent-workers]
[--pin-memory]
particles
Positional Arguments#
- particles
Input particles_imageseries.star (if using Warp/M or NextPYP), or optimisation set star file (if using WarpTools or RELION v5)
Core arguments#
- -o, --outdir
Output directory to save model
- --zdim
Dimension of latent variable
Default:
128
- --load
Initialize training from a checkpoint
- --checkpoint
Checkpointing interval in N_EPOCHS (default: 1)
Default:
1
- --log-interval
Logging interval in N_PTCLS (default: 200)
Default:
200
- -v, --verbose
Increases verbosity
Default:
False
- --seed
Random seed
Default:
76166
- --plot-format
Possible choices: png, svgz
File format with which to save plots
Default:
'png'
Particle starfile loading and filtering#
- --source-software
Possible choices: auto, warp, cryosrpnt, nextpyp, cistem, warptools, relion
Manually set the software used to extract particles. Default is to auto-detect.
Default:
'auto'
- --ind-ptcls
Filter starfile by particles (unique rlnGroupName values) using np array pkl as indices
- --ind-imgs
Filter starfile by particle images (star file rows) using np array pkl as indices
- --sort-ptcl-imgs
Possible choices: unsorted, dose_ascending, random
Sort the star file images on a per-particle basis by the specified criteria
Default:
'unsorted'
- --use-first-ntilts
Keep the first use_first_ntilts images of each particle in the sorted star file.Default -1 means to use all. Will drop particles with fewer than this many tilt images.
Default:
-1
- --use-first-nptcls
Keep the first use_first_nptcls particles in the sorted star file. Default -1 means to use all.
Default:
-1
Particle starfile train/test split#
- --fraction-train
Derive new train/test split with this fraction of each particles images assigned to train
Default:
1.0
- --show-summary-stats
Log distribution statistics of particle sampling for test/train splits
Default:
True
Dataset loading and preprocessing#
- --uninvert-data
Do not invert data sign
Default:
True
- --no-window
Turn off real space windowing of dataset
Default:
True
- --window-r
Real space inner windowing radius to begin cosine falloff
Default:
0.8
- --window-r-outer
Real space outer windowing radius to end cosine falloff
Default:
0.9
- --datadir
Path prefix to particle stack if loading relative paths from a .star file
- --lazy
Lazy loading if full dataset is too large to fit in memory (Should copy dataset to SSD)
Default:
False
- --sequential-tilt-sampling
Supply particle images of one particle to encoder in filtered starfile order
Default:
False
Weighting and masking#
- --recon-tilt-weight
Weight reconstruction loss by cosine(tilt_angle)
Default:
False
- --recon-dose-weight
Weight reconstruction loss per tilt per pixel by dose dependent amplitude attenuation
Default:
False
- --l-dose-mask
Do not train on frequencies exposed to > 2.5x critical dose. Training lattice is intersection of this with –l-extent
Default:
False
Training parameters#
- -n, --num-epochs
Number of training epochs
Default:
20
- -b, --batch-size
Minibatch size
Default:
1
- --wd
Weight decay in Adam optimizer
Default:
0
- --lr
Learning rate in Adam optimizer for batch size 1. Is automatically further scaled as square-root of batch size.
Default:
0.0001
- --beta
Choice of beta schedule or a constant for KLD weight
- --beta-control
KL-Controlled VAE gamma. Beta is KL target
- --norm
Data normalization as shift, 1/scale (default: 0, std of dataset)
- --no-amp
Disable use of mixed-precision training
Default:
False
- --multigpu
Parallelize training across all detected GPUs
Default:
False
Encoder Network#
- --enc-layers-A
Number of hidden layers for each tilt
Default:
3
- --enc-dim-A
Number of nodes in hidden layers for each tilt
Default:
256
- --out-dim-A
Number of nodes in output layer of encA == ntilts * number of nodes input to encB
Default:
128
- --enc-layers-B
Number of hidden layers encoding merged tilts
Default:
3
- --enc-dim-B
Number of nodes in hidden layers encoding merged tilts
Default:
256
- --enc-mask
Diameter of circular mask of image for encoder in pixels (default: boxsize+1 to use up to Nyquist; -1 for no mask)
- --pooling-function
Possible choices: concatenate, max, mean, median, set_encoder
Function used to pool features along ntilts dimension after encA
Default:
'concatenate'
- --num-seeds
number of seeds for PMA
Default:
1
- --num-heads
number of heads for multi head attention blocks
Default:
4
- --layer-norm
whether to apply layer normalization in the set transformer block
Default:
False
Decoder Network#
- --dec-layers
Number of hidden layers
Default:
3
- --dec-dim
Number of nodes in hidden layers
Default:
256
- --l-extent
Coordinate lattice size (if not using positional encoding) (default: 0.5)
Default:
0.5
- --pe-type
Possible choices: geom_ft, geom_full, geom_lowf, geom_nohighf, linear_lowf, gaussian, none
Type of positional encoding
Default:
'gaussian'
- --feat-sigma
Scale for random Gaussian features
Default:
0.5
- --pe-dim
Num features in positional encoding
- --activation
Possible choices: relu, leaky_relu
Activation
Default:
'relu'
Dataloader arguments#
- --num-workers
Number of workers to use when batching particles for training. Has moderate impact on epoch time
Default:
0
- --prefetch-factor
Number of particles to prefetch per worker for training. Has moderate impact on epoch time
- --persistent-workers
Whether to persist workers after dataset has been fully consumed. Has minimal impact on run time
Default:
False
- --pin-memory
Whether to use pinned memory for dataloader. Has large impact on epoch time. Recommended.
Default:
False
Common next steps#
Assess model convergence with
tomodrgn convergence_vae
Analyze model at a particular epoch in latent space with
tomodrgn analyze
Analyze model at a particular epoch in volume space with
tomodrgn analyze_volumes
Generate volumes for all particles at a particular epoch with
tomodrgn eval_vol
Embed a (potentially related) dataset of images into the learned latent space with
tomodrgn eval_images
Map back generated volumes (for all particles) to source tomograms to explore spatially contextuallized heterogeneity with
tomodrgn subtomo2chimerax