Source code for cellmaps_vnn.train

# train.py
import os
import shutil
from datetime import date
import yaml

from cellmaps_utils import constants as constants
import cellmaps_vnn.constants as vnnconstants
import logging

import cellmaps_vnn
from cellmaps_vnn.data_wrapper import TrainingDataWrapper
from cellmaps_vnn.exceptions import CellmapsvnnError
from cellmaps_vnn.vnn_trainer import VNNTrainer
from cellmaps_vnn.util import copy_and_register_gene2id_file
from cellmaps_vnn.optuna_vnn_trainer import OptunaVNNTrainer

logger = logging.getLogger(__name__)


[docs] class VNNTrain: COMMAND = 'train' DEFAULT_EPOCH = 50 DEFAULT_LR = 0.001 DEFAULT_WD = 0.001 DEFAULT_ALPHA = 0.3 DEFAULT_PATIENCE = 30 DEFAULT_DELTA = 0.001 DEFAULT_MIN_DROPOUT_LAYER = 2 DEFAULT_DROPOUT_FRACTION = 0.3 DEFAULT_OPTIMIZE = 0 DEFAULT_N_TRIALS = 3 DEFAULT_STD = 'std.txt' def __init__(self, outdir, inputdir, gene_attribute_name=vnnconstants.GENE_SET_COLUMN_NAME, config_file=None, training_data=None, gene2id=None, cell2id=None, mutations=None, cn_deletions=None, cn_amplifications=None, batchsize=vnnconstants.DEFAULT_BATCHSIZE, zscore_method=vnnconstants.DEFAULT_ZSCORE_METHOD, epoch=DEFAULT_EPOCH, lr=DEFAULT_LR, wd=DEFAULT_WD, alpha=DEFAULT_ALPHA, genotype_hiddens=vnnconstants.DEFAULT_GENOTYPE_HIDDENS, patience=DEFAULT_PATIENCE, delta=DEFAULT_DELTA, min_dropout_layer=DEFAULT_MIN_DROPOUT_LAYER, dropout_fraction=DEFAULT_DROPOUT_FRACTION, optimize=DEFAULT_OPTIMIZE, n_trials=DEFAULT_N_TRIALS, cuda=vnnconstants.DEFAULT_CUDA, skip_parent_copy=False, slurm=False, use_gpu=False, slurm_partition=None, slurm_account=None, hierarchy=None, parent_network=None): """ Constructor for training a Visual Neural Network. :param outdir: Directory to write results to. :type outdir: str :param inputdir: Path to directory or RO-Crate with hierarchy.cx2 file. :type inputdir: str :param gene_attribute_name: Name of the node attribute with genes/proteins. :type gene_attribute_name: str :param config_file: Path to configuration file for populating arguments. :type config_file: str, optional :param training_data: Training data file path. :type training_data: str, optional :param gene2id: File mapping genes to IDs. :type gene2id: str, optional :param cell2id: File mapping cells to IDs. :type cell2id: str, optional :param mutations: File with mutation information for cell lines. :type mutations: str, optional :param cn_deletions: File with copy number deletions for cell lines. :type cn_deletions: str, optional :param cn_amplifications: File with copy number amplifications for cell lines. :type cn_amplifications: str, optional :param batchsize: Batch size for training. Default is 64. :type batchsize: int :param zscore_method: Z-score method. Default is 'auc'. :type zscore_method: str :param epoch: Number of epochs for training. Default is 50. :type epoch: int :param lr: Learning rate. Default is 0.001. :type lr: float or list or tuple :param wd: Weight decay. Default is 0.001. :type wd: float :param alpha: Loss parameter alpha. Default is 0.3. :type alpha: float :param genotype_hiddens: Number of neurons in genotype parts. Default is 4. :type genotype_hiddens: int :param patience: Early stopping epoch limit. Default is 30. :type patience: int :param delta: Minimum loss improvement for early stopping. Default is 0.001. :type delta: float :param min_dropout_layer: Layer number to start applying dropout. Default is 2. :type min_dropout_layer: int :param dropout_fraction: Dropout fraction. Default is 0.3. :type dropout_fraction: float :param optimize: Hyperparameter optimization flag. Default is 0. :type optimize: int :param cuda: GPU index. Default is 0. :type cuda: int :param skip_parent_copy: If True, do not copy hierarchy parent. Default is False. :type skip_parent_copy: bool :param slurm: If True, generate SLURM script for training. Default is False. :type slurm: bool :param use_gpu: If True, adjust SLURM script to run on GPU. Default is False. :type use_gpu: bool :param slurm_partition: SLURM partition to use. Default is 'nrnb-gpu' if use_gpu is True. :type slurm_partition: str, optional :param slurm_account: SLURM account name. :type slurm_account: str, optional """ self._outdir = os.path.abspath(outdir) self._inputdir = inputdir self._hierarchy = hierarchy self._parent_network = parent_network self._gene_attribute_name = gene_attribute_name self._config_file = config_file self._training_data = training_data self._gene2id = gene2id self._cell2id = cell2id self._mutations = mutations self._cn_deletions = cn_deletions self._cn_amplifications = cn_amplifications self._zscore_method = zscore_method self._epoch = epoch self._optimize = optimize self._batchsize, self._optimize_batchsize = self._process_param(batchsize, 'batchsize') self._lr, self._optimize_lr = self._process_param(lr, 'lr') self._wd, self._optimize_wd = self._process_param(wd, 'wd') self._alpha, self._optimize_alpha = self._process_param(alpha, 'alpha') self._genotype_hiddens, self._optimize_genotype_hiddens = self._process_param(genotype_hiddens, 'genotype_hiddens') self._patience, self._optimize_patience = self._process_param(patience, 'patience') self._delta, self._optimize_delta = self._process_param(delta, 'delta') self._min_dropout_layer, self._optimize_min_dropout_layer = self._process_param(min_dropout_layer, 'min_dropout_layer') self._dropout_fraction, self._optimize_dropout_fraction = self._process_param(dropout_fraction, 'dropout_fraction') self._n_trials = n_trials self._cuda = cuda self._skip_parent_copy = skip_parent_copy self._slurm = slurm self._use_gpu = use_gpu self._slurm_partition = slurm_partition self._slurm_account = slurm_account self._modelfile = self._get_model_dest_file() self._stdfile = self._get_std_dest_file() def _process_param(self, param, name): if isinstance(param, list) or isinstance(param, tuple): if isinstance(param, tuple) or len(param) > 1: if self._optimize != 0: return param[0], param else: raise CellmapsvnnError( f'Hyperparameter optimization is not enabled (optimize set to 0), but multiple values ' f'for {name} are provided. Please specify only one value, or enable optimization.' ) else: return param[0], None else: return param, None
[docs] @staticmethod def add_subparser(subparsers): """ Adds a subparser for the 'train' command. """ # TODO: modify description later desc = """ Version: todo The 'train' command trains a Visual Neural Network using specified hierarchy files and data from drugcell or NeSTVNN repository. The resulting model is stored in a directory specified by the user. """ parser = subparsers.add_parser(VNNTrain.COMMAND, help='Train a Visual Neural Network', description=desc, formatter_class=constants.ArgParseFormatter) parser.add_argument('outdir', help='Directory to write results to') parser.add_argument('--inputdir', help='Path to directory or RO-Crate with hierarchy.cx2 file.' 'Note that the name of the hierarchy should be hierarchy.cx2.') parser.add_argument('--hierarchy', help='Path to hierarchy (optional). If not set, the process will search for ' 'hierarchy.cx2 in inputdir. It will fail if not found.', type=str) parser.add_argument('--parent_network', help='Path to interactome (parent network, optional) of hierarchy ' 'or NDEx UUID of parent network. If not set, the process will ' 'search for hierarchy_parent.cx2 in inputdir, but it will not fail' 'if not found.', type=str) parser.add_argument('--gene_attribute_name', help='Name of the node attribute of the hierarchy with list of ' 'genes/ proteins of this node. Default: CD_MemberList.', type=str, default=vnnconstants.GENE_SET_COLUMN_NAME) parser.add_argument('--config_file', help='Config file that can be used to populate arguments for training. ' 'If a given argument is set, it will override the default value.') parser.add_argument('--training_data', help='Training data') parser.add_argument('--gene2id', help='Gene to ID mapping file', type=str) parser.add_argument('--cell2id', help='Cell to ID mapping file', type=str) parser.add_argument('--mutations', help='Mutation information for cell lines', type=str) parser.add_argument('--cn_deletions', help='Copy number deletions for cell lines', type=str) parser.add_argument('--cn_amplifications', help='Copy number amplifications for cell lines', type=str) parser.add_argument('--epoch', type=int, default=VNNTrain.DEFAULT_EPOCH, help='Training epochs') parser.add_argument('--zscore_method', type=str, choices=['auc', 'zscore', 'robustz'], help='Z-score method (choose from: auc, zscore, robustz). Default: auc') parser.add_argument('--batchsize', type=int, nargs='+', help='Batch size') parser.add_argument('--lr', type=float, nargs='+', help='Learning rate') parser.add_argument('--wd', type=float, nargs='+', help='Weight decay') parser.add_argument('--alpha', type=float, nargs='+', help='Loss parameter alpha') parser.add_argument('--genotype_hiddens', type=int, nargs='+', help='Neurons in genotype parts') parser.add_argument('--patience', type=int, nargs='+', help='Early stopping epoch limit') parser.add_argument('--delta', type=float, nargs='+', help='Minimum change in loss for improvement') parser.add_argument('--min_dropout_layer', type=int, nargs='+', help='Start dropout from this layer') parser.add_argument('--dropout_fraction', type=float, nargs='+', help='Dropout fraction') parser.add_argument('--optimize', type=int, help='Hyperparameter optimization (0- no optimization, ' '1- optuna optimizer)') parser.add_argument('--n_trials', type=int, help='Number of trials in hyperparameter optimization') parser.add_argument('--cuda', type=int, help='Specify GPU') parser.add_argument('--skip_parent_copy', help='If set, hierarchy parent (interactome) will not be copied', action='store_true') parser.add_argument('--slurm', help='If set, slurm script for training will be generated.', action='store_true') parser.add_argument('--use_gpu', help='If set, slurm script will be adjusted to run on GPU.', action='store_true') parser.add_argument('--slurm_partition', help='Slurm partition. If use_gpu is set, the default is nrnb-gpu.', type=str) parser.add_argument('--slurm_account', help='Slurm account. If use_gpu is set, the default is nrnb-gpu.', type=str) return parser
[docs] def run(self): """ The logic for training the Visual Neural Network. """ if self._optimize not in [0, 1]: logger.error(f"The value {self._optimize} is wrong value for optimize.") raise CellmapsvnnError(f"The value {self._optimize} is wrong value for optimize.") data_wrapper = TrainingDataWrapper(self._outdir, self._inputdir, self._gene_attribute_name, self._training_data, self._cell2id, self._gene2id, self._mutations, self._cn_deletions, self._cn_amplifications, self._modelfile, self._genotype_hiddens, self._lr, self._wd, self._alpha, self._epoch, self._batchsize, self._cuda, self._zscore_method, self._stdfile, self._patience, self._delta, self._min_dropout_layer, self._dropout_fraction, self._hierarchy) if self._optimize == 1: trial_params = OptunaVNNTrainer(data_wrapper, n_trials=self._n_trials, batchsize_vals=self._optimize_batchsize, lr_vals=self._optimize_lr, wd_vals=self._optimize_wd, alpha_vals=self._optimize_alpha, genotype_hiddens_vals=self._optimize_genotype_hiddens, patience_vals=self._optimize_patience, delta_vals=self._optimize_delta, min_dropout_layer_vals=self._optimize_min_dropout_layer, dropout_fraction_vals=self._optimize_dropout_fraction ).exec_study() self._save_final_config(trial_params) for key, value in trial_params.items(): if hasattr(data_wrapper, key): setattr(data_wrapper, key, value) try: VNNTrainer(data_wrapper).train_model() except Exception as e: logger.error(f"Training error: {e}") raise CellmapsvnnError(f"Encountered problem in training: {e}")
def _save_final_config(self, best_params): """ Writes a flattened config file that includes best Optuna parameters and all original (non-optimized) settings. :param best_params: dict from Optuna best trial """ if self._config_file is None: return # skip if no config file was used with open(self._config_file, 'r') as f: config = yaml.safe_load(f) # Overwrite optimized params with best values for key, value in best_params.items(): config[key] = value # Save new config next to original final_config_path = os.path.join(self._outdir, 'config.yaml') with open(final_config_path, 'w') as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) logger.info(f'Saved final config to: {final_config_path}') def _get_model_dest_file(self): """ Returns the file path for saving the trained model file. :return: The file path for the model file. """ return os.path.join(self._outdir, 'model_final.pt') def _get_std_dest_file(self): """ Returns the file path for saving the standard deviation file. :return: The file path for the standard deviation file. """ return os.path.join(self._outdir, VNNTrain.DEFAULT_STD)
[docs] def register_outputs(self, outdir, description, keywords, provenance_utils): """ Registers the model and standard deviation files with the FAIRSCAPE service for data provenance. It generates dataset IDs for each registered file. :param outdir: The directory where the output files are stored. :param description: Description for the output files. :param keywords: List of keywords associated with the files. :param provenance_utils: The utility class for provenance registration. :return: A list of dataset IDs for the registered model and standard deviation files. """ return_ids = [self._register_model_file(outdir, description, keywords, provenance_utils), self._register_std_file(outdir, description, keywords, provenance_utils), self._copy_and_register_hierarchy(outdir, description, keywords, provenance_utils), self._register_pruned_hierarchy(outdir, description, keywords, provenance_utils), copy_and_register_gene2id_file(self._gene2id, outdir, description, keywords, provenance_utils)] if not self._skip_parent_copy: id_hierarchy_parent = self._copy_and_register_hierarchy_parent(outdir, description, keywords, provenance_utils) if id_hierarchy_parent is not None: return_ids.append(id_hierarchy_parent) if self._optimize != 0: id_config = self._register_config_file(outdir, description, keywords, provenance_utils) return_ids.append(id_config) return return_ids
def _register_model_file(self, outdir, description, keywords, provenance_utils): """ Registers the trained model file with the FAIRSCAPE service for data provenance. :param outdir: The output directory where the model file is stored. :param description: Description of the model file for provenance registration. :param keywords: List of keywords associated with the model file. :param provenance_utils: The utility class for provenance registration. :return: The dataset ID assigned to the registered model file. """ dest_path = self._get_model_dest_file() description = description description += ' Model file' keywords = keywords keywords.extend(['file']) data_dict = {'name': os.path.basename(dest_path) + ' trained model file', 'description': description, 'keywords': keywords, 'data-format': 'pt', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime(provenance_utils.get_default_date_format_str())} dataset_id = provenance_utils.register_dataset(outdir, source_file=dest_path, data_dict=data_dict) return dataset_id def _register_std_file(self, outdir, description, keywords, provenance_utils): """ Registers the standard deviation file with the FAIRSCAPE service for data provenance. :param outdir: The output directory where the standard deviation file is stored. :param description: Description of the standard deviation file for provenance registration. :param keywords: List of keywords associated with the standard deviation file. :param provenance_utils: The utility class for provenance registration. :return: The dataset ID assigned to the registered standard deviation file. """ dest_path = self._get_std_dest_file() description = description description += ' standard deviation file' keywords = keywords keywords.extend(['file']) data_dict = {'name': os.path.basename(dest_path) + ' standard deviation file', 'description': description, 'keywords': keywords, 'data-format': 'txt', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime(provenance_utils.get_default_date_format_str())} dataset_id = provenance_utils.register_dataset(outdir, source_file=dest_path, data_dict=data_dict) return dataset_id def _register_config_file(self, outdir, description, keywords, provenance_utils): """ Registers the standard deviation file with the FAIRSCAPE service for data provenance. :param outdir: The output directory where the standard deviation file is stored. :param description: Description of the standard deviation file for provenance registration. :param keywords: List of keywords associated with the standard deviation file. :param provenance_utils: The utility class for provenance registration. :return: The dataset ID assigned to the registered standard deviation file. """ dest_path = os.path.join(self._outdir, "config.yaml") description = description description += ' config file' keywords = keywords keywords.extend(['file']) data_dict = {'name': os.path.basename(dest_path) + ' config file', 'description': description, 'keywords': keywords, 'data-format': 'yaml', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime(provenance_utils.get_default_date_format_str())} dataset_id = provenance_utils.register_dataset(outdir, source_file=dest_path, data_dict=data_dict) return dataset_id def _copy_and_register_hierarchy(self, outdir, description, keywords, provenance_utils): hierarchy_out_file = os.path.join(outdir, vnnconstants.ORIGINAL_HIERARCHY_FILENAME) hierarchy_in_file = os.path.join(self._inputdir, vnnconstants.HIERARCHY_FILENAME) if self._hierarchy is None \ else self._hierarchy shutil.copy(hierarchy_in_file, hierarchy_out_file) data_dict = {'name': os.path.basename(hierarchy_out_file) + ' Hierarchy network file', 'description': description + ' Hierarchy network file', 'keywords': keywords, 'data-format': 'CX2', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime('%m-%d-%Y')} dataset_id = provenance_utils.register_dataset(outdir, source_file=hierarchy_out_file, data_dict=data_dict) return dataset_id def _register_pruned_hierarchy(self, outdir, description, keywords, provenance_utils): hierarchy_out_file = os.path.join(outdir, vnnconstants.HIERARCHY_FILENAME) data_dict = {'name': os.path.basename(hierarchy_out_file) + ' Hierarchy network file used to build VNN', 'description': description + ' Hierarchy network file used to build VNN', 'keywords': keywords, 'data-format': 'CX2', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime('%m-%d-%Y')} dataset_id = provenance_utils.register_dataset(outdir, source_file=hierarchy_out_file, data_dict=data_dict) return dataset_id def _copy_and_register_hierarchy_parent(self, outdir, description, keywords, provenance_utils): hierarchy_parent_in_file = None if self._parent_network is not None: hierarchy_parent_in_file = self._parent_network if self._inputdir is not None: hierarchy_parent_in_file = os.path.join(self._inputdir, vnnconstants.PARENT_NETWORK_NAME) if hierarchy_parent_in_file is None or not os.path.exists(hierarchy_parent_in_file): logger.warning("No hierarchy parent in the input directory. Cannot copy.") return None hierarchy_parent_out_file = os.path.join(outdir, vnnconstants.PARENT_NETWORK_NAME) shutil.copy(hierarchy_parent_in_file, hierarchy_parent_out_file) data_dict = {'name': os.path.basename(hierarchy_parent_out_file) + ' Hierarchy parent network file', 'description': description + ' Hierarchy parent network file', 'keywords': keywords, 'data-format': 'CX2', 'author': cellmaps_vnn.__name__, 'version': cellmaps_vnn.__version__, 'date-published': date.today().strftime('%m-%d-%Y')} dataset_id = provenance_utils.register_dataset(outdir, source_file=hierarchy_parent_out_file, data_dict=data_dict) return dataset_id