Source code for tomodrgn.commands.filter_star

"""
Filter a .star file by selected particle or image indices, optionally per-tomogram
"""
import argparse
import copy
import os
from typing import Literal
import warnings

import numpy as np

from tomodrgn import starfile, utils

log = utils.log


[docs] def add_args(parser: argparse.ArgumentParser | None = None) -> argparse.ArgumentParser: if parser is None: # this script is called directly; need to create a parser parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) else: # this script is called from tomodrgn.__main__ entry point, in which case a parser is already created pass parser.add_argument('input', help='Input .star file') group = parser.add_argument_group('Core arguments') group.add_argument('--starfile-type', type=str, default='imageseries', choices=('imageseries', 'volumeseries', 'optimisation_set'), help='Type of star file to filter. Select imageseries if rows correspond to particle images. Select volumeseries if rows correspond to particle volumes. ' 'Select optimisation_set if passing in an optimisation set star file.') group.add_argument('--action', choices=('keep', 'drop'), default='keep', help='keep or remove particles associated with ind.pkl') group.add_argument('--tomogram', type=str, help='optionally select by individual tomogram name (if `all` then writes individual star files per tomogram') group.add_argument('--tomo-id-col', type=str, default='_rlnMicrographName', help='Name of column in input starfile with unique values per tomogram') group.add_argument('-o', required=True, help='Output .star file (treated as output base name suffixed by tomogram name if specifying `--tomogram`).' 'The output star file name must contain the string `_optimisation_set` if the input star file is of --starfile-type optimisation_set') group = parser.add_argument_group('Index-based filtering arguments') group.add_argument('--ind', help='selected indices array (.pkl)') group.add_argument('--ind-type', choices=('particle', 'image'), default='particle', help='use indices to filter by particle (multiple images) or by image (individual images). ' 'Only relevant for imageseries star files filtered using ``--ind``') group = parser.add_argument_group('Class-label-based filtering arguments') group.add_argument('--labels', type=os.path.abspath, help='path to labels array (.pkl). The labels.pkl must contain a 1-D numpy array of integer class labels ' 'with length matching the number of particles referenced in the star file to be filtered.') group.add_argument('--labels-sel', type=int, nargs='+', help='space-separated list of integer class labels to be selected (to be kept or dropped in accordance with ``--action``)') return parser
[docs] def check_args_compatible(args: argparse.Namespace) -> None: # only one of --ind or --labels can be specified if args.ind is not None: assert args.labels is None, 'Cannot specify both `--ind` and `--labels-labels`' elif args.labels is not None: assert args.ind is None, 'Cannot specify both `--ind` and `--labels-labels`' # if --ind-type is image, --ind must be specified if args.ind_type == 'image': assert args.ind is not None, 'Filtering with --ind-type image is only supported when filtering with --ind' # if --labels is provided, number of selected labels must be greater than 0 if args.labels is not None: assert len(args.labels_sel) > 0
[docs] def filter_image_series_starfile(star_path: str, ind_path: str, labels_path: str, labels_sel: list[int], ind_type: Literal['particle', 'image'] = 'particle', ind_action: Literal['keep', 'drop'] = 'keep') -> starfile.TiltSeriesStarfile: """ Filter an imageseries star file by specified indices in-place. :param star_path: path to image series star file on disk :param ind_path: path to indices pkl file on disk :param labels_path: path to labels pkl file on disk :param labels_sel: space-separated list of integer class labels to be selected (to be kept or dropped in accordance with ``ind_action``) :param ind_type: should indices be interpreted per particle (multiple images, i.e. multiple rows of df) or per image (individual images, i.e. individual row of df) :param ind_action: are specified indices being kept in the output star file, or dropped from the output star file :return: filtered TiltSeriesStarfile object """ # load the star file star = starfile.TiltSeriesStarfile(star_path) ptcl_img_indices = star.get_ptcl_img_indices() log(f'Input star file contains {len(ptcl_img_indices)} particles consisting of {len(np.hstack(ptcl_img_indices))} images.') # establish indices to drop if ind_path: ind = utils.load_pkl(ind_path) # determine the appropriate set of indices to pass to .filter to be preserved if ind_type == 'particle': if ind_action == 'drop': # invert indices on particle level (groups of rows) ind_ptcls = np.array([i for i in range(len(ptcl_img_indices)) if i not in ind]) elif ind_action == 'keep': ind_ptcls = ind else: raise ValueError ind_imgs = None # validate indices assert ind_ptcls.max() < len(ptcl_img_indices), 'A supplied index exceeds the number of unique particles detected' assert ind_ptcls.min() >= 0, 'A supplied index is negative (which is not a valid index)' assert len(set(ind_ptcls)) == len(ind_ptcls), 'An index was specified multiple times (which is not supported)' elif ind_type == 'image': if ind_action == 'drop': # invert indices on image level (individual rows) ind_imgs = np.array([i for i in np.hstack(ptcl_img_indices) if i not in ind]) elif ind_action == 'keep': ind_imgs = ind else: raise ValueError ind_ptcls = None # validate indices assert ind_imgs.max() < len(np.hstack(ptcl_img_indices)), 'A supplied index exceeds the number of images detected' assert ind_imgs.min() >= 0, 'A supplied index is negative (which is not a valid index)' assert len(set(ind_imgs)) == len(ind_imgs), 'An index was specified multiple times (which is not supported)' else: raise ValueError elif labels_path: labels = utils.load_pkl(labels_path) # validate labels pkl assert len(labels) == len(ptcl_img_indices), f'The length of the labels array ({len(labels)} does not match the number of particles in the star file ({len(ptcl_img_indices)})' # validate selected labels labels_sel = list(set(labels_sel)) labels_sel.sort() for label_sel in labels_sel: assert label_sel in labels, f'The selected label {label_sel} was not found in the supplied labels array' # generate ind_sel from labels_sel ind_sel = np.asarray([i for i, label in enumerate(labels) if label in labels_sel]) # determine the appropriate set of indices to pass to .filter to be preserved if ind_action == 'drop': # invert indices on particle level (groups of rows) ind_ptcls = np.array([i for i in range(len(ptcl_img_indices)) if i not in ind_sel]) elif ind_action == 'keep': ind_ptcls = ind_sel else: raise ValueError ind_imgs = None else: ind_imgs = None ind_ptcls = None warnings.warn('Neither --ind nor --labels was specified, continuing without filtering to a subset of particles') # apply filtering star.filter(ind_imgs=ind_imgs, ind_ptcls=ind_ptcls) log(f'Filtered star file has {len(star.get_ptcl_img_indices())} particles consisting of {len(star.df)} images.') return star
[docs] def filter_volume_series_starfile(star_path: str, ind_path: str, labels_path: str, labels_sel: list[int], ind_action: Literal['keep', 'drop'] = 'keep') -> starfile.GenericStarfile: """ Filter a volumeseries star file by specified indices in-place. :param star_path: path to volume series star file on disk :param ind_path: path to indices pkl file on disk :param labels_path: path to labels pkl file on disk :param labels_sel: space-separated list of integer class labels to be selected (to be kept or dropped in accordance with ``ind_action``) :param ind_action: are specified indices being kept in the output star file, or dropped from the output star file :return: filtered GenericStarfile object """ # load the star file if starfile.is_starfile_optimisation_set(star_path): star = starfile.TomoParticlesStarfile(star_path) ptcl_block_name = star.block_particles df = star.df else: star = starfile.GenericStarfile(star_path) ptcl_block_name = star.identify_particles_data_block() df = star.blocks[ptcl_block_name] log(f'Input star file contains {len(df)} particles.') # establish indices to drop if ind_path is not None: ind_ptcls = utils.load_pkl(ind_path) if ind_action == 'drop': ind_ptcls_to_drop = ind_ptcls elif ind_action == 'keep': # invert indices on particle level (individual rows) ind_ptcls_to_drop = np.array([i for i in df.index.to_numpy() if i not in ind_ptcls]) else: raise ValueError # validate indices assert ind_ptcls.max() < len(df), 'A supplied index exceeds the number of unique particles detected' assert ind_ptcls.min() >= 0, 'A supplied index is negative (which is not a valid index)' assert len(set(ind_ptcls)) == len(ind_ptcls), 'An index was specified multiple times (which is not supported)' elif labels_path is not None: labels = utils.load_pkl(labels_path) # validate labels pkl assert len(labels) == len(df), f'The length of the labels array ({len(labels)} does not match the number of particles in the star file ({len(df)})' # validate selected labels labels_sel = list(set(labels_sel)) labels_sel.sort() for label_sel in labels_sel: assert label_sel in labels, f'The selected label {label_sel} was not found in the supplied labels array' # generate ind_sel from labels_sel ind_sel = np.asarray([i for i, label in enumerate(labels) if label in labels_sel]) # determine the appropriate set of indices to pass to .filter to be preserved if ind_action == 'drop': ind_ptcls_to_drop = ind_sel elif ind_action == 'keep': # invert indices on particle level (individual rows) ind_ptcls_to_drop = np.array([i for i in df.index.to_numpy() if i not in ind_sel]) else: raise ValueError else: ind_ptcls_to_drop = np.array([]) warnings.warn('Neither --ind nor --labels was specified, continuing without filtering to a subset of particles') # apply filtering df = df.drop(ind_ptcls_to_drop).reset_index(drop=True) star.blocks[ptcl_block_name] = df log(f'Filtered star file contains {len(df)} particles.') return star
[docs] def main(args): # log inputs log(args) # check that selected arguments are mutually compatible check_args_compatible(args) # filter using the appropriate type of star file if args.starfile_type == 'imageseries': star = filter_image_series_starfile(star_path=args.input, ind_path=args.ind, ind_type=args.ind_type, labels_path=args.labels, labels_sel=args.labels_sel, ind_action=args.action, ) elif args.starfile_type == 'volumeseries' or args.starfile_type == 'optimisation_set': star = filter_volume_series_starfile(star_path=args.input, ind_path=args.ind, labels_path=args.labels, labels_sel=args.labels_sel, ind_action=args.action, ) else: raise ValueError('Unknown starfile type') # write the filtered star file star.write(args.o) # apply further filtering to the specified tomograms and write corresponding star files if args.tomogram: # first find the block containing particle data and ensure the specifed column for tomogram ID is present tomo_block_name = star.identify_particles_data_block(column_substring=args.tomo_id_col) if args.tomogram == 'all': # write each tomo's starfile out separately tomos_to_write = star.blocks[tomo_block_name][args.tomo_id_col].unique() else: # alternatively, specify one tomogram to preserve in output star file tomos_to_write = [args.tomogram] for tomo_name in tomos_to_write: # filter a copy of the star file to the requested tomogram name star_copy_this_tomo = copy.deepcopy(star) star_copy_this_tomo.blocks[tomo_block_name] = star_copy_this_tomo.blocks[tomo_block_name][star_copy_this_tomo.blocks[tomo_block_name][args.tomo_id_col].str.contains(tomo_name)] # write the star file print(f'{len(star_copy_this_tomo.blocks[tomo_block_name])} rows after filtering by tomogram {tomo_name}') if args.o.endswith('.star'): outpath = args.o.split('.star')[0] else: outpath = args.o star_copy_this_tomo.write(f'{outpath}_{tomo_name}.star')
if __name__ == '__main__': main(add_args().parse_args())