"""Visualize latent space and generate volumes"""importargparseimportosimportshutilfromdatetimeimportdatetimeasdtfromtypingimportLiteralimportmatplotlibimportmatplotlib.pyplotaspltimportnumpyasnpimportseabornassnsfromimportlib_resourcesimportfilesfromtomodrgnimportanalysis,utils,starfile,modelslog=utils.log
[docs]defadd_args(parser:argparse.ArgumentParser|None=None)->argparse.ArgumentParser:ifparserisNone:# this script is called directly; need to create a parserparser=argparse.ArgumentParser(description=__doc__,formatter_class=argparse.ArgumentDefaultsHelpFormatter)else:# this script is called from tomodrgn.__main__ entry point, in which case a parser is already createdpassparser.add_argument('workdir',type=os.path.abspath,help='Directory with tomoDRGN results')group=parser.add_argument_group('Core arguments')parser.add_argument('--epoch',type=str,default='latest',help='Epoch number N to analyze (0-based indexing, corresponding to z.N.pkl, weights.N.pkl). ''Supplying `latest` will auto-detect the latest completed epoch of training.')group.add_argument('--device',type=int,help='Optionally specify CUDA device')group.add_argument('-o','--outdir',help='Output directory for analysis results (default: [workdir]/analyze.[epoch])')group.add_argument('--skip-vol',action='store_true',help='Skip generation of volumes')group.add_argument('--skip-umap',action='store_true',help='Skip running UMAP')group.add_argument('--plot-format',type=str,choices=['png','svgz'],default='png',help='File format with which to save plots')group=parser.add_argument_group('Arguments for latent space analysis')group.add_argument('--pc',type=int,default=2,help='Number of principal component traversals to generate (default: %(default)s)')group.add_argument('--pc-ondata',action='store_true',help='Find closest on-data latent point to each PC percentile')group.add_argument('--ksample',type=int,default=20,help='Number of kmeans samples to generate (default: %(default)s)')group=parser.add_argument_group('Arguments for volume generation')group.add_argument('--downsample',type=int,help='Downsample volumes to this box size (pixels)')group.add_argument('--lowpass',type=float,default=None,help='Lowpass filter to this resolution in Å')group.add_argument('--flip',action='store_true',help='Flip handedness of output volumes')group.add_argument('--invert',action='store_true',help='Invert contrast of output volumes')returnparser
[docs]defanalyze_z_onedimensional(z:np.ndarray,outdir:str,plot_format:Literal['png','svgz'],vg:models.VolumeGenerator,skip_vol:bool=False,ondata:bool=False,downsample:int|None=None,lowpass:float|None=None,flip:bool=False,invert:bool=False,)->None:""" Plotting and volume generation for 1D z :param z: array of 1-D latent embeddings, shape (nptcls, 1) :param outdir: directory in which to save all outputs (plots and generated volumes) :param plot_format: file format with which to save plots :param vg: VolumeGenerator instance to aid volume generation at specficied z values :param skip_vol: whether to skip generation of volumes :param ondata: whether to use the closest on-data latent point to each z percentile for plotting and volume generation :param downsample: downsample reconstructed volumes to this box size (units: px) by Fourier cropping, None means to skip downsampling :param lowpass: lowpass filter reconstructed volumes to this resolution (units: Å), None means to skip lowpass filtering :param flip: flip the chirality of the reconstructed volumes by inverting along the z axis :param invert: invert the data sign of the reconstructed volumes (light-on-dark vs dark-on-light) :return: None """assertz.shape[1]==1z=z.reshape(-1)nptcls=len(z)# scatter plot of particle index against latent embeddingplt.scatter(np.arange(nptcls),z,alpha=.1,s=2,rasterized=True)plt.xlabel('particle index')plt.ylabel('z')plt.tight_layout()plt.savefig(f'{outdir}/z.{plot_format}')plt.close()# histogram of latent embeddings with KDE overlaysns.displot(z,kde=True)plt.xlabel('z')plt.tight_layout()plt.savefig(f'{outdir}/z_hist.{plot_format}')plt.close()ifnotskip_vol:# sample z values at 5th, 15th, ..., 95th percentiles of the latent distributionztraj=np.percentile(z,np.linspace(start=5,stop=95,num=10))ifondata:ztraj=analysis.get_nearest_point(z,ztraj)# histogram of latent embeddings with KDE overlaysns.displot(z,kde=True)forpercentileinztraj:plt.axvline(percentile,color='red',linestyle='-')plt.xlabel('z')plt.tight_layout()plt.savefig(f'{outdir}/z_hist_percentile_volumes.{plot_format}')plt.close()# generate corresponding volumesvg.generate_volumes(z=ztraj,out_dir=outdir,downsample=downsample,lowpass=lowpass,flip=flip,invert=invert)
[docs]defanalyze_z_multidimensional(z:np.ndarray,outdir:str,plot_format:Literal['png','svgz'],vg:models.VolumeGenerator,starfile_path:str,source_software:str,datadir:str=None,skip_vol:bool=False,skip_umap:bool=False,num_pcs:int=2,pc_ondata:int=False,num_ksamples:int=20,downsample:int|None=None,lowpass:float|None=None,flip:bool=False,invert:bool=False,)->None:""" Plotting and volume generation for multidimensional z :param z: array of 1-D latent embeddings, shape (nptcls, zdim) :param outdir: directory in which to save all outputs (plots and generated volumes) :param plot_format: file format with which to save plots :param vg: VolumeGenerator instance to aid volume generation at specficied z values :param starfile_path: path to star file used during model training through which to load images :param source_software: type of source software used to create the star file, used to indicate the appropriate star file handling class to instantiate. Default of 'auto' tries to infer the appropriate star file handling class based on whether ``star_path`` is an optimisation set star file. :param datadir: path to particle images on disk, used when plotting images per kmeans class :param skip_vol: whether to skip generation of volumes :param skip_umap: whether to skip latent embeddings UMAP dimensionality reduction :param num_pcs: number of principal components along which to generate volumes. If 0, then no PCA is performed :param pc_ondata: whether to use the closest on-data latent point to each PCA axis trajectory for plotting and volume generation :param num_ksamples: number of latent clusters to form by k-means clustering for plotting and volume generation :param downsample: downsample reconstructed volumes to this box size (units: px) by Fourier cropping, None means to skip downsampling :param lowpass: lowpass filter reconstructed volumes to this resolution (units: Å), None means to skip lowpass filtering :param flip: flip the chirality of the reconstructed volumes by inverting along the z axis :param invert: invert the data sign of the reconstructed volumes (light-on-dark vs dark-on-light) :return: None """zdim=z.shape[1]# Principal component analysislog('Perfoming principal component analysis ...')pc,pca=analysis.run_pca(z)z_trajectories=[]foriinrange(num_pcs):os.mkdir(f'{outdir}/pc{i+1}')z_pc_trajectory=np.percentile(pc[:,i],np.linspace(start=5,stop=95,num=10))z_trajectory=analysis.get_pc_traj(pca=pca,dim=i+1,sampling_points=z_pc_trajectory)ifpc_ondata:z_trajectory,z_pc_ind=analysis.get_nearest_point(z,z_trajectory)np.savetxt(f'{outdir}/pc{i+1}/z_percentiles_ind.txt',z_pc_ind,fmt='%d')z_trajectories.append(z_trajectory)np.savetxt(f'{outdir}/pc{i+1}/z_percentiles.txt',z_trajectory)ifnotskip_vol:vg.generate_volumes(z=z_trajectory,out_dir=f'{outdir}/pc{i+1}',downsample=downsample,lowpass=lowpass,flip=flip,invert=invert)# K-means clusteringlog('Performing K-means clustering ...')kmeans_labels,kmeans_centers=analysis.cluster_kmeans(z,num_ksamples)kmeans_centers,kmeans_centers_ind=analysis.get_nearest_point(z,kmeans_centers)ifnotos.path.exists(f'{outdir}/kmeans{num_ksamples}'):os.mkdir(f'{outdir}/kmeans{num_ksamples}')utils.save_pkl(kmeans_labels,f'{outdir}/kmeans{num_ksamples}/labels.pkl')np.savetxt(f'{outdir}/kmeans{num_ksamples}/centers.txt',kmeans_centers)np.savetxt(f'{outdir}/kmeans{num_ksamples}/centers_ind.txt',kmeans_centers_ind,fmt='%d')ifnotskip_vol:vg.generate_volumes(z=kmeans_centers,out_dir=f'{outdir}/kmeans{num_ksamples}',downsample=downsample,lowpass=lowpass,flip=flip,invert=invert)# Make some plots using PCA transformation# bar plot PCA explained variance ratiolog('Plotting latent PCA explained variance ratio')plt.bar(np.arange(z.shape[1])+1,pca.explained_variance_ratio_)plt.xticks(np.arange(z.shape[1])+1)plt.xlabel('principal components')plt.ylabel('explained variance')plt.tight_layout()plt.savefig(f'{outdir}/z_pca_explainedvariance.{plot_format}')plt.close()# scatter plot latent PCAlog('Plotting latent principal components as scatter')g=sns.jointplot(x=pc[:,0],y=pc[:,1],alpha=.1,s=2)g.set_axis_labels('l-PC1','l-PC2')plt.tight_layout()plt.savefig(f'{outdir}/z_pca_scatter.{plot_format}')plt.close()# hexbin plot latent PCAlog('Plotting latent principal components as hexbin')g=sns.jointplot(x=pc[:,0],y=pc[:,1],kind='hex')g.set_axis_labels('l-PC1','l-PC2')plt.tight_layout()plt.savefig(f'{outdir}/z_pca_hexbin.{plot_format}')plt.close()# scatter plot latent PCA with kmeans center annotationslog('Plotting latent principal components as scatter with latent clustering kmeans-centers annotations')analysis.scatter_annotate(x=pc[:,0],y=pc[:,1],centers_ind=kmeans_centers_ind,annotate=True,labels=[f'k{i}'foriinrange(num_ksamples)])plt.xlabel('l-PC1')plt.ylabel('l-PC2')plt.tight_layout()plt.savefig(f'{outdir}/kmeans{num_ksamples}/z_pca_scatter_annotatekmeans.{plot_format}')plt.close()# hexbin plot latent PCA with kmeans center annotationslog('Plotting latent principal components as hexbin with latent clustering kmeans-centers annotations')g=analysis.scatter_annotate_hex(x=pc[:,0],y=pc[:,1],centers_ind=kmeans_centers_ind,annotate=True,labels=[f'k{i}'foriinrange(num_ksamples)])g.set_axis_labels('l-PC1','l-PC2')plt.tight_layout()plt.savefig(f'{outdir}/kmeans{num_ksamples}/z_pca_hexbin_annotatekmeans.{plot_format}')plt.close()ifnum_pcs>=2:# scatter plot latent PCA with PCA trajectory annotationslog('Plotting latent principal components as scatter with volume-sampled principal components annotations')analysis.scatter_annotate(x=pc[:,0],y=pc[:,1],centers_xy=np.vstack([pca.transform(z_trajectories[0])[:,:2],# trajectory along pc 1, trajectory is pc-dimensional so just take first two dims for plottingpca.transform(z_trajectories[1])[:,:2]]),# trajectory along pc 2, trajectory is pc-dimensional so just take first two dims for plottingannotate=True,labels=[f'l-PC1_{i}'foriinrange(len(z_trajectories[0]))]+[f'l-PC2_{i}'foriinrange(len(z_trajectories[1]))])plt.xlabel('l-PC1')plt.ylabel('l-PC2')plt.tight_layout()plt.savefig(f'{outdir}/pc1/z_pca_scatter_annotatepca.{plot_format}')plt.close()# hexbin plot latent PCA with PCA trajectory annotationslog('Plotting latent principal components as hexbin with volume-sampled principal components annotations')g=analysis.scatter_annotate_hex(x=pc[:,0],y=pc[:,1],centers_xy=np.vstack([pca.transform(z_trajectories[0])[:,:2],# trajectory along pc 1, trajectory is pc-dimensional so just take first two dims for plottingpca.transform(z_trajectories[1])[:,:2]]),# trajectory along pc 2, trajectory is pc-dimensional so just take first two dims for plottingannotate=True,labels=[f'l-PC1_{i}'foriinrange(len(z_trajectories[0]))]+[f'l-PC2_{i}'foriinrange(len(z_trajectories[1]))])g.set_axis_labels('l-PC1','l-PC2')plt.tight_layout()plt.savefig(f'{outdir}/pc1/z_pca_hexbin_annotatepca.{plot_format}')plt.close()# scatter plot latent PCA colored by k-means clusterslog('Plotting latent principal components as scatter colored by latent clustering kmeans-centers annotations')analysis.plot_by_cluster(x=pc[:,0],y=pc[:,1],labels=kmeans_labels,labels_sel=num_ksamples,centers_ind=kmeans_centers_ind,annotate=True)plt.xlabel('l-PC1')plt.ylabel('l-PC2')plt.tight_layout()plt.savefig(f'{outdir}/kmeans{num_ksamples}/z_pca_scatter_colorkmeanslabel.{plot_format}')plt.close()# scatter subplots latent PCA colored by k-means clusterslog('Plotting latent principal components as scatter subplots colored by latent clustering kmeans-centers annotations')analysis.plot_by_cluster_subplot(x=pc[:,0],y=pc[:,1],labels=kmeans_labels,labels_sel=num_ksamples)plt.xlabel('l-PC1')plt.ylabel('l-PC2')plt.savefig(f'{outdir}/kmeans{num_ksamples}/z_pca_scatter_subplotkmeanslabel.{plot_format}')plt.close()# UMAP dimensionality reductionifzdim>2andnotskip_umap:log('Running UMAP ...')umap_emb,umap_reducer=analysis.run_umap(z)utils.save_pkl(umap_emb,f'{outdir}/umap.pkl')# scatter plot latent UMAPlog('Plotting latent UMAP embeddings as scatter')g=sns.jointplot(x=umap_emb[:,0],y=umap_emb[:,1],alpha=.1,s=2)g.set_axis_labels('l-UMAP1','l-UMAP2')plt.tight_layout()plt.savefig(f'{outdir}/z_umap_scatter.{plot_format}')plt.close()# hexbin plot latent UMAPlog('Plotting latent UMAP embeddings as hexbin')g=sns.jointplot(x=umap_emb[:,0],y=umap_emb[:,1],kind='hex')g.set_axis_labels('l-UMAP1','l-UMAP2')plt.tight_layout()plt.savefig(f'{outdir}/z_umap_hexbin.{plot_format}')plt.close()# scatter plot latent UMAP with kmeans center annotationslog('Plotting latent UMAP embeddings as scatter with latent clustering kmeans-centers annotations')analysis.scatter_annotate(x=umap_emb[:,0],y=umap_emb[:,1],centers_ind=kmeans_centers_ind,annotate=True,labels=[f'k{i}'foriinrange(num_ksamples)])plt.xlabel('l-UMAP1')plt.ylabel('l-UMAP2')plt.tight_layout()plt.savefig(f'{outdir}/kmeans{num_ksamples}/z_umap_scatter_annotatekmeans.{plot_format}')plt.close()# hexbin plot latent UMAP with kmeans center annotationslog('Plotting latent UMAP embeddings as hexbin with latent clustering kmeans-centers annotations')g=analysis.scatter_annotate_hex(x=umap_emb[:,0],y=umap_emb[:,1],centers_ind=kmeans_centers_ind,annotate=True,labels=[f'k{i}'foriinrange(num_ksamples)])g.set_axis_labels('l-UMAP1','l-UMAP2')plt.tight_layout()plt.savefig(f'{outdir}/kmeans{num_ksamples}/z_umap_hexbin_annotatekmeans.{plot_format}')plt.close()ifnum_pcs>=2:# scatter plot latent UMAP with PCA trajectory annotationslog('Plotting latent UMAP embeddings as scatter with volume-sampled principal components annotations')analysis.scatter_annotate(x=umap_emb[:,0],y=umap_emb[:,1],centers_xy=np.vstack([umap_reducer.transform(z_trajectories[0]),# trajectory in latent space along pc 1, transformed to UMAP spaceumap_reducer.transform(z_trajectories[1])]),# trajectory in latent space along pc 2, transformed to UMAP spaceannotate=True,labels=[f'l-PC1_{i}'foriinrange(len(z_trajectories[0]))]+[f'l-PC2_{i}'foriinrange(len(z_trajectories[1]))])plt.xlabel('l-UMAP1')plt.ylabel('l-UMAP2')plt.tight_layout()plt.savefig(f'{outdir}/pc1/z_umap_scatter_annotatepca.{plot_format}')plt.close()# hexbin plot latent UMAP with PCA trajectory annotationslog('Plotting latent UMAP embeddings as hexbin with volume-sampled principal components annotations')g=analysis.scatter_annotate_hex(x=umap_emb[:,0],y=umap_emb[:,1],centers_xy=np.vstack([umap_reducer.transform(z_trajectories[0]),# trajectory in latent space along pc 1, transformed to UMAP spaceumap_reducer.transform(z_trajectories[1])]),# trajectory in latent space along pc 2, transformed to UMAP spaceannotate=True,labels=[f'l-PC1_{i}'foriinrange(len(z_trajectories[0]))]+[f'l-PC2_{i}'foriinrange(len(z_trajectories[1]))])g.set_axis_labels('l-UMAP1','l-UMAP2')plt.tight_layout()plt.savefig(f'{outdir}/pc1/z_umap_hexbin_annotatepca.{plot_format}')plt.close()# scatter plot latent UMAP colored by k-means clusterslog('Plotting latent UMAP embeddings as scatter colored by latent clustering kmeans-centers annotations')analysis.plot_by_cluster(x=umap_emb[:,0],y=umap_emb[:,1],labels=kmeans_labels,labels_sel=num_ksamples,centers_ind=kmeans_centers_ind,annotate=True)plt.xlabel('l-UMAP1')plt.ylabel('l-UMAP2')plt.tight_layout()plt.savefig(f'{outdir}/kmeans{num_ksamples}/z_umap_scatter_colorkmeanslabel.{plot_format}')plt.close()# scatter subplots latent UMAP colored by k-means clusterslog('Plotting latent UMAP embeddings as scatter subplots colored by latent clustering kmeans-centers annotations')analysis.plot_by_cluster_subplot(x=umap_emb[:,0],y=umap_emb[:,1],labels=kmeans_labels,labels_sel=num_ksamples)plt.xlabel('l-UMAP1')plt.ylabel('l-UMAP2')plt.tight_layout()plt.savefig(f'{outdir}/kmeans{num_ksamples}/z_umap_scatter_subplotkmeanslabel.{plot_format}')plt.close()foriinrange(num_pcs):ifi==0:log('Plotting latent UMAP embeddings as scatter colored by each latent principal components')analysis.scatter_color(x=umap_emb[:,0],y=umap_emb[:,1],c=pc[:,i],cbar_label=f'l-PC{i+1}')plt.xlabel('UMAP1')plt.ylabel('UMAP2')plt.tight_layout()plt.savefig(f'{outdir}/pc{i+1}/z_umap_colorlatentpca.{plot_format}')plt.close()# make plots of first 6 images of each kmeans classlog('Plotting (up to) first 6 images of each latent clustering kmeans class')s=starfile.load_sta_starfile(star_path=starfile_path,source_software=source_software)star_df_backup=s.df.copy(deep=True)forlabelinrange(num_ksamples):# get indices of particles within this kmeans classptcl_inds_this_label=np.nonzero(kmeans_labels==label)[0]# randomly select N particles in the class and sort their indicesptcl_inds_random_subset=np.sort(np.random.choice(ptcl_inds_this_label,min(len(ptcl_inds_this_label),6),replace=False))# skip classes with no particlesiflen(ptcl_inds_random_subset)==0:continues.filter(ind_ptcls=ptcl_inds_random_subset,sort_ptcl_imgs='dose_ascending',use_first_ntilts=1)imgs=s.get_particles_stack(datadir=datadir,lazy=False)analysis.plot_projections(images=imgs,labels=[f'{ptcl_ind}'forptcl_indinptcl_inds_random_subset],width_between_imgs_px=30,height_between_imgs_px=50)plt.savefig(f'{outdir}/kmeans{num_ksamples}/particle_images_kmeanslabel{label}.{plot_format}')plt.close()s.df=star_df_backup.copy(deep=True)# make plot of class label distribution versus tomogram / micrograph in star file orderlog('Plotting distribution of latent clustering kmeans annotations per tomogram')analysis.plot_label_count_distribution(ptcl_star=s,class_labels=kmeans_labels)plt.savefig(f'{outdir}/kmeans{num_ksamples}/tomogram_label_distribution.{plot_format}')plt.close()# make plots of numeric columns in star file (e.g. pose, coordinate, ctf) for correlations with UMAPlog('Plotting latent principal components or latent UMAP embeddings against numeric columns in input star file for potential correlations')os.mkdir(f'{outdir}/controls')ifzdim>2andnotskip_umap:ref_array=utils.load_pkl(f'{outdir}/umap.pkl')ref_names=['l-UMAP1','l-UMAP2']else:ref_array=pcref_names=['l-PC1','l-PC2']s.filter(sort_ptcl_imgs='dose_ascending',use_first_ntilts=1)# only want one value per particlefornumeric_columnins.df.select_dtypes(include=[np.number]).columns:analysis.plot_three_column_correlation(reference_array=ref_array,query_array=s.df[numeric_column].to_numpy(),reference_names=ref_names,query_name=numeric_column)plt.tight_layout()plt.savefig(f'{outdir}/controls/{numeric_column}.{plot_format}')plt.close()
[docs]defmain(args):# log argumentst1=dt.now()log(args)# set files to use as inputs for analysisconfig=f'{args.workdir}/config.pkl'cfg=utils.load_pkl(config)star_path=cfg['starfile_args']['sourcefile_filtered']datadir=cfg['dataset_args']['datadir']args.epoch=utils.get_latest_epoch(args.workdir)ifargs.epoch=='latest'elseint(args.epoch)ifargs.epoch==-1:zfile=f'{args.workdir}/z.train.pkl'weights=f'{args.workdir}/weights.pkl'outdir=f'{args.workdir}/analyze'else:zfile=f'{args.workdir}/z.{args.epoch}.train.pkl'weights=f'{args.workdir}/weights.{args.epoch}.pkl'outdir=f'{args.workdir}/analyze.{args.epoch}'# override outdir if provided as input argumentifargs.outdir:outdir=args.outdirlog(f'Saving results to {outdir}')ifnotos.path.exists(outdir):os.mkdir(outdir)z=utils.load_pkl(zfile)zdim=z.shape[1]vg=models.VolumeGenerator(config=config,weights_path=weights)# do not render figure window when drawing/saving figuresplt.ioff()# plot the loss curveloss=analysis.parse_loss(f'{args.workdir}/run.log')plt.plot(loss)plt.xlabel('epoch')plt.ylabel('total loss')plt.savefig(f'{outdir}/model_loss.{args.plot_format}')plt.close()ifzdim==1:analyze_z_onedimensional(z=z,outdir=outdir,plot_format=args.plot_format,vg=vg,downsample=args.downsample,lowpass=args.lowpass,flip=args.flip,invert=args.invert,skip_vol=args.skip_vol,ondata=args.pc_ondata)else:analyze_z_multidimensional(z=z,outdir=outdir,plot_format=args.plot_format,skip_vol=args.skip_vol,vg=vg,downsample=args.downsample,lowpass=args.lowpass,flip=args.flip,invert=args.invert,num_pcs=args.pc,pc_ondata=args.pc_ondata,skip_umap=args.skip_umap,num_ksamples=args.ksample,starfile_path=star_path,source_software=cfg['starfile_args']['source_software'],datadir=datadir)# copy over template if file doesn't existipynbs=[['tomoDRGN_viz+filt_template_legacy.ipynb','tomoDRGN_viz+filt_legacy.ipynb'],['tomoDRGN_interactive_viz_template.ipynb','tomoDRGN_interactive_viz.ipynb']]fortemplate_ipynb,out_ipynbinipynbs:template_path=str(files('tomodrgn.templates').joinpath(template_ipynb))out_path=f'{outdir}/{out_ipynb}'ifnotos.path.exists(out_path):log(f'Creating jupyter notebook...')assertos.path.isfile(template_path)shutil.copyfile(template_path,out_path)log(out_path)else:log(f'{out_path} already exists. Skipping')log(f'Finished in {dt.now()-t1}')