Source code for cellmaps_vnn.predict

import os
import logging
import shutil
from datetime import date

from cellmaps_vnn.importance_score import FakeGeneImportanceScoreCalculator
from cellmaps_vnn.util import copy_and_register_gene2id_file
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
import torch.utils.data as du
from torch.autograd import Variable
from cellmaps_utils import constants
import cellmaps_vnn.constants as vnnconstants
from ndex2.cx2 import RawCX2NetworkFactory

import cellmaps_vnn
from cellmaps_vnn import util
from cellmaps_vnn.exceptions import CellmapsvnnError
from cellmaps_vnn.rlipp_calculator import RLIPPCalculator

logger = logging.getLogger(__name__)


[docs] class VNNPredict: COMMAND = 'predict' DEFAULT_CPU_COUNT = 1 DEFAULT_DRUG_COUNT = 0 def __init__(self, outdir, inputdir, config_file=None, predict_data=None, gene2id=None, cell2id=None, mutations=None, cn_deletions=None, cn_amplifications=None, batchsize=vnnconstants.DEFAULT_BATCHSIZE, zscore_method=vnnconstants.DEFAULT_ZSCORE_METHOD, cpu_count=DEFAULT_CPU_COUNT, drug_count=DEFAULT_DRUG_COUNT, genotype_hiddens=vnnconstants.DEFAULT_GENOTYPE_HIDDENS, cuda=vnnconstants.DEFAULT_CUDA, std=None, slurm=False, use_gpu=False, slurm_partition=None, slurm_account=None): """ Constructor for predicting with a trained model. """ self._inputdir = inputdir self._hierarchy_file = os.path.join(self._inputdir, vnnconstants.HIERARCHY_FILENAME) self._outdir = os.path.abspath(outdir) self._config_file = config_file self._predict_data = predict_data self._gene2id = gene2id self._cell2id = cell2id self._mutations = mutations self._cn_deletions = cn_deletions self._cn_amplifications = cn_amplifications self._batchsize = batchsize self._zscore_method = zscore_method self._cpu_count = cpu_count self._drug_count = drug_count self._genotype_hiddens = genotype_hiddens self._std = std self._cuda = cuda self._slurm = slurm self._use_gpu = use_gpu self._slurm_partition = slurm_partition self._slurm_account = slurm_account self._number_feature_grads = 0 self.use_cuda = torch.cuda.is_available() and self._cuda is not None self.excluded_terms = [] if (isinstance(self._batchsize, list) or isinstance(self._batchsize, tuple) or isinstance(self._genotype_hiddens, list) or isinstance(self._genotype_hiddens, tuple)): raise CellmapsvnnError( "Batch size and genotype hidden layer sizes must be integers during testing or prediction. Lists of " "values for these parameters are only supported during hyperparameter optimization in training." )
[docs] @staticmethod def add_subparser(subparsers): """ Adds a subparser for the 'predict' command. """ # TODO: modify description later desc = """ Version: todo The 'predict' command takes a trained model and input data to run predictions. The results are stored in a specified output directory. """ parser = subparsers.add_parser(VNNPredict.COMMAND, help='Run prediction using a trained model', description=desc, formatter_class=constants.ArgParseFormatter) parser.add_argument('outdir', help='Directory to write results to') parser.add_argument('--inputdir', required=True, help='Path to RO-Crate with the trained model', type=str) parser.add_argument('--config_file', help='Config file that can be used to populate arguments for training. ' 'If a given argument is set, it will override the default value.') parser.add_argument('--predict_data', help='Path to the file with text data', type=str) parser.add_argument('--gene2id', help='Gene to ID mapping file', type=str) parser.add_argument('--cell2id', help='Cell to ID mapping file', type=str) parser.add_argument('--mutations', help='Mutation information for cell lines', type=str) parser.add_argument('--cn_deletions', help='Copy number deletions for cell lines', type=str) parser.add_argument('--cn_amplifications', help='Copy number amplifications for cell lines', type=str) parser.add_argument('--batchsize', help='Batchsize', type=int) parser.add_argument('--zscore_method', help='zscore method (zscore/robustz)', type=str) parser.add_argument('--cpu_count', help='No of available cores', type=int) parser.add_argument('--drug_count', help='No of top performing drugs', type=int) parser.add_argument('--genotype_hiddens', help='Mapping for the number of neurons in each term in genotype parts', type=int) parser.add_argument('--cuda', help='Specify GPU', type=int) parser.add_argument('--std', help='Path to standardization File (if not set, standardization file from ' 'RO-Crate will be used)', type=str) parser.add_argument('--slurm', help='If set, slurm script for training will be generated.', action='store_true') parser.add_argument('--use_gpu', help='If set, slurm script will be adjusted to run on GPU.', action='store_true') parser.add_argument('--slurm_partition', help='Slurm partition. If use_gpu is set, the default is nrnb-gpu.', type=str) parser.add_argument('--slurm_account', help='Slurm account. If use_gpu is set, the default is nrnb-gpu.', type=str) return parser
[docs] def run(self): """ The logic for running predictions with the model. It executes the prediction process using the trained model and input data. :raises CellmapsvnnError: If an error occurs during the prediction process. """ try: self._check_inputdir() self._populate_excluded_terms() model = os.path.join(self._inputdir, 'model_final.pt') std = os.path.join(self._inputdir, 'std.txt') if self._std is None else os.path.abspath(self._std) torch.set_printoptions(precision=5) # Load data and model for prediction predict_data, cell2id_mapping = self._prepare_predict_data(self._predict_data, self._cell2id, self._zscore_method, std) # Load cell features cell_features = util.load_cell_features(self._mutations, self._cn_deletions, self._cn_amplifications) hidden_dir = self._get_hidden_dir_path() if not os.path.exists(hidden_dir): os.mkdir(hidden_dir) # Perform prediction self.predict(predict_data, model, hidden_dir, self._batchsize, cell_features) factory = RawCX2NetworkFactory() hierarchy = factory.get_cx2network(self._hierarchy_file) # Perform interpretation calc = RLIPPCalculator(self._outdir, hierarchy, self._predict_data, self._get_predict_dest_file(), self._gene2id, self._cell2id, hidden_dir, self._cpu_count, self._genotype_hiddens, self._drug_count, self.excluded_terms) calc.calc_scores() gene_calc = FakeGeneImportanceScoreCalculator(self._outdir, hierarchy) gene_calc.calc_scores() logger.info('Prediction and interpretation executed successfully') print('Prediction and interpretation executed successfully') except Exception as e: logger.error(f"Error in prediction flow: {e}") raise CellmapsvnnError(f"Encountered problem in prediction flow: {e}")
def _check_inputdir(self): if not os.path.exists(os.path.join(self._inputdir, 'model_final.pt')): self._inputdir = os.path.join(self._inputdir, 'out_train') def _populate_excluded_terms(self): excluded_terms_path = os.path.join(self._inputdir, 'vnn_excluded_terms.txt') if os.path.exists(excluded_terms_path): with open(excluded_terms_path, 'r') as file: self.excluded_terms = set(int(line.strip()) for line in file if line.strip().isdigit()) def _prepare_predict_data(self, test_file, cell2id_mapping_file, zscore_method, std_file): """ Prepares the prediction data for the model. :param test_file: Path to the file containing the test dataset. :type test_file: str :param cell2id_mapping_file: Path to the file containing the cell to ID mapping. :type cell2id_mapping_file: str :param zscore_method: Method used for z-score standardization. :type zscore_method: str :param std_file: Path to the standardization file. :type std_file: str :return: A tuple containing test features and labels as tensors, and the cell2id mapping. :rtype: Tuple(Tensor, Tensor), dict """ cell2id_mapping = util.load_mapping(cell2id_mapping_file, 'cell lines') test_features, test_labels = self._load_pred_data(test_file, cell2id_mapping, zscore_method, std_file) return (torch.Tensor(test_features), torch.Tensor(test_labels)), cell2id_mapping @staticmethod def _load_pred_data(test_file, cell2id, zscore_method, train_std_file): """ Loads and processes prediction data from a file. :param test_file: Path to the file containing the test dataset. :type test_file: str :param cell2id: Dictionary mapping cell lines to their respective IDs. :type cell2id: dict :param zscore_method: Method used for z-score standardization. :type zscore_method: str :param train_std_file: Path to the training standardization file. :type train_std_file: str :return: Features and labels for the prediction data. :rtype: List, List """ train_std_df = pd.read_csv(train_std_file, sep='\t', header=None, names=['dataset', 'center', 'scale']) all_df = pd.read_csv(test_file, sep='\t', header=None, names=['cell_line', 'smiles', 'auc', 'dataset']) test_df = all_df[all_df['cell_line'].isin(cell2id.keys())] test_std_df = util.calc_std_vals(test_df, zscore_method) for i, row in test_std_df.iterrows(): dataset = row['dataset'] train_entry = train_std_df.query('dataset == @dataset') if not train_entry.empty: test_std_df.loc[i, 'center'] = float(train_entry['center'].iloc[0]) test_std_df.loc[i, 'scale'] = float(train_entry['scale'].iloc[0]) test_df = util.standardize_data(test_df, test_std_df) feature = [] label = [] for row in test_df.values: feature.append([cell2id[row[0]]]) label.append([float(row[2])]) return feature, label def _get_predict_dest_file(self): """ Returns the file path for saving the prediction results. :return: The file path to the prediction results file. """ return os.path.join(self._outdir, 'predict.txt') def _get_feature_grad_dest_file(self, grad): """ Returns the file path for saving the gradient of a specific feature. :return: The file path to the prediction feature grad file. """ return os.path.join(self._outdir, f'predict_feature_grad_{grad}.txt') def _get_hidden_dir_path(self): """ Returns the path to the directory where hidden layer outputs will be stored. :return: The file path to the hidden directory. """ return os.path.join(self._outdir, 'hidden/') def _to_device(self, tensor): if self.use_cuda: return tensor.cuda(self._cuda) return tensor
[docs] def predict(self, predict_data, model_file, hidden_folder, batch_size, cell_features=None): """ Perform prediction using the trained model. :param predict_data: Tuple of features and labels for prediction. :param model_file: Path to the trained model file. :param hidden_folder: Directory to store hidden layer outputs. :param batch_size: Size of each batch for prediction. :param cell_features: Additional cell features for prediction. """ try: logger.info('Starting prediction process') print('Starting prediction process') model = self._load_model(model_file) test_loader = self._create_data_loader(predict_data, batch_size) test_predict, saved_grads = self._predict(model, test_loader, cell_features, hidden_folder) predict_label_gpu = self._to_device(predict_data[1]) test_corr = util.pearson_corr(test_predict, predict_label_gpu) logger.info(f"Test correlation {model.root}: {test_corr:.4f}") np.savetxt(self._get_predict_dest_file(), test_predict.cpu().numpy(), '%.4e') except Exception as e: logger.error(f"Prediction error: {e}") raise CellmapsvnnError(f"Encountered problem in prediction: {e}")
def _load_model(self, model_file): """ Load the trained model for prediction. :param model_file: Path to the trained model file. :return: Loaded model. """ model = torch.load(model_file, map_location=f'cuda:{self._cuda}' if self.use_cuda else torch.device("cpu"), weights_only=False) if self.use_cuda: model.cuda(self._cuda) model.eval() return model def _create_data_loader(self, predict_data, batch_size): """ Create a DataLoader for the prediction data. :param predict_data: Tuple of features and labels for prediction. :param batch_size: Size of each batch for prediction. :return: DataLoader for the prediction data. """ predict_feature, predict_label = predict_data return du.DataLoader(du.TensorDataset(predict_feature, predict_label), batch_size=batch_size, shuffle=False) def _predict(self, model, data_loader, cell_features, hidden_folder): """ Run the prediction process and save outputs. :param model: Trained model for prediction. :param data_loader: DataLoader containing the prediction data. :param cell_features: Additional cell features for prediction. :param hidden_folder: Directory to store hidden layer outputs. :return: Tuple of prediction results and saved gradients. """ test_predict = torch.zeros(0, 0) if self.use_cuda: test_predict = test_predict.cuda(self._cuda) saved_grads = {} for i, (inputdata, labels) in enumerate(data_loader): cuda_features = self._process_input(inputdata, cell_features) aux_out_map, hidden_embeddings_map = model(cuda_features) test_predict = torch.cat([test_predict, aux_out_map['final'].data], dim=0) \ if test_predict.size()[0] else aux_out_map['final'].data self._save_hidden_outputs(hidden_embeddings_map, hidden_folder) self._register_gradient_hooks(hidden_embeddings_map, saved_grads) self._backpropagate(aux_out_map) self._save_gradients(cuda_features) self._save_hidden_gradients(saved_grads, hidden_folder) return test_predict, saved_grads def _process_input(self, inputdata, cell_features): """ Process input data for the model. :param inputdata: Input data for the model. :param cell_features: Additional cell features for prediction. :return: Processed features as CUDA variables. """ features = util.build_input_vector(inputdata, cell_features) return Variable(self._to_device(features), requires_grad=True) def _save_hidden_outputs(self, hidden_embeddings_map, hidden_folder): """ Save outputs from hidden layers. :param hidden_embeddings_map: Dictionary of hidden layer outputs. :param hidden_folder: Directory to save hidden layer outputs. """ for element, hidden_map in hidden_embeddings_map.items(): hidden_file = os.path.join(hidden_folder, element + '.hidden') with open(hidden_file, 'ab') as f: np.savetxt(f, hidden_map.data.cpu().numpy(), '%.4e') def _register_gradient_hooks(self, hidden_embeddings_map, saved_grads): """ Register gradient hooks to save gradients of hidden layers. :param hidden_embeddings_map: Dictionary of hidden layer outputs. :param saved_grads: Dictionary to store saved gradients. """ def save_grad(elem): def savegrad_hook(grad): saved_grads[elem] = grad return savegrad_hook for element, _ in hidden_embeddings_map.items(): hidden_embeddings_map[element].register_hook(save_grad(element)) def _backpropagate(self, aux_out_map): """ Perform backpropagation. :param aux_out_map: Auxiliary output map from the model. """ aux_out_map['final'].backward(torch.ones_like(aux_out_map['final'])) def _save_gradients(self, cuda_features): """ Save gradients for each feature. :param cuda_features: CUDA features variable. """ self._number_feature_grads = len(cuda_features[0, 0, :]) for i in range(self._number_feature_grads): feature_grad = cuda_features.grad.data[:, :, i] grad_file = self._get_feature_grad_dest_file(i) with open(grad_file, 'ab') as f: np.savetxt(f, feature_grad.cpu().numpy(), '%.4e', delimiter='\t') def _save_hidden_gradients(self, saved_grads, hidden_folder): """ Save the gradients of the hidden layer outputs. :param saved_grads: Dictionary containing the saved gradients. :param hidden_folder: Directory to save the hidden layer gradients. """ for element, hidden_grad in saved_grads.items(): hidden_grad_file = os.path.join(hidden_folder, f'{element}.hidden_grad') with open(hidden_grad_file, 'ab') as f: np.savetxt(f, hidden_grad.data.cpu().numpy(), '%.4e', delimiter='\t')
[docs] def register_outputs(self, outdir, description, keywords, provenance_utils): """ Registers all output files (predictions, feature gradients, and hidden files) with the FAIRSCAPE service for data provenance. :param outdir: The directory where the output files are stored. :param description: Description for the output files. :param keywords: List of keywords associated with the files. :param provenance_utils: The utility class for provenance registration. :return: A list of dataset IDs for the registered files. """ output_ids = [copy_and_register_gene2id_file(self._gene2id, outdir, description, keywords, provenance_utils), self._register_predict_file(outdir, description, keywords, provenance_utils)] for i in range(self._number_feature_grads): output_ids.append(self._register_feature_grad_file(outdir, description, keywords, provenance_utils, i)) output_ids.extend(self._register_hidden_files(outdir, description, keywords, provenance_utils)) orginal_hierarchy_id = self._copy_and_register_original_hierarchy(outdir, description, keywords, provenance_utils) if orginal_hierarchy_id is not None: output_ids.append(orginal_hierarchy_id) output_ids.append(self._copy_and_register_hierarchy(outdir, description, keywords, provenance_utils)) id_parent = self._copy_and_register_hierarchy_parent(outdir, description, keywords, provenance_utils) if id_parent is not None: output_ids.append(id_parent) return output_ids
def _register_predict_file(self, outdir, description, keywords, provenance_utils): """ Registers the prediction result file with the FAIRSCAPE service for data provenance. :param outdir: The output directory where the outputs are stored. :param description: Description of the file for provenance registration. :param keywords: List of keywords associated with the file. :param provenance_utils: The utility class for provenance registration. :return: The dataset ID assigned to the registered file. """ dest_path = self._get_predict_dest_file() description = description description += ' prediction result file' keywords = keywords keywords.extend(['file']) data_dict = {'name': os.path.basename(dest_path) + ' prediction result file', 'description': description, 'keywords': keywords, 'data-format': 'txt', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime(provenance_utils.get_default_date_format_str())} dataset_id = provenance_utils.register_dataset(outdir, source_file=dest_path, data_dict=data_dict) return dataset_id def _register_feature_grad_file(self, outdir, description, keywords, provenance_utils, grad): """ Registers the feature gradient file with the FAIRSCAPE service for data provenance. :param outdir: The output directory where the file is stored. :param description: Description of the file for provenance registration. :param keywords: List of keywords associated with the file. :param provenance_utils: The utility class for provenance registration. :param grad: The specific gradient index for the feature. :return: The dataset ID assigned to the registered file. """ dest_path = self._get_feature_grad_dest_file(grad) description = description description += f' prediction feature grad {grad} file' keywords = keywords keywords.extend(['file']) data_dict = {'name': os.path.basename(dest_path) + f' prediction feature grad {grad} file', 'description': description, 'keywords': keywords, 'data-format': 'txt', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime(provenance_utils.get_default_date_format_str())} dataset_id = provenance_utils.register_dataset(outdir, source_file=dest_path, data_dict=data_dict) return dataset_id def _register_hidden_files(self, outdir, description, keywords, provenance_utils): """ Registers the output files from the hidden layers with the FAIRSCAPE service for data provenance. :param outdir: The directory where the hidden layer output files are stored. :param description: A general description for the hidden files. :param keywords: A list of keywords associated with the hidden files. :param provenance_utils: An instance of the utility class used for handling the provenance registration. :return: A list of dataset IDs, each corresponding to a registered hidden file. """ data_dict = {'name': cellmaps_vnn.__name__ + ' hidden layer output file', 'description': description + ' hidden layer output file', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime(provenance_utils.get_default_date_format_str())} hidden_files_ids = list() hidden_dir = self._get_hidden_dir_path() cntr = 0 for entry in tqdm(os.listdir(hidden_dir), desc='FAIRSCAPE hidden files registration'): data_dict['data-format'] = entry.split('.')[-1] dest_path = os.path.join(hidden_dir, entry) data_dict['name'] = os.path.basename(dest_path) + f' hidden file' data_dict['keywords'] = ['hidden', 'file'] dataset_id = provenance_utils.register_dataset(outdir, source_file=dest_path, data_dict=data_dict) hidden_files_ids.append(dataset_id) cntr += 1 if cntr > 5: # Todo: https://github.com/fairscape/fairscape-cli/issues/9 logger.warning('FAIRSCAPE cannot handle too many files, skipping rest') break return hidden_files_ids def _copy_and_register_original_hierarchy(self, outdir, description, keywords, provenance_utils): hierarchy_out_file = os.path.join(outdir, vnnconstants.ORIGINAL_HIERARCHY_FILENAME) hierarchy_in_file = os.path.join(self._inputdir, vnnconstants.ORIGINAL_HIERARCHY_FILENAME) if not os.path.exists(hierarchy_in_file): return None shutil.copy(hierarchy_in_file, hierarchy_out_file) data_dict = {'name': os.path.basename(hierarchy_out_file) + ' Hierarchy network file', 'description': description + ' Hierarchy network file', 'keywords': keywords, 'data-format': 'CX2', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime('%m-%d-%Y')} dataset_id = provenance_utils.register_dataset(outdir, source_file=hierarchy_out_file, data_dict=data_dict) return dataset_id def _copy_and_register_hierarchy(self, outdir, description, keywords, provenance_utils): hierarchy_out_file = os.path.join(outdir, vnnconstants.HIERARCHY_FILENAME) shutil.copy(self._hierarchy_file, hierarchy_out_file) data_dict = {'name': os.path.basename(hierarchy_out_file) + ' Hierarchy network file used to build VNN', 'description': description + ' Hierarchy network file used to build VNN', 'keywords': keywords, 'data-format': 'CX2', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime('%m-%d-%Y')} dataset_id = provenance_utils.register_dataset(outdir, source_file=hierarchy_out_file, data_dict=data_dict) return dataset_id def _copy_and_register_hierarchy_parent(self, outdir, description, keywords, provenance_utils): hierarchy_parent_in_file = os.path.join(self._inputdir, vnnconstants.PARENT_NETWORK_NAME) if not os.path.exists(hierarchy_parent_in_file): return None hierarchy_parent_out_file = os.path.join(outdir, vnnconstants.PARENT_NETWORK_NAME) shutil.copy(hierarchy_parent_in_file, hierarchy_parent_out_file) data_dict = {'name': os.path.basename(hierarchy_parent_out_file) + ' Hierarchy parent network file', 'description': description + ' Hierarchy parent network file', 'keywords': keywords, 'data-format': 'CX2', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime('%m-%d-%Y')} dataset_id = provenance_utils.register_dataset(outdir, source_file=hierarchy_parent_out_file, data_dict=data_dict) return dataset_id def _register_rlipp_file(self, outdir, description, keywords, provenance_utils): rlipp_file = os.path.join(outdir, vnnconstants.RLIPP_OUTPUT_FILE) data_dict = {'name': os.path.basename(rlipp_file) + ' RLIPP output file', 'description': description + ' RLIPP output file', 'keywords': keywords, 'data-format': 'txt', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime('%m-%d-%Y')} dataset_id = provenance_utils.register_dataset(outdir, source_file=rlipp_file, data_dict=data_dict) return dataset_id def _register_gene_rho_file(self, outdir, description, keywords, provenance_utils): gene_rho_file = os.path.join(outdir, 'gene_rho.out') data_dict = {'name': os.path.basename(gene_rho_file) + ' Gene Rho file', 'description': description + ' Gene Rho file', 'keywords': keywords, 'data-format': 'txt', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime('%m-%d-%Y')} dataset_id = provenance_utils.register_dataset(outdir, source_file=gene_rho_file, data_dict=data_dict) return dataset_id