tomodrgn train_nn#

Purpose#

Train a homogeneous tomoDRGN network (i.e. decoder-only) to generate a “consensus” 3-D reconstruction from pre-aligned 2-D tilt-series projections.

Sample usage#

The examples below are adapted from tomodrgn/testing/commandtest*.py, and rely on other outputs from commandtest.py to execute successfully.

# Warp v1 style inputs
tomodrgn \
    train_nn \
    data/10076_classE_32_sim.star \
    --source-software cryosrpnt \  # tomoDRGN tries to automatically infer the software used to export particles, but allows this value to be set explicitly
    --outdir output/nn_classE_sim \
    --uninvert-data \
    --num-epochs 40 \
    --l-dose-mask \
    --recon-dose-weight \
    --recon-tilt-weight

# WarpTools style inputs
tomodrgn \
    train_nn \
    data/warptools_test_4-tomos_10-ptcls_box-32_angpix-12_optimisation_set.star \
    --outdir output/nn_warptools_70S_dosetiltweightmask \
    --uninvert-data \
    --num-epochs 40 \
    --l-dose-mask \
    --recon-dose-weight \
    --recon-tilt-weight \
    --lazy \  # note: lazy is used because the separate .mrcs file per particle, as used by WarpTools, is well suited to lazy loading
    --num-workers 2 \  # note: num-workers, prefetch-factor, and persistent-workers are best used only if lazy is enabled to avoid excessive memory utilization
    --prefetch-factor 2 \
    --persistent-workers

Arguments#

usage: train_nn [-h] --outdir OUTDIR [--load WEIGHTS.PKL]
                [--checkpoint CHECKPOINT] [--log-interval LOG_INTERVAL]
                [--verbose] [--seed SEED] [--plot-format {png,svgz}]
                [--source-software {auto,warp,cryosrpnt,nextpyp,cistem,warptools,relion}]
                [--ind-ptcls PKL] [--ind-imgs IND_IMGS]
                [--sort-ptcl-imgs {unsorted,dose_ascending,random}]
                [--use-first-ntilts USE_FIRST_NTILTS]
                [--use-first-nptcls USE_FIRST_NPTCLS] [--uninvert-data]
                [--no-window] [--window-r WINDOW_R]
                [--window-r-outer WINDOW_R_OUTER] [--datadir DATADIR] [--lazy]
                [--sequential-tilt-sampling] [--recon-tilt-weight]
                [--recon-dose-weight] [--l-dose-mask] [-n NUM_EPOCHS]
                [-b BATCH_SIZE] [--wd WD] [--lr LR] [--norm NORM NORM]
                [--no-amp] [--multigpu] [--layers LAYERS] [--dim DIM]
                [--l-extent L_EXTENT]
                [--pe-type {geom_ft,geom_full,geom_lowf,geom_nohighf,linear_lowf,gaussian,none}]
                [--pe-dim PE_DIM] [--activation {relu,leaky_relu}]
                [--feat-sigma FEAT_SIGMA] [--num-workers NUM_WORKERS]
                [--prefetch-factor PREFETCH_FACTOR] [--persistent-workers]
                [--pin-memory]
                particles

Positional Arguments#

particles

Input particles (.mrcs, .star, or .txt)

Core arguments#

--outdir

Output directory to save model

--load

Initialize training from a checkpoint

--checkpoint

Checkpointing interval in N_EPOCHS (default: 1)

Default: 1

--log-interval

Logging interval in N_PTCLS (default: 200)

Default: 200

--verbose

Increases verbosity

Default: False

--seed

Random seed

Default: 5135

--plot-format

Possible choices: png, svgz

File format with which to save plots

Default: 'png'

Particle starfile loading and filtering#

--source-software

Possible choices: auto, warp, cryosrpnt, nextpyp, cistem, warptools, relion

Manually set the software used to extract particles. Default is to auto-detect.

Default: 'auto'

--ind-ptcls

Filter starfile by particles (unique rlnGroupName values) using np array pkl as indices

--ind-imgs

Filter starfile by particle images (star file rows) using np array pkl as indices

--sort-ptcl-imgs

Possible choices: unsorted, dose_ascending, random

Sort the star file images on a per-particle basis by the specified criteria

Default: 'unsorted'

--use-first-ntilts

Keep the first use_first_ntilts images of each particle in the sorted star file.Default -1 means to use all. Will drop particles with fewer than this many tilt images.

Default: -1

--use-first-nptcls

Keep the first use_first_nptcls particles in the sorted star file. Default -1 means to use all.

Default: -1

Dataset loading and preprocessing#

--uninvert-data

Do not invert data sign

Default: True

--no-window

Turn off real space windowing of dataset

Default: True

--window-r

Real space inner windowing radius to begin cosine falloff

Default: 0.8

--window-r-outer

Real space outer windowing radius to end cosine falloff

Default: 0.9

--datadir

Path prefix to particle stack if loading relative paths from a .star file

--lazy

Lazy loading if full dataset is too large to fit in memory (Should copy dataset to SSD)

Default: False

--sequential-tilt-sampling

Supply particle images of one particle to encoder in filtered starfile order

Default: False

Weighting and masking#

--recon-tilt-weight

Weight reconstruction loss by cosine(tilt_angle)

Default: False

--recon-dose-weight

Weight reconstruction loss per tilt per pixel by dose dependent amplitude attenuation

Default: False

--l-dose-mask

Do not train on frequencies exposed to > 2.5x critical dose. Training lattice is intersection of this with –l-extent

Default: False

Training parameters#

-n, --num-epochs

Number of training epochs

Default: 20

-b, --batch-size

Minibatch size

Default: 1

--wd

Weight decay in Adam optimizer

Default: 0

--lr

Learning rate in Adam optimizer for batch size 1. Is automatically further scaled as square-root of batch size.

Default: 0.0001

--norm

Data normalization as shift, 1/scale (default: mean, std of dataset)

--no-amp

Disable use of mixed-precision training

Default: False

--multigpu

Parallelize training across all detected GPUs. Specify GPUs i,j via export CUDA_VISIBLE_DEVICES=i,j before tomodrgn train_vae

Default: False

Network Architecture#

--layers

Number of hidden layers

Default: 3

--dim

Number of nodes in hidden layers

Default: 512

--l-extent

Coordinate lattice size (if not using positional encoding)

Default: 0.5

--pe-type

Possible choices: geom_ft, geom_full, geom_lowf, geom_nohighf, linear_lowf, gaussian, none

Type of positional encoding

Default: 'gaussian'

--pe-dim

Num sinusoid features in positional encoding (default: D/2)

--activation

Possible choices: relu, leaky_relu

Activation

Default: 'relu'

--feat-sigma

Scale for random Gaussian features

Default: 0.5

Dataloader arguments#

--num-workers

Number of workers to use when batching particles for training. Has moderate impact on epoch time

Default: 0

--prefetch-factor

Number of particles to prefetch per worker for training. Has moderate impact on epoch time

--persistent-workers

Whether to persist workers after dataset has been fully consumed. Has minimal impact on run time

Default: False

--pin-memory

Whether to use pinned memory for dataloader. Has large impact on epoch time. Recommended.

Default: False

Common next steps#

  • Assess model convergence with tomodrgn convergence_nn