tomodrgn train_nn#
Purpose#
Train a homogeneous tomoDRGN network (i.e. decoder-only) to generate a “consensus” 3-D reconstruction from pre-aligned 2-D tilt-series projections.
Sample usage#
The examples below are adapted from tomodrgn/testing/commandtest*.py
, and rely on other outputs from commandtest.py
to execute successfully.
# Warp v1 style inputs
tomodrgn \
train_nn \
data/10076_classE_32_sim.star \
--source-software cryosrpnt \ # tomoDRGN tries to automatically infer the software used to export particles, but allows this value to be set explicitly
--outdir output/nn_classE_sim \
--uninvert-data \
--num-epochs 40 \
--l-dose-mask \
--recon-dose-weight \
--recon-tilt-weight
# WarpTools style inputs
tomodrgn \
train_nn \
data/warptools_test_4-tomos_10-ptcls_box-32_angpix-12_optimisation_set.star \
--outdir output/nn_warptools_70S_dosetiltweightmask \
--uninvert-data \
--num-epochs 40 \
--l-dose-mask \
--recon-dose-weight \
--recon-tilt-weight \
--lazy \ # note: lazy is used because the separate .mrcs file per particle, as used by WarpTools, is well suited to lazy loading
--num-workers 2 \ # note: num-workers, prefetch-factor, and persistent-workers are best used only if lazy is enabled to avoid excessive memory utilization
--prefetch-factor 2 \
--persistent-workers
Arguments#
usage: train_nn [-h] --outdir OUTDIR [--load WEIGHTS.PKL]
[--checkpoint CHECKPOINT] [--log-interval LOG_INTERVAL]
[--verbose] [--seed SEED] [--plot-format {png,svgz}]
[--source-software {auto,warp,cryosrpnt,nextpyp,cistem,warptools,relion}]
[--ind-ptcls PKL] [--ind-imgs IND_IMGS]
[--sort-ptcl-imgs {unsorted,dose_ascending,random}]
[--use-first-ntilts USE_FIRST_NTILTS]
[--use-first-nptcls USE_FIRST_NPTCLS] [--uninvert-data]
[--no-window] [--window-r WINDOW_R]
[--window-r-outer WINDOW_R_OUTER] [--datadir DATADIR] [--lazy]
[--sequential-tilt-sampling] [--recon-tilt-weight]
[--recon-dose-weight] [--l-dose-mask] [-n NUM_EPOCHS]
[-b BATCH_SIZE] [--wd WD] [--lr LR] [--norm NORM NORM]
[--no-amp] [--multigpu] [--layers LAYERS] [--dim DIM]
[--l-extent L_EXTENT]
[--pe-type {geom_ft,geom_full,geom_lowf,geom_nohighf,linear_lowf,gaussian,none}]
[--pe-dim PE_DIM] [--activation {relu,leaky_relu}]
[--feat-sigma FEAT_SIGMA] [--num-workers NUM_WORKERS]
[--prefetch-factor PREFETCH_FACTOR] [--persistent-workers]
[--pin-memory]
particles
Positional Arguments#
- particles
Input particles (.mrcs, .star, or .txt)
Core arguments#
- --outdir
Output directory to save model
- --load
Initialize training from a checkpoint
- --checkpoint
Checkpointing interval in N_EPOCHS (default: 1)
Default:
1
- --log-interval
Logging interval in N_PTCLS (default: 200)
Default:
200
- --verbose
Increases verbosity
Default:
False
- --seed
Random seed
Default:
5135
- --plot-format
Possible choices: png, svgz
File format with which to save plots
Default:
'png'
Particle starfile loading and filtering#
- --source-software
Possible choices: auto, warp, cryosrpnt, nextpyp, cistem, warptools, relion
Manually set the software used to extract particles. Default is to auto-detect.
Default:
'auto'
- --ind-ptcls
Filter starfile by particles (unique rlnGroupName values) using np array pkl as indices
- --ind-imgs
Filter starfile by particle images (star file rows) using np array pkl as indices
- --sort-ptcl-imgs
Possible choices: unsorted, dose_ascending, random
Sort the star file images on a per-particle basis by the specified criteria
Default:
'unsorted'
- --use-first-ntilts
Keep the first use_first_ntilts images of each particle in the sorted star file.Default -1 means to use all. Will drop particles with fewer than this many tilt images.
Default:
-1
- --use-first-nptcls
Keep the first use_first_nptcls particles in the sorted star file. Default -1 means to use all.
Default:
-1
Dataset loading and preprocessing#
- --uninvert-data
Do not invert data sign
Default:
True
- --no-window
Turn off real space windowing of dataset
Default:
True
- --window-r
Real space inner windowing radius to begin cosine falloff
Default:
0.8
- --window-r-outer
Real space outer windowing radius to end cosine falloff
Default:
0.9
- --datadir
Path prefix to particle stack if loading relative paths from a .star file
- --lazy
Lazy loading if full dataset is too large to fit in memory (Should copy dataset to SSD)
Default:
False
- --sequential-tilt-sampling
Supply particle images of one particle to encoder in filtered starfile order
Default:
False
Weighting and masking#
- --recon-tilt-weight
Weight reconstruction loss by cosine(tilt_angle)
Default:
False
- --recon-dose-weight
Weight reconstruction loss per tilt per pixel by dose dependent amplitude attenuation
Default:
False
- --l-dose-mask
Do not train on frequencies exposed to > 2.5x critical dose. Training lattice is intersection of this with –l-extent
Default:
False
Training parameters#
- -n, --num-epochs
Number of training epochs
Default:
20
- -b, --batch-size
Minibatch size
Default:
1
- --wd
Weight decay in Adam optimizer
Default:
0
- --lr
Learning rate in Adam optimizer for batch size 1. Is automatically further scaled as square-root of batch size.
Default:
0.0001
- --norm
Data normalization as shift, 1/scale (default: mean, std of dataset)
- --no-amp
Disable use of mixed-precision training
Default:
False
- --multigpu
Parallelize training across all detected GPUs. Specify GPUs i,j via export CUDA_VISIBLE_DEVICES=i,j before tomodrgn train_vae
Default:
False
Network Architecture#
- --layers
Number of hidden layers
Default:
3
- --dim
Number of nodes in hidden layers
Default:
512
- --l-extent
Coordinate lattice size (if not using positional encoding)
Default:
0.5
- --pe-type
Possible choices: geom_ft, geom_full, geom_lowf, geom_nohighf, linear_lowf, gaussian, none
Type of positional encoding
Default:
'gaussian'
- --pe-dim
Num sinusoid features in positional encoding (default: D/2)
- --activation
Possible choices: relu, leaky_relu
Activation
Default:
'relu'
- --feat-sigma
Scale for random Gaussian features
Default:
0.5
Dataloader arguments#
- --num-workers
Number of workers to use when batching particles for training. Has moderate impact on epoch time
Default:
0
- --prefetch-factor
Number of particles to prefetch per worker for training. Has moderate impact on epoch time
- --persistent-workers
Whether to persist workers after dataset has been fully consumed. Has minimal impact on run time
Default:
False
- --pin-memory
Whether to use pinned memory for dataloader. Has large impact on epoch time. Recommended.
Default:
False
Common next steps#
Assess model convergence with
tomodrgn convergence_nn