diff --git a/cluster/data/medinfmk/ddi/processed/runs/replicatedIndividualScores_max_epochs-20-batch_size-200-H1-300-H2-400.csv b/cluster/data/medinfmk/ddi/processed/runs/replicatedIndividualScores_max_epochs-20-batch_size-200-H1-300-H2-400.csv index fa41e35..ede8a2e 100644 --- a/cluster/data/medinfmk/ddi/processed/runs/replicatedIndividualScores_max_epochs-20-batch_size-200-H1-300-H2-400.csv +++ b/cluster/data/medinfmk/ddi/processed/runs/replicatedIndividualScores_max_epochs-20-batch_size-200-H1-300-H2-400.csv @@ -1,9 +1,9 @@ Similarity,AUC,AUPR,F-measure,Recall,Precision -chem,0.5034279703813386,0.3707883127432703,0.016485193853668464,0.008336764100452861,0.7297297297297297 -target,0.6365752726156747,0.48225391753253,0.510043846232283,0.5148209139563606,0.5053546170943625 -transporter,0.5929480035719402,0.42082198127211157,0.4441033241583739,0.43176204199258955,0.45717088055797733 -enzyme,0.6176284992224306,0.49677802489333955,0.5204540928101972,0.6723960477562783,0.42452401065696277 -pathway,0.5977482655600098,0.44625723361169034,0.4693583964707011,0.5037052284890902,0.4393966600826001 -indication,0.6849664612920969,0.6082905116097612,0.5798574821852732,0.6281391519143681,0.538468325392624 -sideeffect,0.5500630462517399,0.35279234448866786,0.5009900404882229,0.8723754631535612,0.3513950499564695 -offsideeffect,0.7412231878157685,0.7075307466626528,0.6496950411031557,0.8825648414985591,0.5140579101972303 +chem,0.0,0.0,0.0,0.0,0.0 +target,0.0,0.0,0.0,0.0,0.0 +transporter,0.0,0.0,0.0,0.0,0.0 +enzyme,0.0,0.0,0.0,0.0,0.0 +pathway,0.0,0.0,0.0,0.0,0.0 +indication,0.0,0.0,0.0,0.0,0.0 +sideeffect,0.0,0.0,0.0,0.0,0.0 +offsideeffect,0.0,0.0,0.0,0.0,0.0 diff --git a/notebooks/.ipynb_checkpoints/01_KS_Skorch_DDI-Copy1-checkpoint.ipynb b/notebooks/.ipynb_checkpoints/01_KS_Skorch_DDI-Copy1-checkpoint.ipynb new file mode 100644 index 0000000..ac0e629 --- /dev/null +++ b/notebooks/.ipynb_checkpoints/01_KS_Skorch_DDI-Copy1-checkpoint.ipynb @@ -0,0 +1,605 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](https://scikit-learn.org/stable/_images/grid_search_workflow.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 1358, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')" + ] + }, + { + "cell_type": "code", + "execution_count": 1359, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import pickle\n", + "\n", + "from sklearn.datasets import make_classification\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.preprocessing import LabelEncoder\n", + "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.model_selection import StratifiedKFold\n", + "from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, matthews_corrcoef, precision_recall_curve, auc\n", + "\n", + "from keras.utils import np_utils\n", + "\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import TensorDataset\n", + "from torch.utils.data import Dataset\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "from torch.optim import SGD\n", + "\n", + "import skorch\n", + "from skorch import NeuralNetClassifier\n", + "from skorch.callbacks import EpochScoring\n", + "from skorch.callbacks import TensorBoard\n", + "from skorch.helper import predefined_split" + ] + }, + { + "cell_type": "code", + "execution_count": 1360, + "metadata": {}, + "outputs": [], + "source": [ + "# import configurations (file paths, etc.)\n", + "import yaml\n", + "try:\n", + " from yaml import CLoader as Loader, CDumper as Dumper\n", + "except ImportError:\n", + " from yaml import Loader, Dumper\n", + " \n", + "configFile = '../cluster/data/medinfmk/ddi/config/config.yml'\n", + "\n", + "with open(configFile, 'r') as ymlfile:\n", + " cfg = yaml.load(ymlfile, Loader=Loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 1361, + "metadata": {}, + "outputs": [], + "source": [ + "pathInput = cfg['filePaths']['dirRaw']\n", + "pathOutput = cfg['filePaths']['dirProcessed']\n", + "# path to store python binary files (pickles)\n", + "# in order not to recalculate them every time\n", + "pathPickles = cfg['filePaths']['dirProcessedFiles']['dirPickles']\n", + "pathRuns = cfg['filePaths']['dirProcessedFiles']['dirRuns']\n", + "pathPaperScores = cfg['filePaths']['dirRawFiles']['paper-individual-metrics-scores']\n", + "datasetDirs = cfg['filePaths']['dirRawDatasets']\n", + "DS1_path = str(datasetDirs[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 1362, + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_data(input_fea, input_lab, seperate=False):\n", + " offside_sim_path = input_fea\n", + " drug_interaction_matrix_path = input_lab\n", + " drug_fea = np.loadtxt(offside_sim_path,dtype=float,delimiter=\",\")\n", + " interaction = np.loadtxt(drug_interaction_matrix_path,dtype=int,delimiter=\",\")\n", + " \n", + " train = []\n", + " label = []\n", + " tmp_fea=[]\n", + " drug_fea_tmp = []\n", + " \n", + " for i in range(0, (interaction.shape[0]-1)):\n", + " for j in range((i+1), interaction.shape[1]):\n", + " label.append(interaction[i,j])\n", + " drug_fea_tmp_1 = list(drug_fea[i])\n", + " drug_fea_tmp_2 = list(drug_fea[j])\n", + " if seperate:\n", + " tmp_fea = (drug_fea_tmp_1,drug_fea_tmp_2)\n", + " else:\n", + " tmp_fea = drug_fea_tmp_1 + drug_fea_tmp_2\n", + " train.append(tmp_fea)\n", + "\n", + " return np.array(train), np.array(label)" + ] + }, + { + "cell_type": "code", + "execution_count": 1363, + "metadata": {}, + "outputs": [], + "source": [ + "def transfer_array_format(data):\n", + " formated_matrix1 = []\n", + " formated_matrix2 = []\n", + " for val in data:\n", + " formated_matrix1.append(val[0])\n", + " formated_matrix2.append(val[1])\n", + " return np.array(formated_matrix1), np.array(formated_matrix2)" + ] + }, + { + "cell_type": "code", + "execution_count": 1364, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_labels(labels, encoder=None, categorical=True):\n", + " if not encoder:\n", + " encoder = LabelEncoder()\n", + " encoder.fit(labels)\n", + " y = encoder.transform(labels).astype(np.int32)\n", + " if categorical:\n", + " y = np_utils.to_categorical(y)\n", + "# print(y)\n", + " return y, encoder" + ] + }, + { + "cell_type": "code", + "execution_count": 1365, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_names(labels, encoder=None, categorical=True):\n", + " if not encoder:\n", + " encoder = LabelEncoder()\n", + " encoder.fit(labels)\n", + " if categorical:\n", + " labels = np_utils.to_categorical(labels)\n", + " return labels, encoder" + ] + }, + { + "cell_type": "code", + "execution_count": 1366, + "metadata": {}, + "outputs": [], + "source": [ + "def getStratifiedKFoldSplit(X,y,n_splits):\n", + " skf = StratifiedKFold(n_splits=n_splits, random_state=42)\n", + " return skf.split(X,y)" + ] + }, + { + "cell_type": "code", + "execution_count": 1367, + "metadata": {}, + "outputs": [], + "source": [ + "class NDD(nn.Module):\n", + " def __init__(self, D_in=1096, H1=300, H2=400, D_out=1, drop=0.5):\n", + " super(NDD, self).__init__()\n", + " # an affine operation: y = Wx + b\n", + " self.fc1 = nn.Linear(D_in, H1) # Fully Connected\n", + " self.fc2 = nn.Linear(H1, H2)\n", + " self.fc3 = nn.Linear(H2, D_out)\n", + " self.drop = nn.Dropout(drop)\n", + " self._init_weights()\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(self.fc1(x))\n", + " x = self.drop(x)\n", + " x = F.relu(self.fc2(x))\n", + " x = self.drop(x)\n", + " x = self.fc3(x)\n", + " return x\n", + " \n", + " def _init_weights(self):\n", + " for m in self.modules():\n", + " if(isinstance(m, nn.Linear)):\n", + " m.weight.data.normal_(0, 0.05)\n", + " m.bias.data.uniform_(-1,0)" + ] + }, + { + "cell_type": "code", + "execution_count": 1368, + "metadata": {}, + "outputs": [], + "source": [ + "def updateSimilarityDFSingleMetric(df, sim_type, metric, value):\n", + " df.loc[df['Similarity'] == sim_type, metric ] = round(value,3)\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 1369, + "metadata": {}, + "outputs": [], + "source": [ + "def updateSimilarityDF(df, sim_type, AUROC, AUPR, F1, Rec, Prec):\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'AUC', AUROC)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'AUPR', AUPR)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'F-measure', F1)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'Recall', Rec)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'Precision', Prec)\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 1370, + "metadata": {}, + "outputs": [], + "source": [ + "def getNetParamsStr(net, str_hidden_layers_params, net_params_to_print=[\"max_epochs\", \"batch_size\"]):\n", + " net_params = [val for sublist in [[x,net.get_params()[x]] for x in net_params_to_print] for val in sublist]\n", + " net_params_str = '-'.join(map(str, flattened))\n", + " return(net_params_str+str_hidden_layers_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 1371, + "metadata": {}, + "outputs": [], + "source": [ + "def writeReplicatedIndividualScoresCSV(net, df, destination, str_hidden_layers_params):\n", + " filePath = destination + \"replicatedIndividualScores_\" + getNetParamsStr(net, str_hidden_layers_params) + \".csv\"\n", + " df.to_csv(path_or_buf = filePath, index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 1372, + "metadata": {}, + "outputs": [], + "source": [ + "def getNDDClassifier(D_in, H1, H2, D_out, drop, Xy_test):\n", + " model = NDD(D_in, H1, H2, D_out, drop)\n", + " \n", + " net = NeuralNetClassifier(\n", + " model,\n", + "# criterion=nn.CrossEntropyLoss,\n", + " criterion=nn.BCEWithLogitsLoss,\n", + " max_epochs=20,\n", + " optimizer=SGD,\n", + " optimizer__lr=0.01,\n", + " optimizer__momentum=0.9, \n", + " optimizer__weight_decay=1e-6, \n", + " optimizer__nesterov=True, \n", + " batch_size=200,\n", + " callbacks=callbacks,\n", + " # Shuffle training data on each epoch\n", + " iterator_train__shuffle=True,\n", + " device=device,\n", + " train_split=predefined_split(Xy_test),\n", + " )\n", + " return net" + ] + }, + { + "cell_type": "code", + "execution_count": 1373, + "metadata": {}, + "outputs": [], + "source": [ + "def avgMetrics(AUROC, AUPR, F1, Rec, Prec, kfold_nsplits):\n", + " AUROC /= kfold_nsplits\n", + " AUPR /= kfold_nsplits\n", + " F1 /= kfold_nsplits\n", + " Rec /= kfold_nsplits\n", + " Prec /= kfold_nsplits\n", + " return AUROC, AUPR, F1, Rec, Prec" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run" + ] + }, + { + "cell_type": "code", + "execution_count": 1374, + "metadata": {}, + "outputs": [], + "source": [ + "df_paperIndividualScores = pd.read_csv(pathPaperScores)\n", + "\n", + "df_replicatedIndividualScores = df_paperIndividualScores.copy()\n", + "\n", + "for col in df_replicatedIndividualScores.columns:\n", + " if col != 'Similarity':\n", + " df_replicatedIndividualScores[col].values[:] = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 1375, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "soft = nn.Softmax(dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 1376, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(119902, 1096)\n", + "(119902, 1)\n" + ] + } + ], + "source": [ + "print(X_train.shape)\n", + "print(y_train.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 1377, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running fold0 for sideeffect...\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "result type Float can't be cast to the desired output type Long", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0mmodelPicklePath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpathPickles\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m\"model_params/model_params_fold\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"_\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr_hidden_layers_params\u001b[0m\u001b[0;34m+\u001b[0m \u001b[0;34m\"_\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msimilarity\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\".p\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdo_train_model\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodelPicklePath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/classifier.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0;31m# this is actually a pylint bug:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;31m# https://github.com/PyCQA/pylint/issues/1085\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mNeuralNetClassifier\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[1;32m 846\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 847\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 848\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpartial_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 849\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 850\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mpartial_fit\u001b[0;34m(self, X, y, classes, **fit_params)\u001b[0m\n\u001b[1;32m 805\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'on_train_begin'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 806\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 807\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 808\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mfit_loop\u001b[0;34m(self, X, y, epochs, **fit_params)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[0myi_res\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0myi\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0my_train_is_ph\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 738\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'on_batch_begin'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0myi_res\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 739\u001b[0;31m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 740\u001b[0m \u001b[0mtrain_batch_count\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 741\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecord_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'train_loss'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mtrain_step\u001b[0;34m(self, Xi, yi, **fit_params)\u001b[0m\n\u001b[1;32m 662\u001b[0m \u001b[0mstep_accumulator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstore_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 663\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 664\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 665\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstep_accumulator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/optim/sgd.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mclosure\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 80\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mgroup\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_groups\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mstep_fn\u001b[0;34m()\u001b[0m\n\u001b[1;32m 659\u001b[0m \u001b[0mstep_accumulator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_train_step_accumulator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 660\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstep_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 661\u001b[0;31m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step_single\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 662\u001b[0m \u001b[0mstep_accumulator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstore_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 663\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mtrain_step_single\u001b[0;34m(self, Xi, yi, **fit_params)\u001b[0m\n\u001b[1;32m 602\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 603\u001b[0m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 604\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 605\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/classifier.py\u001b[0m in \u001b[0;36mget_loss\u001b[0;34m(self, y_pred, y_true, *args, **kwargs)\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0meps\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;31m# pylint: disable=signature-differs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mget_loss\u001b[0;34m(self, y_pred, y_true, X, training)\u001b[0m\n\u001b[1;32m 1097\u001b[0m \"\"\"\n\u001b[1;32m 1098\u001b[0m \u001b[0my_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1099\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcriterion_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1101\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 541\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 542\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m 599\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 600\u001b[0m \u001b[0mpos_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpos_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 601\u001b[0;31m reduction=self.reduction)\n\u001b[0m\u001b[1;32m 602\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 603\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mbinary_cross_entropy_with_logits\u001b[0;34m(input, target, weight, size_average, reduce, reduction, pos_weight)\u001b[0m\n\u001b[1;32m 2112\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Target size ({}) must be the same as input size ({})\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2113\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2114\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbinary_cross_entropy_with_logits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpos_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction_enum\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2116\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: result type Float can't be cast to the desired output type Long" + ] + } + ], + "source": [ + "do_prepare_data = False\n", + "do_train_model = True\n", + "kfold_nsplits = 5\n", + "# similaritiesToRun = df_paperIndividualScores['Similarity']\n", + "similaritiesToRun = [\"sideeffect\"]\n", + "\n", + "for similarity in similaritiesToRun:\n", + " input_fea = pathInput+DS1_path+\"/\" + similarity + \"_Jacarrd_sim.csv\"\n", + " input_lab = pathInput+DS1_path+\"/drug_drug_matrix.csv\"\n", + " dataPicklePath = pathPickles+\"data_X_y_\" + similarity + \"_Jaccard.p\"\n", + "\n", + " # Define model\n", + " D_in, H1, H2, D_out, drop = X.shape[1], 300, 400, 1, 0.5\n", + " str_hidden_layers_params = \"-H1-\" + str(H1) + \"-H2-\" + str(H2)\n", + " callbacks = []\n", + " \n", + " # Prepare data if not available\n", + " if do_prepare_data:\n", + " print(\"Preparing \" + similarity + \" data...\")\n", + " X,y = prepare_data(input_fea, input_lab, seperate = False)\n", + "\n", + " with open(dataPicklePath, 'wb') as f:\n", + " pickle.dump([X, y], f)\n", + "\n", + " # Load X,y and split in to train, test\n", + " with open(dataPicklePath, 'rb') as f:\n", + " X, y = pickle.load(f)\n", + " \n", + "\n", + " y = np.reshape(y, (y.shape[0], 1))\n", + " \n", + " X = X.astype(np.float32)\n", + " y = y.astype(np.int64) \n", + "\n", + " \n", + "# y_cat = np_utils.to_categorical(y)\n", + " \n", + " AUROC, AUPR, F1, Rec, Prec = 0,0,0,0,0\n", + " kFoldSplit = getStratifiedKFoldSplit(X,y,n_splits=kfold_nsplits)\n", + " for i, indices in enumerate(kFoldSplit):\n", + " print(\"Running fold\" + str(i) + \" for \" + similarity +\"...\")\n", + " \n", + " train_index = indices[0]\n", + " test_index = indices[1]\n", + " X_train, X_test = X[train_index], X[test_index]\n", + " y_train, y_test = y[train_index], y[test_index]\n", + "# y_train, y_test = y_cat[train_index], y_cat[test_index]\n", + " \n", + " # Create Network Classifier\n", + " Xy_test = skorch.dataset.Dataset(X_test, y_test)\n", + " net = getNDDClassifier(D_in, H1, H2, D_out, drop, Xy_test)\n", + " \n", + " # Fit and save OR load model\n", + " modelPicklePath = pathPickles+\"model_params/model_params_fold\" + str(i) + \"_\" + str_hidden_layers_params+ \"_\" + similarity + \".p\"\n", + " if do_train_model:\n", + " net.fit(X_train, y_train)\n", + " net.save_params(f_params=modelPicklePath)\n", + " else:\n", + " net.initialize() # This is important!\n", + " net.load_params(f_params=modelPicklePath)\n", + "\n", + " # Make predictions\n", + " y_pred = net.predict(X_test)\n", + " lr_probs = soft(net.forward(X_test))[:,1]\n", + " lr_precision, lr_recall, _ = precision_recall_curve(y_test, lr_probs)\n", + "\n", + " AUROC += roc_auc_score(y_test, y_pred)\n", + " AUPR += auc(lr_recall, lr_precision)\n", + " F1 += f1_score(y_test, y_pred)\n", + " Rec += recall_score(y_test, y_pred)\n", + " Prec += precision_score(y_test, y_pred)\n", + " \n", + " print(i, similarity, AUROC, AUPR, F1, Rec, Prec)\n", + " \n", + " \n", + " AUROC, AUPR, F1, Rec, Prec = avgMetrics(AUROC, AUPR, F1, Rec, Prec, kfold_nsplits)\n", + " print(similarity, AUROC, AUPR, F1, Rec, Prec)\n", + " \n", + " # Fill replicated metrics\n", + " updateSimilarityDF(df_replicatedIndividualScores, similarity, AUROC, AUPR, F1, Rec, Prec)\n", + " \n", + "# Write CSV\n", + "writeReplicatedIndividualScoresCSV(net, df_replicatedIndividualScores, pathRuns, str_hidden_layers_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compare to Paper" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "print(df_paperIndividualScores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(df_replicatedIndividualScores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "diff_metrics = ['AUC', 'AUPR', 'F-measure', 'Recall', 'Precision']\n", + "df_diff = df_paperIndividualScores[diff_metrics] - df_replicatedIndividualScores[diff_metrics]\n", + "df_diff_abs = df_diff.abs()\n", + "df_diff_percent = (df_diff_abs / df_paperIndividualScores[diff_metrics]) * 100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_diff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from seaborn import heatmap\n", + "heatmap(df_diff, yticklabels=df_paperIndividualScores[\"Similarity\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "heatmap(df_diff_abs, yticklabels=df_paperIndividualScores[\"Similarity\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "heatmap(df_diff_percent, yticklabels=df_paperIndividualScores[\"Similarity\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import mean_squared_error\n", + "mean_squared_error(df_paperIndividualScores[diff_metrics],\n", + " df_replicatedIndividualScores[diff_metrics])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/.ipynb_checkpoints/02_AA_Skorch_DDI-checkpoint.ipynb b/notebooks/.ipynb_checkpoints/02_AA_Skorch_DDI-checkpoint.ipynb new file mode 100644 index 0000000..111f7ab --- /dev/null +++ b/notebooks/.ipynb_checkpoints/02_AA_Skorch_DDI-checkpoint.ipynb @@ -0,0 +1,581 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](https://scikit-learn.org/stable/_images/grid_search_workflow.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 1230, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')" + ] + }, + { + "cell_type": "code", + "execution_count": 1231, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import pickle\n", + "\n", + "from sklearn.datasets import make_classification\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.preprocessing import LabelEncoder\n", + "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.model_selection import StratifiedKFold\n", + "from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, matthews_corrcoef, precision_recall_curve, auc\n", + "\n", + "from keras.utils import np_utils\n", + "\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import TensorDataset\n", + "from torch.utils.data import Dataset\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "from torch.optim import SGD\n", + "\n", + "import skorch\n", + "from skorch import NeuralNetClassifier\n", + "from skorch.callbacks import EpochScoring\n", + "from skorch.callbacks import TensorBoard\n", + "from skorch.helper import predefined_split" + ] + }, + { + "cell_type": "code", + "execution_count": 1232, + "metadata": {}, + "outputs": [], + "source": [ + "# import configurations (file paths, etc.)\n", + "import yaml\n", + "try:\n", + " from yaml import CLoader as Loader, CDumper as Dumper\n", + "except ImportError:\n", + " from yaml import Loader, Dumper\n", + " \n", + "configFile = '../cluster/data/medinfmk/ddi/config/config.yml'\n", + "\n", + "with open(configFile, 'r') as ymlfile:\n", + " cfg = yaml.load(ymlfile, Loader=Loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 1233, + "metadata": {}, + "outputs": [], + "source": [ + "pathInput = cfg['filePaths']['dirRaw']\n", + "pathOutput = cfg['filePaths']['dirProcessed']\n", + "# path to store python binary files (pickles)\n", + "# in order not to recalculate them every time\n", + "pathPickles = cfg['filePaths']['dirProcessedFiles']['dirPickles']\n", + "pathRuns = cfg['filePaths']['dirProcessedFiles']['dirRuns']\n", + "pathPaperScores = cfg['filePaths']['dirRawFiles']['paper-individual-metrics-scores']\n", + "datasetDirs = cfg['filePaths']['dirRawDatasets']\n", + "DS1_path = str(datasetDirs[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 1234, + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_data(input_fea, input_lab, seperate=False):\n", + " offside_sim_path = input_fea\n", + " drug_interaction_matrix_path = input_lab\n", + " drug_fea = np.loadtxt(offside_sim_path,dtype=float,delimiter=\",\")\n", + " interaction = np.loadtxt(drug_interaction_matrix_path,dtype=int,delimiter=\",\")\n", + " \n", + " train = []\n", + " label = []\n", + " tmp_fea=[]\n", + " drug_fea_tmp = []\n", + " \n", + " for i in range(0, (interaction.shape[0]-1)):\n", + " for j in range((i+1), interaction.shape[1]):\n", + " label.append(interaction[i,j])\n", + " drug_fea_tmp_1 = list(drug_fea[i])\n", + " drug_fea_tmp_2 = list(drug_fea[j])\n", + " if seperate:\n", + " tmp_fea = (drug_fea_tmp_1,drug_fea_tmp_2)\n", + " else:\n", + " tmp_fea = drug_fea_tmp_1 + drug_fea_tmp_2\n", + " train.append(tmp_fea)\n", + "\n", + " return np.array(train), np.array(label)" + ] + }, + { + "cell_type": "code", + "execution_count": 1235, + "metadata": {}, + "outputs": [], + "source": [ + "def transfer_array_format(data):\n", + " formated_matrix1 = []\n", + " formated_matrix2 = []\n", + " for val in data:\n", + " formated_matrix1.append(val[0])\n", + " formated_matrix2.append(val[1])\n", + " return np.array(formated_matrix1), np.array(formated_matrix2)" + ] + }, + { + "cell_type": "code", + "execution_count": 1236, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_labels(labels, encoder=None, categorical=True):\n", + " if not encoder:\n", + " encoder = LabelEncoder()\n", + " encoder.fit(labels)\n", + " y = encoder.transform(labels).astype(np.int32)\n", + " if categorical:\n", + " y = np_utils.to_categorical(y)\n", + "# print(y)\n", + " return y, encoder" + ] + }, + { + "cell_type": "code", + "execution_count": 1237, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_names(labels, encoder=None, categorical=True):\n", + " if not encoder:\n", + " encoder = LabelEncoder()\n", + " encoder.fit(labels)\n", + " if categorical:\n", + " labels = np_utils.to_categorical(labels)\n", + " return labels, encoder" + ] + }, + { + "cell_type": "code", + "execution_count": 1238, + "metadata": {}, + "outputs": [], + "source": [ + "def getStratifiedKFoldSplit(X,y,n_splits):\n", + " skf = StratifiedKFold(n_splits=n_splits)\n", + " return skf.split(X,y)" + ] + }, + { + "cell_type": "code", + "execution_count": 1239, + "metadata": {}, + "outputs": [], + "source": [ + "class NDD(nn.Module):\n", + " def __init__(self, D_in=1096, H1=300, H2=400, D_out=2, drop=0.5):\n", + " super(NDD, self).__init__()\n", + " # an affine operation: y = Wx + b\n", + " self.fc1 = nn.Linear(D_in, H1) # Fully Connected\n", + " self.fc2 = nn.Linear(H1, H2)\n", + " self.fc3 = nn.Linear(H2, D_out)\n", + " self.drop = nn.Dropout(drop)\n", + " self._init_weights()\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(self.fc1(x))\n", + " x = self.drop(x)\n", + " x = F.relu(self.fc2(x))\n", + " x = self.drop(x)\n", + " x = self.fc3(x)\n", + " return x\n", + " \n", + " def _init_weights(self):\n", + " for m in self.modules():\n", + " if(isinstance(m, nn.Linear)):\n", + " m.weight.data.normal_(0, 0.05)\n", + " m.bias.data.uniform_(-1,0)" + ] + }, + { + "cell_type": "code", + "execution_count": 1240, + "metadata": {}, + "outputs": [], + "source": [ + "def updateSimilarityDFSingleMetric(df, sim_type, metric, value):\n", + " df.loc[df['Similarity'] == sim_type, metric ] = round(value,3)\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 1241, + "metadata": {}, + "outputs": [], + "source": [ + "def updateSimilarityDF(df, sim_type, AUROC, AUPR, F1, Rec, Prec):\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'AUC', AUROC)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'AUPR', AUPR)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'F-measure', F1)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'Recall', Rec)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'Precision', Prec)\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 1242, + "metadata": {}, + "outputs": [], + "source": [ + "def getNetParamsStr(net, str_hidden_layers_params, net_params_to_print=[\"max_epochs\", \"batch_size\"]):\n", + " net_params = [val for sublist in [[x,net.get_params()[x]] for x in net_params_to_print] for val in sublist]\n", + " net_params_str = '-'.join(map(str, flattened))\n", + " return(net_params_str+str_hidden_layers_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 1243, + "metadata": {}, + "outputs": [], + "source": [ + "def writeReplicatedIndividualScoresCSV(net, df, destination, str_hidden_layers_params):\n", + " filePath = destination + \"replicatedIndividualScores_\" + getNetParamsStr(net, str_hidden_layers_params) + \".csv\"\n", + " df.to_csv(path_or_buf = filePath, index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 1244, + "metadata": {}, + "outputs": [], + "source": [ + "def getNDDClassifier(D_in, H1, H2, D_out, drop, Xy_test):\n", + " model = NDD(D_in, H1, H2, D_out, drop)\n", + " \n", + " net = NeuralNetClassifier(\n", + " model,\n", + "# criterion=nn.CrossEntropyLoss,\n", + " criterion=nn.BCEWithLogitsLoss,\n", + " max_epochs=20,\n", + " optimizer=SGD,\n", + " optimizer__lr=0.01,\n", + " optimizer__momentum=0.9, \n", + " optimizer__weight_decay=1e-6, \n", + " optimizer__nesterov=True, \n", + " batch_size=200,\n", + " callbacks=callbacks,\n", + " # Shuffle training data on each epoch\n", + " iterator_train__shuffle=True,\n", + " device=device,\n", + " train_split=predefined_split(Xy_test),\n", + " )\n", + " return net" + ] + }, + { + "cell_type": "code", + "execution_count": 1245, + "metadata": {}, + "outputs": [], + "source": [ + "def avgMetrics(AUROC, AUPR, F1, Rec, Prec, kfold_nsplits):\n", + " AUROC /= kfold_nsplits\n", + " AUPR /= kfold_nsplits\n", + " F1 /= kfold_nsplits\n", + " Rec /= kfold_nsplits\n", + " Prec /= kfold_nsplits\n", + " return AUROC, AUPR, F1, Rec, Prec" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run" + ] + }, + { + "cell_type": "code", + "execution_count": 1246, + "metadata": {}, + "outputs": [], + "source": [ + "df_paperIndividualScores = pd.read_csv(pathPaperScores)\n", + "\n", + "df_replicatedIndividualScores = df_paperIndividualScores.copy()\n", + "\n", + "for col in df_replicatedIndividualScores.columns:\n", + " if col != 'Similarity':\n", + " df_replicatedIndividualScores[col].values[:] = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 1247, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "soft = nn.Softmax(dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 1248, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Preparing sideeffect data...\n", + "Running fold0 for sideeffect...\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Classification metrics can't handle a mix of multilabel-indicator and binary targets", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mmodelPicklePath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpathPickles\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m\"model_params/model_params_fold\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"_\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr_hidden_layers_params\u001b[0m\u001b[0;34m+\u001b[0m \u001b[0;34m\"_\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msimilarity\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\".p\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdo_train_model\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodelPicklePath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/classifier.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0;31m# this is actually a pylint bug:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;31m# https://github.com/PyCQA/pylint/issues/1085\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mNeuralNetClassifier\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[1;32m 846\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 847\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 848\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpartial_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 849\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 850\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mpartial_fit\u001b[0;34m(self, X, y, classes, **fit_params)\u001b[0m\n\u001b[1;32m 805\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'on_train_begin'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 806\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 807\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 808\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mfit_loop\u001b[0;34m(self, X, y, epochs, **fit_params)\u001b[0m\n\u001b[1;32m 760\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecord\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"valid_batch_count\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_batch_count\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 761\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 762\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'on_epoch_end'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mon_epoch_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 763\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 764\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mnotify\u001b[0;34m(self, method_name, **cb_kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mcb_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallbacks_\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mcb_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0;31m# pylint: disable=unused-argument\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/callbacks/scoring.py\u001b[0m in \u001b[0;36mon_epoch_end\u001b[0;34m(self, net, dataset_train, dataset_valid, **kwargs)\u001b[0m\n\u001b[1;32m 410\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 411\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcache_net_infer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_caching\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mcached_net\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 412\u001b[0;31m \u001b[0mcurrent_score\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_scoring\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcached_net\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 413\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 414\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_record_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcurrent_score\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/callbacks/scoring.py\u001b[0m in \u001b[0;36m_scoring\u001b[0;34m(self, net, X_test, y_test)\u001b[0m\n\u001b[1;32m 119\u001b[0m instead of running inference again, if available.\"\"\"\n\u001b[1;32m 120\u001b[0m \u001b[0mscorer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcheck_scoring\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscoring_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 121\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mscorer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 122\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_is_best_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcurrent_score\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_scorer.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, estimator, X, y_true, sample_weight)\u001b[0m\n\u001b[1;32m 167\u001b[0m stacklevel=2)\n\u001b[1;32m 168\u001b[0m return self._score(partial(_cached_call, None), estimator, X, y_true,\n\u001b[0;32m--> 169\u001b[0;31m sample_weight=sample_weight)\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_factory_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_scorer.py\u001b[0m in \u001b[0;36m_score\u001b[0;34m(self, method_caller, estimator, X, y_true, sample_weight)\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 211\u001b[0m return self._sign * self._score_func(y_true, y_pred,\n\u001b[0;32m--> 212\u001b[0;31m **self._kwargs)\n\u001b[0m\u001b[1;32m 213\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_classification.py\u001b[0m in \u001b[0;36maccuracy_score\u001b[0;34m(y_true, y_pred, normalize, sample_weight)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0;31m# Compute accuracy for each possible representation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0my_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_check_targets\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0mcheck_consistent_length\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0my_type\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstartswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'multilabel'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_classification.py\u001b[0m in \u001b[0;36m_check_targets\u001b[0;34m(y_true, y_pred)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_type\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m raise ValueError(\"Classification metrics can't handle a mix of {0} \"\n\u001b[0;32m---> 90\u001b[0;31m \"and {1} targets\".format(type_true, type_pred))\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# We can't have more than one value on y_type => The set is no more needed\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: Classification metrics can't handle a mix of multilabel-indicator and binary targets" + ] + } + ], + "source": [ + "do_prepare_data = True\n", + "do_train_model = True\n", + "kfold_nsplits = 5\n", + "# similaritiesToRun = df_paperIndividualScores['Similarity']\n", + "similaritiesToRun = [\"sideeffect\"]\n", + "\n", + "for similarity in similaritiesToRun:\n", + " input_fea = pathInput+DS1_path+\"/\" + similarity + \"_Jacarrd_sim.csv\"\n", + " input_lab = pathInput+DS1_path+\"/drug_drug_matrix.csv\"\n", + " dataPicklePath = pathPickles+\"data_X_y_\" + similarity + \"_Jaccard.p\"\n", + "\n", + " # Define model\n", + " D_in, H1, H2, D_out, drop = X.shape[1], 300, 400, 2, 0.5\n", + " str_hidden_layers_params = \"-H1-\" + str(H1) + \"-H2-\" + str(H2)\n", + " callbacks = []\n", + " \n", + " # Prepare data if not available\n", + " if do_prepare_data:\n", + " print(\"Preparing \" + similarity + \" data...\")\n", + " X,y = prepare_data(input_fea, input_lab, seperate = False)\n", + "\n", + " with open(dataPicklePath, 'wb') as f:\n", + " pickle.dump([X, y], f)\n", + "\n", + " # Load X,y and split in to train, test\n", + " with open(dataPicklePath, 'rb') as f:\n", + " X, y = pickle.load(f)\n", + " \n", + " X = X.astype(np.float32)\n", + " y = y.astype(np.int64) \n", + " \n", + " y_cat = np_utils.to_categorical(y)\n", + " \n", + " AUROC, AUPR, F1, Rec, Prec = 0,0,0,0,0\n", + " kFoldSplit = getStratifiedKFoldSplit(X,y,n_splits=kfold_nsplits)\n", + " for i, indices in enumerate(kFoldSplit):\n", + " print(\"Running fold\" + str(i) + \" for \" + similarity +\"...\")\n", + " \n", + " train_index = indices[0]\n", + " test_index = indices[1]\n", + " X_train, X_test = X[train_index], X[test_index]\n", + "# y_train, y_test = y[train_index], y[test_index]\n", + " y_train, y_test = y_cat[train_index], y_cat[test_index]\n", + " \n", + " # Create Network Classifier\n", + " Xy_test = skorch.dataset.Dataset(X_test, y_test)\n", + " net = getNDDClassifier(D_in, H1, H2, D_out, drop, Xy_test)\n", + " \n", + " # Fit and save OR load model\n", + " modelPicklePath = pathPickles+\"model_params/model_params_fold\" + str(i) + \"_\" + str_hidden_layers_params+ \"_\" + similarity + \".p\"\n", + " if do_train_model:\n", + " net.fit(X_train, y_train)\n", + " net.save_params(f_params=modelPicklePath)\n", + " else:\n", + " net.initialize() # This is important!\n", + " net.load_params(f_params=modelPicklePath)\n", + "\n", + " # Make predictions\n", + " y_pred = net.predict(X_test)\n", + " lr_probs = soft(net.forward(X_test))[:,1]\n", + " lr_precision, lr_recall, _ = precision_recall_curve(y_test, lr_probs)\n", + "\n", + " AUROC += roc_auc_score(y_test, y_pred)\n", + " AUPR += auc(lr_recall, lr_precision)\n", + " F1 += f1_score(y_test, y_pred)\n", + " Rec += recall_score(y_test, y_pred)\n", + " Prec += precision_score(y_test, y_pred)\n", + " \n", + " print(i, similarity, AUROC, AUPR, F1, Rec, Prec)\n", + " \n", + " \n", + " AUROC, AUPR, F1, Rec, Prec = avgMetrics(AUROC, AUPR, F1, Rec, Prec, kfold_nsplits)\n", + " print(similarity, AUROC, AUPR, F1, Rec, Prec)\n", + " \n", + " # Fill replicated metrics\n", + " updateSimilarityDF(df_replicatedIndividualScores, similarity, AUROC, AUPR, F1, Rec, Prec)\n", + " \n", + "# Write CSV\n", + "writeReplicatedIndividualScoresCSV(net, df_replicatedIndividualScores, pathRuns, str_hidden_layers_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compare to Paper" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "print(df_paperIndividualScores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(df_replicatedIndividualScores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "diff_metrics = ['AUC', 'AUPR', 'F-measure', 'Recall', 'Precision']\n", + "df_diff = df_paperIndividualScores[diff_metrics] - df_replicatedIndividualScores[diff_metrics]\n", + "df_diff_abs = df_diff.abs()\n", + "df_diff_percent = (df_diff_abs / df_paperIndividualScores[diff_metrics]) * 100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_diff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from seaborn import heatmap\n", + "heatmap(df_diff, yticklabels=df_paperIndividualScores[\"Similarity\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "heatmap(df_diff_abs, yticklabels=df_paperIndividualScores[\"Similarity\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "heatmap(df_diff_percent, yticklabels=df_paperIndividualScores[\"Similarity\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import mean_squared_error\n", + "mean_squared_error(df_paperIndividualScores[diff_metrics],\n", + " df_replicatedIndividualScores[diff_metrics])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/01_KS_Skorch_DDI-Copy1.ipynb b/notebooks/01_KS_Skorch_DDI-Copy1.ipynb new file mode 100644 index 0000000..ac0e629 --- /dev/null +++ b/notebooks/01_KS_Skorch_DDI-Copy1.ipynb @@ -0,0 +1,605 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](https://scikit-learn.org/stable/_images/grid_search_workflow.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 1358, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')" + ] + }, + { + "cell_type": "code", + "execution_count": 1359, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import pickle\n", + "\n", + "from sklearn.datasets import make_classification\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.preprocessing import LabelEncoder\n", + "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.model_selection import StratifiedKFold\n", + "from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, matthews_corrcoef, precision_recall_curve, auc\n", + "\n", + "from keras.utils import np_utils\n", + "\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import TensorDataset\n", + "from torch.utils.data import Dataset\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "from torch.optim import SGD\n", + "\n", + "import skorch\n", + "from skorch import NeuralNetClassifier\n", + "from skorch.callbacks import EpochScoring\n", + "from skorch.callbacks import TensorBoard\n", + "from skorch.helper import predefined_split" + ] + }, + { + "cell_type": "code", + "execution_count": 1360, + "metadata": {}, + "outputs": [], + "source": [ + "# import configurations (file paths, etc.)\n", + "import yaml\n", + "try:\n", + " from yaml import CLoader as Loader, CDumper as Dumper\n", + "except ImportError:\n", + " from yaml import Loader, Dumper\n", + " \n", + "configFile = '../cluster/data/medinfmk/ddi/config/config.yml'\n", + "\n", + "with open(configFile, 'r') as ymlfile:\n", + " cfg = yaml.load(ymlfile, Loader=Loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 1361, + "metadata": {}, + "outputs": [], + "source": [ + "pathInput = cfg['filePaths']['dirRaw']\n", + "pathOutput = cfg['filePaths']['dirProcessed']\n", + "# path to store python binary files (pickles)\n", + "# in order not to recalculate them every time\n", + "pathPickles = cfg['filePaths']['dirProcessedFiles']['dirPickles']\n", + "pathRuns = cfg['filePaths']['dirProcessedFiles']['dirRuns']\n", + "pathPaperScores = cfg['filePaths']['dirRawFiles']['paper-individual-metrics-scores']\n", + "datasetDirs = cfg['filePaths']['dirRawDatasets']\n", + "DS1_path = str(datasetDirs[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 1362, + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_data(input_fea, input_lab, seperate=False):\n", + " offside_sim_path = input_fea\n", + " drug_interaction_matrix_path = input_lab\n", + " drug_fea = np.loadtxt(offside_sim_path,dtype=float,delimiter=\",\")\n", + " interaction = np.loadtxt(drug_interaction_matrix_path,dtype=int,delimiter=\",\")\n", + " \n", + " train = []\n", + " label = []\n", + " tmp_fea=[]\n", + " drug_fea_tmp = []\n", + " \n", + " for i in range(0, (interaction.shape[0]-1)):\n", + " for j in range((i+1), interaction.shape[1]):\n", + " label.append(interaction[i,j])\n", + " drug_fea_tmp_1 = list(drug_fea[i])\n", + " drug_fea_tmp_2 = list(drug_fea[j])\n", + " if seperate:\n", + " tmp_fea = (drug_fea_tmp_1,drug_fea_tmp_2)\n", + " else:\n", + " tmp_fea = drug_fea_tmp_1 + drug_fea_tmp_2\n", + " train.append(tmp_fea)\n", + "\n", + " return np.array(train), np.array(label)" + ] + }, + { + "cell_type": "code", + "execution_count": 1363, + "metadata": {}, + "outputs": [], + "source": [ + "def transfer_array_format(data):\n", + " formated_matrix1 = []\n", + " formated_matrix2 = []\n", + " for val in data:\n", + " formated_matrix1.append(val[0])\n", + " formated_matrix2.append(val[1])\n", + " return np.array(formated_matrix1), np.array(formated_matrix2)" + ] + }, + { + "cell_type": "code", + "execution_count": 1364, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_labels(labels, encoder=None, categorical=True):\n", + " if not encoder:\n", + " encoder = LabelEncoder()\n", + " encoder.fit(labels)\n", + " y = encoder.transform(labels).astype(np.int32)\n", + " if categorical:\n", + " y = np_utils.to_categorical(y)\n", + "# print(y)\n", + " return y, encoder" + ] + }, + { + "cell_type": "code", + "execution_count": 1365, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_names(labels, encoder=None, categorical=True):\n", + " if not encoder:\n", + " encoder = LabelEncoder()\n", + " encoder.fit(labels)\n", + " if categorical:\n", + " labels = np_utils.to_categorical(labels)\n", + " return labels, encoder" + ] + }, + { + "cell_type": "code", + "execution_count": 1366, + "metadata": {}, + "outputs": [], + "source": [ + "def getStratifiedKFoldSplit(X,y,n_splits):\n", + " skf = StratifiedKFold(n_splits=n_splits, random_state=42)\n", + " return skf.split(X,y)" + ] + }, + { + "cell_type": "code", + "execution_count": 1367, + "metadata": {}, + "outputs": [], + "source": [ + "class NDD(nn.Module):\n", + " def __init__(self, D_in=1096, H1=300, H2=400, D_out=1, drop=0.5):\n", + " super(NDD, self).__init__()\n", + " # an affine operation: y = Wx + b\n", + " self.fc1 = nn.Linear(D_in, H1) # Fully Connected\n", + " self.fc2 = nn.Linear(H1, H2)\n", + " self.fc3 = nn.Linear(H2, D_out)\n", + " self.drop = nn.Dropout(drop)\n", + " self._init_weights()\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(self.fc1(x))\n", + " x = self.drop(x)\n", + " x = F.relu(self.fc2(x))\n", + " x = self.drop(x)\n", + " x = self.fc3(x)\n", + " return x\n", + " \n", + " def _init_weights(self):\n", + " for m in self.modules():\n", + " if(isinstance(m, nn.Linear)):\n", + " m.weight.data.normal_(0, 0.05)\n", + " m.bias.data.uniform_(-1,0)" + ] + }, + { + "cell_type": "code", + "execution_count": 1368, + "metadata": {}, + "outputs": [], + "source": [ + "def updateSimilarityDFSingleMetric(df, sim_type, metric, value):\n", + " df.loc[df['Similarity'] == sim_type, metric ] = round(value,3)\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 1369, + "metadata": {}, + "outputs": [], + "source": [ + "def updateSimilarityDF(df, sim_type, AUROC, AUPR, F1, Rec, Prec):\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'AUC', AUROC)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'AUPR', AUPR)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'F-measure', F1)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'Recall', Rec)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'Precision', Prec)\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 1370, + "metadata": {}, + "outputs": [], + "source": [ + "def getNetParamsStr(net, str_hidden_layers_params, net_params_to_print=[\"max_epochs\", \"batch_size\"]):\n", + " net_params = [val for sublist in [[x,net.get_params()[x]] for x in net_params_to_print] for val in sublist]\n", + " net_params_str = '-'.join(map(str, flattened))\n", + " return(net_params_str+str_hidden_layers_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 1371, + "metadata": {}, + "outputs": [], + "source": [ + "def writeReplicatedIndividualScoresCSV(net, df, destination, str_hidden_layers_params):\n", + " filePath = destination + \"replicatedIndividualScores_\" + getNetParamsStr(net, str_hidden_layers_params) + \".csv\"\n", + " df.to_csv(path_or_buf = filePath, index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 1372, + "metadata": {}, + "outputs": [], + "source": [ + "def getNDDClassifier(D_in, H1, H2, D_out, drop, Xy_test):\n", + " model = NDD(D_in, H1, H2, D_out, drop)\n", + " \n", + " net = NeuralNetClassifier(\n", + " model,\n", + "# criterion=nn.CrossEntropyLoss,\n", + " criterion=nn.BCEWithLogitsLoss,\n", + " max_epochs=20,\n", + " optimizer=SGD,\n", + " optimizer__lr=0.01,\n", + " optimizer__momentum=0.9, \n", + " optimizer__weight_decay=1e-6, \n", + " optimizer__nesterov=True, \n", + " batch_size=200,\n", + " callbacks=callbacks,\n", + " # Shuffle training data on each epoch\n", + " iterator_train__shuffle=True,\n", + " device=device,\n", + " train_split=predefined_split(Xy_test),\n", + " )\n", + " return net" + ] + }, + { + "cell_type": "code", + "execution_count": 1373, + "metadata": {}, + "outputs": [], + "source": [ + "def avgMetrics(AUROC, AUPR, F1, Rec, Prec, kfold_nsplits):\n", + " AUROC /= kfold_nsplits\n", + " AUPR /= kfold_nsplits\n", + " F1 /= kfold_nsplits\n", + " Rec /= kfold_nsplits\n", + " Prec /= kfold_nsplits\n", + " return AUROC, AUPR, F1, Rec, Prec" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run" + ] + }, + { + "cell_type": "code", + "execution_count": 1374, + "metadata": {}, + "outputs": [], + "source": [ + "df_paperIndividualScores = pd.read_csv(pathPaperScores)\n", + "\n", + "df_replicatedIndividualScores = df_paperIndividualScores.copy()\n", + "\n", + "for col in df_replicatedIndividualScores.columns:\n", + " if col != 'Similarity':\n", + " df_replicatedIndividualScores[col].values[:] = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 1375, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "soft = nn.Softmax(dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 1376, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(119902, 1096)\n", + "(119902, 1)\n" + ] + } + ], + "source": [ + "print(X_train.shape)\n", + "print(y_train.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 1377, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running fold0 for sideeffect...\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "result type Float can't be cast to the desired output type Long", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0mmodelPicklePath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpathPickles\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m\"model_params/model_params_fold\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"_\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr_hidden_layers_params\u001b[0m\u001b[0;34m+\u001b[0m \u001b[0;34m\"_\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msimilarity\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\".p\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdo_train_model\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodelPicklePath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/classifier.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0;31m# this is actually a pylint bug:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;31m# https://github.com/PyCQA/pylint/issues/1085\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mNeuralNetClassifier\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[1;32m 846\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 847\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 848\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpartial_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 849\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 850\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mpartial_fit\u001b[0;34m(self, X, y, classes, **fit_params)\u001b[0m\n\u001b[1;32m 805\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'on_train_begin'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 806\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 807\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 808\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mfit_loop\u001b[0;34m(self, X, y, epochs, **fit_params)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[0myi_res\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0myi\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0my_train_is_ph\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 738\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'on_batch_begin'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0myi_res\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 739\u001b[0;31m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 740\u001b[0m \u001b[0mtrain_batch_count\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 741\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecord_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'train_loss'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mtrain_step\u001b[0;34m(self, Xi, yi, **fit_params)\u001b[0m\n\u001b[1;32m 662\u001b[0m \u001b[0mstep_accumulator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstore_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 663\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 664\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 665\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstep_accumulator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/optim/sgd.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mclosure\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 80\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mgroup\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_groups\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mstep_fn\u001b[0;34m()\u001b[0m\n\u001b[1;32m 659\u001b[0m \u001b[0mstep_accumulator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_train_step_accumulator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 660\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstep_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 661\u001b[0;31m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step_single\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 662\u001b[0m \u001b[0mstep_accumulator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstore_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 663\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mtrain_step_single\u001b[0;34m(self, Xi, yi, **fit_params)\u001b[0m\n\u001b[1;32m 602\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 603\u001b[0m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 604\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 605\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/classifier.py\u001b[0m in \u001b[0;36mget_loss\u001b[0;34m(self, y_pred, y_true, *args, **kwargs)\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0meps\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;31m# pylint: disable=signature-differs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mget_loss\u001b[0;34m(self, y_pred, y_true, X, training)\u001b[0m\n\u001b[1;32m 1097\u001b[0m \"\"\"\n\u001b[1;32m 1098\u001b[0m \u001b[0my_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1099\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcriterion_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1101\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 541\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 542\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m 599\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 600\u001b[0m \u001b[0mpos_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpos_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 601\u001b[0;31m reduction=self.reduction)\n\u001b[0m\u001b[1;32m 602\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 603\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mbinary_cross_entropy_with_logits\u001b[0;34m(input, target, weight, size_average, reduce, reduction, pos_weight)\u001b[0m\n\u001b[1;32m 2112\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Target size ({}) must be the same as input size ({})\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2113\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2114\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbinary_cross_entropy_with_logits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpos_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction_enum\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2116\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: result type Float can't be cast to the desired output type Long" + ] + } + ], + "source": [ + "do_prepare_data = False\n", + "do_train_model = True\n", + "kfold_nsplits = 5\n", + "# similaritiesToRun = df_paperIndividualScores['Similarity']\n", + "similaritiesToRun = [\"sideeffect\"]\n", + "\n", + "for similarity in similaritiesToRun:\n", + " input_fea = pathInput+DS1_path+\"/\" + similarity + \"_Jacarrd_sim.csv\"\n", + " input_lab = pathInput+DS1_path+\"/drug_drug_matrix.csv\"\n", + " dataPicklePath = pathPickles+\"data_X_y_\" + similarity + \"_Jaccard.p\"\n", + "\n", + " # Define model\n", + " D_in, H1, H2, D_out, drop = X.shape[1], 300, 400, 1, 0.5\n", + " str_hidden_layers_params = \"-H1-\" + str(H1) + \"-H2-\" + str(H2)\n", + " callbacks = []\n", + " \n", + " # Prepare data if not available\n", + " if do_prepare_data:\n", + " print(\"Preparing \" + similarity + \" data...\")\n", + " X,y = prepare_data(input_fea, input_lab, seperate = False)\n", + "\n", + " with open(dataPicklePath, 'wb') as f:\n", + " pickle.dump([X, y], f)\n", + "\n", + " # Load X,y and split in to train, test\n", + " with open(dataPicklePath, 'rb') as f:\n", + " X, y = pickle.load(f)\n", + " \n", + "\n", + " y = np.reshape(y, (y.shape[0], 1))\n", + " \n", + " X = X.astype(np.float32)\n", + " y = y.astype(np.int64) \n", + "\n", + " \n", + "# y_cat = np_utils.to_categorical(y)\n", + " \n", + " AUROC, AUPR, F1, Rec, Prec = 0,0,0,0,0\n", + " kFoldSplit = getStratifiedKFoldSplit(X,y,n_splits=kfold_nsplits)\n", + " for i, indices in enumerate(kFoldSplit):\n", + " print(\"Running fold\" + str(i) + \" for \" + similarity +\"...\")\n", + " \n", + " train_index = indices[0]\n", + " test_index = indices[1]\n", + " X_train, X_test = X[train_index], X[test_index]\n", + " y_train, y_test = y[train_index], y[test_index]\n", + "# y_train, y_test = y_cat[train_index], y_cat[test_index]\n", + " \n", + " # Create Network Classifier\n", + " Xy_test = skorch.dataset.Dataset(X_test, y_test)\n", + " net = getNDDClassifier(D_in, H1, H2, D_out, drop, Xy_test)\n", + " \n", + " # Fit and save OR load model\n", + " modelPicklePath = pathPickles+\"model_params/model_params_fold\" + str(i) + \"_\" + str_hidden_layers_params+ \"_\" + similarity + \".p\"\n", + " if do_train_model:\n", + " net.fit(X_train, y_train)\n", + " net.save_params(f_params=modelPicklePath)\n", + " else:\n", + " net.initialize() # This is important!\n", + " net.load_params(f_params=modelPicklePath)\n", + "\n", + " # Make predictions\n", + " y_pred = net.predict(X_test)\n", + " lr_probs = soft(net.forward(X_test))[:,1]\n", + " lr_precision, lr_recall, _ = precision_recall_curve(y_test, lr_probs)\n", + "\n", + " AUROC += roc_auc_score(y_test, y_pred)\n", + " AUPR += auc(lr_recall, lr_precision)\n", + " F1 += f1_score(y_test, y_pred)\n", + " Rec += recall_score(y_test, y_pred)\n", + " Prec += precision_score(y_test, y_pred)\n", + " \n", + " print(i, similarity, AUROC, AUPR, F1, Rec, Prec)\n", + " \n", + " \n", + " AUROC, AUPR, F1, Rec, Prec = avgMetrics(AUROC, AUPR, F1, Rec, Prec, kfold_nsplits)\n", + " print(similarity, AUROC, AUPR, F1, Rec, Prec)\n", + " \n", + " # Fill replicated metrics\n", + " updateSimilarityDF(df_replicatedIndividualScores, similarity, AUROC, AUPR, F1, Rec, Prec)\n", + " \n", + "# Write CSV\n", + "writeReplicatedIndividualScoresCSV(net, df_replicatedIndividualScores, pathRuns, str_hidden_layers_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compare to Paper" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "print(df_paperIndividualScores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(df_replicatedIndividualScores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "diff_metrics = ['AUC', 'AUPR', 'F-measure', 'Recall', 'Precision']\n", + "df_diff = df_paperIndividualScores[diff_metrics] - df_replicatedIndividualScores[diff_metrics]\n", + "df_diff_abs = df_diff.abs()\n", + "df_diff_percent = (df_diff_abs / df_paperIndividualScores[diff_metrics]) * 100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_diff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from seaborn import heatmap\n", + "heatmap(df_diff, yticklabels=df_paperIndividualScores[\"Similarity\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "heatmap(df_diff_abs, yticklabels=df_paperIndividualScores[\"Similarity\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "heatmap(df_diff_percent, yticklabels=df_paperIndividualScores[\"Similarity\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import mean_squared_error\n", + "mean_squared_error(df_paperIndividualScores[diff_metrics],\n", + " df_replicatedIndividualScores[diff_metrics])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/01_KS_Skorch_DDI.ipynb b/notebooks/01_KS_Skorch_DDI.ipynb index 6f0fed1..ac0e629 100644 --- a/notebooks/01_KS_Skorch_DDI.ipynb +++ b/notebooks/01_KS_Skorch_DDI.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 810, + "execution_count": 1358, "metadata": {}, "outputs": [], "source": [ @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 811, + "execution_count": 1359, "metadata": {}, "outputs": [], "source": [ @@ -47,14 +47,16 @@ "from torch.utils.tensorboard import SummaryWriter\n", "from torch.optim import SGD\n", "\n", + "import skorch\n", "from skorch import NeuralNetClassifier\n", "from skorch.callbacks import EpochScoring\n", - "from skorch.callbacks import TensorBoard" + "from skorch.callbacks import TensorBoard\n", + "from skorch.helper import predefined_split" ] }, { "cell_type": "code", - "execution_count": 812, + "execution_count": 1360, "metadata": {}, "outputs": [], "source": [ @@ -73,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 813, + "execution_count": 1361, "metadata": {}, "outputs": [], "source": [ @@ -97,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": 814, + "execution_count": 1362, "metadata": {}, "outputs": [], "source": [ @@ -128,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 815, + "execution_count": 1363, "metadata": {}, "outputs": [], "source": [ @@ -143,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": 816, + "execution_count": 1364, "metadata": {}, "outputs": [], "source": [ @@ -154,13 +156,13 @@ " y = encoder.transform(labels).astype(np.int32)\n", " if categorical:\n", " y = np_utils.to_categorical(y)\n", - " print(y)\n", + "# print(y)\n", " return y, encoder" ] }, { "cell_type": "code", - "execution_count": 817, + "execution_count": 1365, "metadata": {}, "outputs": [], "source": [ @@ -175,59 +177,30 @@ }, { "cell_type": "code", - "execution_count": 818, + "execution_count": 1366, "metadata": {}, "outputs": [], "source": [ "def getStratifiedKFoldSplit(X,y,n_splits):\n", - " skf = StratifiedKFold(n_splits=n_splits)\n", - " return skf.split(X,y)\n", - "# skf.get_n_splits(X, y)\n", - "# for train_index, test_index in skf.split(X, y):\n", - "# X_train, X_test = X[train_index], X[test_index]\n", - "# y_train, y_test = y[train_index], y[test_index]\n", - "# return X_train, X_test, y_train, y_test" + " skf = StratifiedKFold(n_splits=n_splits, random_state=42)\n", + " return skf.split(X,y)" ] }, { "cell_type": "code", - "execution_count": 819, - "metadata": {}, - "outputs": [], - "source": [ - "# x = np.arange(100)\n", - "# y = np.random.binomial(1,0.5,100)\n", - "\n", - "# #print(x)\n", - "# #print(y)\n", - "\n", - "# skf = StratifiedKFold(n_splits=5)\n", - "# s = skf.split(x,y)\n", - "\n", - "# for i, indices in enumerate(s):\n", - "# train = indices[0]\n", - "# test = indices[1]\n", - "# print(train)\n", - "# print(test)\n", - "# # print(indices)\n", - "# print(i)\n", - "# print(\"######################\")" - ] - }, - { - "cell_type": "code", - "execution_count": 820, + "execution_count": 1367, "metadata": {}, "outputs": [], "source": [ "class NDD(nn.Module):\n", - " def __init__(self, D_in=123, H1=300, H2=400, D_out=2, drop=0.5):\n", + " def __init__(self, D_in=1096, H1=300, H2=400, D_out=1, drop=0.5):\n", " super(NDD, self).__init__()\n", " # an affine operation: y = Wx + b\n", " self.fc1 = nn.Linear(D_in, H1) # Fully Connected\n", " self.fc2 = nn.Linear(H1, H2)\n", " self.fc3 = nn.Linear(H2, D_out)\n", " self.drop = nn.Dropout(drop)\n", + " self._init_weights()\n", "\n", " def forward(self, x):\n", " x = F.relu(self.fc1(x))\n", @@ -235,12 +208,18 @@ " x = F.relu(self.fc2(x))\n", " x = self.drop(x)\n", " x = self.fc3(x)\n", - " return x" + " return x\n", + " \n", + " def _init_weights(self):\n", + " for m in self.modules():\n", + " if(isinstance(m, nn.Linear)):\n", + " m.weight.data.normal_(0, 0.05)\n", + " m.bias.data.uniform_(-1,0)" ] }, { "cell_type": "code", - "execution_count": 821, + "execution_count": 1368, "metadata": {}, "outputs": [], "source": [ @@ -251,7 +230,7 @@ }, { "cell_type": "code", - "execution_count": 822, + "execution_count": 1369, "metadata": {}, "outputs": [], "source": [ @@ -266,7 +245,7 @@ }, { "cell_type": "code", - "execution_count": 823, + "execution_count": 1370, "metadata": {}, "outputs": [], "source": [ @@ -278,7 +257,7 @@ }, { "cell_type": "code", - "execution_count": 824, + "execution_count": 1371, "metadata": {}, "outputs": [], "source": [ @@ -289,14 +268,17 @@ }, { "cell_type": "code", - "execution_count": 825, + "execution_count": 1372, "metadata": {}, "outputs": [], "source": [ - "def getNDDClassifier():\n", + "def getNDDClassifier(D_in, H1, H2, D_out, drop, Xy_test):\n", + " model = NDD(D_in, H1, H2, D_out, drop)\n", + " \n", " net = NeuralNetClassifier(\n", " model,\n", - " criterion=nn.CrossEntropyLoss,\n", + "# criterion=nn.CrossEntropyLoss,\n", + " criterion=nn.BCEWithLogitsLoss,\n", " max_epochs=20,\n", " optimizer=SGD,\n", " optimizer__lr=0.01,\n", @@ -308,13 +290,14 @@ " # Shuffle training data on each epoch\n", " iterator_train__shuffle=True,\n", " device=device,\n", + " train_split=predefined_split(Xy_test),\n", " )\n", " return net" ] }, { "cell_type": "code", - "execution_count": 826, + "execution_count": 1373, "metadata": {}, "outputs": [], "source": [ @@ -336,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 827, + "execution_count": 1374, "metadata": {}, "outputs": [], "source": [ @@ -351,7 +334,7 @@ }, { "cell_type": "code", - "execution_count": 828, + "execution_count": 1375, "metadata": {}, "outputs": [], "source": [ @@ -361,985 +344,57 @@ }, { "cell_type": "code", - "execution_count": 829, - "metadata": { - "scrolled": false - }, + "execution_count": 1376, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.6264\u001b[0m \u001b[32m0.6758\u001b[0m \u001b[35m0.6258\u001b[0m 1.9311\n", - " 2 \u001b[36m0.6171\u001b[0m 0.6758 \u001b[35m0.6226\u001b[0m 1.9490\n", - " 3 \u001b[36m0.6103\u001b[0m 0.6758 0.6514 1.7302\n", - " 4 \u001b[36m0.6050\u001b[0m 0.6758 0.6253 1.8612\n", - " 5 0.6053 0.6758 0.6260 1.9754\n", - " 6 0.6061 0.6758 \u001b[35m0.6154\u001b[0m 1.9608\n", - " 7 \u001b[36m0.6030\u001b[0m 0.6757 0.6267 1.9462\n", - " 8 0.6063 0.6758 0.6342 1.9354\n", - " 9 \u001b[36m0.6029\u001b[0m 0.6758 0.6302 1.9364\n", - " 10 0.6039 \u001b[32m0.6762\u001b[0m 0.6189 1.9230\n", - " 11 0.6061 0.6755 0.6390 1.9202\n", - " 12 \u001b[36m0.6024\u001b[0m 0.6761 0.6331 1.9181\n", - " 13 0.6025 0.6758 0.6311 1.9044\n", - " 14 0.6031 \u001b[32m0.6766\u001b[0m 0.6247 1.9200\n", - " 15 \u001b[36m0.5993\u001b[0m 0.6715 0.6303 1.8904\n", - " 16 0.6001 0.6762 0.6434 1.9544\n", - " 17 0.5995 0.6723 0.6275 1.9512\n", - " 18 \u001b[36m0.5969\u001b[0m 0.6765 0.6410 1.9422\n", - " 19 0.6000 0.6758 0.6608 1.9362\n", - " 20 0.6033 0.6765 0.6247 1.9891\n", - "0 chem 0.5020708210276184 0.35337736975618034 0.009411764705882354 0.004733971390346815 0.7931034482758621\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.6070\u001b[0m \u001b[32m0.6765\u001b[0m \u001b[35m0.6216\u001b[0m 1.9519\n", - " 2 \u001b[36m0.5995\u001b[0m \u001b[32m0.6772\u001b[0m \u001b[35m0.6208\u001b[0m 2.0301\n", - " 3 0.6031 0.6770 \u001b[35m0.6174\u001b[0m 2.0206\n", - " 4 0.6042 \u001b[32m0.6775\u001b[0m 0.6235 1.9298\n", - " 5 0.6045 0.6772 0.6236 1.9657\n", - " 6 \u001b[36m0.5993\u001b[0m 0.6773 0.6253 2.1007\n", - " 7 0.6041 0.6772 0.6196 1.9751\n", - " 8 \u001b[36m0.5982\u001b[0m 0.6768 0.6210 2.1276\n", - " 9 \u001b[36m0.5920\u001b[0m 0.6769 \u001b[35m0.6148\u001b[0m 2.5294\n", - " 10 0.6020 0.6770 0.6174 2.2505\n", - " 11 0.5981 0.6767 0.6264 2.2731\n", - " 12 0.6014 0.6769 0.6183 2.1471\n", - " 13 0.6046 0.6729 0.6205 2.0925\n", - " 14 0.6017 0.6703 0.6271 2.2000\n", - " 15 0.6050 0.6594 0.6287 2.0615\n", - " 16 0.5954 0.6762 0.6207 2.2425\n", - " 17 0.6001 0.6697 0.6238 2.0013\n", - " 18 0.5982 0.6567 0.6321 2.2593\n", - " 19 0.6024 0.6640 0.6302 1.9879\n", - " 20 0.6062 0.6605 0.6339 2.0203\n", - "1 chem 1.0032093649358742 0.6969387441862649 0.046459634453371874 0.024287331480909745 1.144955300127714\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.6106\u001b[0m \u001b[32m0.6769\u001b[0m \u001b[35m0.6226\u001b[0m 1.8913\n", - " 2 \u001b[36m0.6067\u001b[0m \u001b[32m0.6771\u001b[0m \u001b[35m0.6220\u001b[0m 1.9373\n", - " 3 \u001b[36m0.6027\u001b[0m \u001b[32m0.6777\u001b[0m \u001b[35m0.6157\u001b[0m 1.9948\n", - " 4 \u001b[36m0.5949\u001b[0m \u001b[32m0.6830\u001b[0m \u001b[35m0.6149\u001b[0m 1.9998\n", - " 5 0.5952 0.6762 0.6223 1.9508\n", - " 6 0.6029 0.6769 0.6207 1.9928\n", - " 7 0.6039 0.6773 0.6196 1.9644\n", - " 8 0.6035 0.6771 0.6183 2.2406\n", - " 9 0.6017 0.6768 0.6226 1.9536\n", - " 10 0.6080 0.6765 0.6225 2.0521\n", - " 11 0.6027 0.6765 0.6202 2.0107\n", - " 12 0.6029 0.6768 0.6223 1.9672\n", - " 13 0.5984 0.6780 0.6256 1.9688\n", - " 14 0.5992 0.6775 0.6320 1.9475\n", - " 15 0.6015 0.6765 0.6316 1.9421\n", - " 16 0.5979 0.6767 0.6244 1.9307\n", - " 17 0.6016 0.6767 0.6280 2.4530\n", - " 18 0.6017 0.6767 0.6228 2.1558\n", - " 19 0.6033 0.6768 0.6214 2.2501\n", - " 20 0.6085 0.6771 0.6170 1.8949\n", - "2 chem 1.5064641436161814 1.1424150280898917 0.06293141889668656 0.03262323762478131 1.8313959780938156\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.6054\u001b[0m \u001b[32m0.6739\u001b[0m \u001b[35m0.6245\u001b[0m 1.9645\n", - " 2 0.6060 0.6659 0.6272 2.1329\n", - " 3 0.6096 0.6692 0.6281 1.9609\n", - " 4 \u001b[36m0.6040\u001b[0m 0.6727 \u001b[35m0.6231\u001b[0m 1.9600\n", - " 5 \u001b[36m0.5943\u001b[0m 0.6669 0.6250 1.9544\n", - " 6 0.6010 0.6727 0.6264 2.0072\n", - " 7 0.6008 0.6719 \u001b[35m0.6187\u001b[0m 2.0353\n", - " 8 \u001b[36m0.5932\u001b[0m \u001b[32m0.6743\u001b[0m 0.6240 1.9501\n", - " 9 0.6068 \u001b[32m0.6765\u001b[0m 0.6260 2.1851\n", - " 10 0.6066 \u001b[32m0.6769\u001b[0m 0.6228 1.9564\n", - " 11 0.6042 \u001b[32m0.6770\u001b[0m 0.6246 1.8944\n", - " 12 0.5975 \u001b[32m0.6774\u001b[0m 0.6209 1.9510\n", - " 13 0.6076 0.6771 0.6238 1.9444\n", - " 14 0.6024 \u001b[32m0.6777\u001b[0m 0.6192 1.8995\n", - " 15 0.5996 0.6770 0.6285 1.9817\n", - " 16 0.6078 0.6776 0.6319 2.1958\n", - " 17 0.6061 \u001b[32m0.6779\u001b[0m 0.6243 2.2466\n", - " 18 0.6068 0.6758 0.6233 1.9894\n", - " 19 0.5997 0.6761 0.6236 2.0110\n", - " 20 0.6016 0.6757 0.6263 2.4612\n", - "3 chem 2.0055205553816124 1.4813323443132456 0.08796344806787015 0.04569311515899969 2.126744815303118\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.6173\u001b[0m \u001b[32m0.6758\u001b[0m \u001b[35m0.6281\u001b[0m 2.2658\n", - " 2 0.6201 0.6758 \u001b[35m0.6246\u001b[0m 2.0893\n", - " 3 \u001b[36m0.6132\u001b[0m 0.6758 0.6262 2.3175\n", - " 4 0.6146 0.6758 0.6277 1.9338\n", - " 5 \u001b[36m0.6126\u001b[0m 0.6758 \u001b[35m0.6222\u001b[0m 1.9317\n", - " 6 \u001b[36m0.6125\u001b[0m \u001b[32m0.6759\u001b[0m 0.6227 1.9300\n", - " 7 \u001b[36m0.6120\u001b[0m 0.6758 0.6260 1.8792\n", - " 8 0.6174 0.6757 0.6264 1.9285\n", - " 9 0.6140 0.6758 0.6250 1.9282\n", - " 10 0.6126 0.6758 0.6278 1.9924\n", - " 11 0.6189 0.6758 0.6290 1.9383\n", - " 12 0.6138 0.6758 0.6289 1.9302\n", - " 13 0.6131 0.6758 0.6300 1.9215\n", - " 14 0.6128 0.6758 0.6298 1.9982\n", - " 15 0.6218 0.6758 0.6293 1.9817\n", - " 16 0.6226 0.6758 0.6271 1.9578\n", - " 17 0.6199 0.6758 0.6291 2.0317\n", - " 18 \u001b[36m0.6112\u001b[0m 0.6759 0.6282 1.9412\n", - " 19 0.6113 0.6758 0.6270 1.8891\n", - " 20 0.6157 0.6758 0.6232 2.2322\n", - "4 chem 2.5059837089427486 1.8974985160803042 0.08981434781080075 0.04661942228127223 3.126744815303118\n", - "chem 0.5011967417885497 0.37949970321606086 0.01796286956216015 0.009323884456254445 0.6253489630606236\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.6262\u001b[0m \u001b[32m0.6759\u001b[0m \u001b[35m0.6335\u001b[0m 1.9797\n", - " 2 \u001b[36m0.5746\u001b[0m \u001b[32m0.6948\u001b[0m \u001b[35m0.6252\u001b[0m 2.2008\n", - " 3 \u001b[36m0.4794\u001b[0m \u001b[32m0.7114\u001b[0m 0.6631 1.9711\n", - " 4 \u001b[36m0.4259\u001b[0m \u001b[32m0.7167\u001b[0m 0.6398 2.0485\n", - " 5 \u001b[36m0.4047\u001b[0m \u001b[32m0.7232\u001b[0m 0.6334 1.9318\n", - " 6 \u001b[36m0.3918\u001b[0m 0.7146 0.7058 1.9769\n", - " 7 \u001b[36m0.3825\u001b[0m \u001b[32m0.7252\u001b[0m 0.6673 1.9610\n", - " 8 \u001b[36m0.3751\u001b[0m 0.7247 0.6693 1.9740\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " 9 \u001b[36m0.3671\u001b[0m \u001b[32m0.7259\u001b[0m 0.6831 2.1066\n", - " 10 \u001b[36m0.3620\u001b[0m 0.7226 0.6636 1.9562\n", - " 11 \u001b[36m0.3584\u001b[0m 0.7247 0.7153 1.9796\n", - " 12 \u001b[36m0.3554\u001b[0m 0.7239 0.6994 1.9612\n", - " 13 \u001b[36m0.3493\u001b[0m 0.7234 0.7160 1.9435\n", - " 14 \u001b[36m0.3457\u001b[0m 0.7222 0.7878 1.9495\n", - " 15 \u001b[36m0.3423\u001b[0m \u001b[32m0.7277\u001b[0m 0.7111 1.9528\n", - " 16 \u001b[36m0.3395\u001b[0m 0.7231 0.7155 1.9537\n", - " 17 \u001b[36m0.3373\u001b[0m 0.7243 0.7194 1.9474\n", - " 18 \u001b[36m0.3346\u001b[0m 0.7244 0.7626 1.9957\n", - " 19 \u001b[36m0.3318\u001b[0m 0.7207 0.7284 2.2627\n", - " 20 \u001b[36m0.3295\u001b[0m \u001b[32m0.7286\u001b[0m 0.7410 1.9731\n", - "0 target 0.6321751182635624 0.5825907140925904 0.454974358974359 0.3423896264279098 0.6778728606356969\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.3375\u001b[0m \u001b[32m0.7423\u001b[0m \u001b[35m0.6494\u001b[0m 1.9436\n", - " 2 \u001b[36m0.3319\u001b[0m \u001b[32m0.7508\u001b[0m \u001b[35m0.6391\u001b[0m 1.9278\n", - " 3 \u001b[36m0.3276\u001b[0m 0.7419 0.6625 1.9632\n", - " 4 \u001b[36m0.3239\u001b[0m 0.7467 0.6600 1.9579\n", - " 5 \u001b[36m0.3217\u001b[0m 0.7460 0.6562 2.0545\n", - " 6 \u001b[36m0.3196\u001b[0m 0.7458 0.6703 1.9170\n", - " 7 \u001b[36m0.3173\u001b[0m 0.7449 0.6628 2.2455\n", - " 8 \u001b[36m0.3152\u001b[0m 0.7444 0.6789 2.2802\n", - " 9 \u001b[36m0.3134\u001b[0m 0.7444 0.7019 1.9199\n", - " 10 \u001b[36m0.3124\u001b[0m 0.7490 0.6841 2.0376\n", - " 11 \u001b[36m0.3095\u001b[0m 0.7491 0.6997 1.9526\n", - " 12 \u001b[36m0.3088\u001b[0m 0.7505 0.6875 1.9527\n", - " 13 \u001b[36m0.3069\u001b[0m 0.7469 0.6888 1.9839\n", - " 14 \u001b[36m0.3052\u001b[0m \u001b[32m0.7513\u001b[0m 0.6887 1.9408\n", - " 15 \u001b[36m0.3027\u001b[0m 0.7458 0.6975 1.9466\n", - " 16 \u001b[36m0.3020\u001b[0m 0.7498 0.7064 2.4493\n", - " 17 \u001b[36m0.3018\u001b[0m 0.7464 0.7197 2.3998\n", - " 18 \u001b[36m0.2985\u001b[0m 0.7463 0.6906 2.3729\n", - " 19 \u001b[36m0.2973\u001b[0m 0.7469 0.7283 2.4179\n", - " 20 \u001b[36m0.2970\u001b[0m 0.7486 0.7242 2.3620\n", - "1 target 1.2839250081314222 1.1966116331896908 0.9530586501620986 0.7370587629926932 1.3528112626068338\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.3532\u001b[0m \u001b[32m0.7370\u001b[0m \u001b[35m0.6903\u001b[0m 1.9463\n", - " 2 \u001b[36m0.3319\u001b[0m \u001b[32m0.7403\u001b[0m \u001b[35m0.6842\u001b[0m 1.9979\n", - " 3 \u001b[36m0.3224\u001b[0m \u001b[32m0.7409\u001b[0m \u001b[35m0.6801\u001b[0m 1.9669\n", - " 4 \u001b[36m0.3171\u001b[0m \u001b[32m0.7466\u001b[0m 0.6892 1.9785\n", - " 5 \u001b[36m0.3139\u001b[0m \u001b[32m0.7480\u001b[0m 0.6880 1.9434\n", - " 6 \u001b[36m0.3101\u001b[0m 0.7453 0.6815 1.9039\n", - " 7 \u001b[36m0.3062\u001b[0m 0.7473 0.6943 1.9172\n", - " 8 \u001b[36m0.3045\u001b[0m 0.7433 0.6969 1.9958\n", - " 9 \u001b[36m0.3015\u001b[0m 0.7325 0.7710 1.9902\n", - " 10 \u001b[36m0.2994\u001b[0m 0.7340 0.7264 1.9469\n", - " 11 \u001b[36m0.2983\u001b[0m 0.7393 0.7254 1.9816\n", - " 12 \u001b[36m0.2959\u001b[0m 0.7373 0.7486 1.9743\n", - " 13 \u001b[36m0.2946\u001b[0m 0.7420 0.7606 1.9401\n", - " 14 \u001b[36m0.2922\u001b[0m 0.7385 0.7949 1.9742\n", - " 15 \u001b[36m0.2909\u001b[0m 0.7371 0.7616 1.9609\n", - " 16 \u001b[36m0.2888\u001b[0m 0.7433 0.7248 1.9965\n", - " 17 \u001b[36m0.2875\u001b[0m 0.7420 0.7409 1.9803\n", - " 18 0.2878 \u001b[32m0.7521\u001b[0m 0.7031 2.0235\n", - " 19 \u001b[36m0.2862\u001b[0m 0.7385 0.7798 1.9928\n", - " 20 \u001b[36m0.2832\u001b[0m 0.7395 0.7997 1.9513\n", - "2 target 2.0640001575155913 1.9922932634389086 1.6589858674207605 1.4038283420808892 2.1027823240722814\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.3053\u001b[0m \u001b[32m0.7439\u001b[0m \u001b[35m0.7349\u001b[0m 1.9421\n", - " 2 \u001b[36m0.2896\u001b[0m \u001b[32m0.7472\u001b[0m 0.7387 1.9075\n", - " 3 \u001b[36m0.2825\u001b[0m 0.7355 0.7866 1.9415\n", - " 4 \u001b[36m0.2798\u001b[0m 0.7366 0.8267 2.0192\n", - " 5 \u001b[36m0.2766\u001b[0m 0.7400 0.7578 1.9606\n", - " 6 \u001b[36m0.2725\u001b[0m 0.7341 0.8071 1.9675\n", - " 7 \u001b[36m0.2712\u001b[0m 0.7412 0.7890 1.9679\n", - " 8 \u001b[36m0.2677\u001b[0m 0.7409 0.8309 1.9612\n", - " 9 0.2680 0.7404 0.7768 2.2084\n", - " 10 \u001b[36m0.2655\u001b[0m 0.7326 0.8460 1.9428\n", - " 11 \u001b[36m0.2636\u001b[0m 0.7351 0.8235 1.9421\n", - " 12 \u001b[36m0.2626\u001b[0m 0.7383 0.8207 1.9317\n", - " 13 \u001b[36m0.2622\u001b[0m 0.7300 0.8517 1.9141\n", - " 14 \u001b[36m0.2598\u001b[0m 0.7342 0.8603 1.9179\n", - " 15 \u001b[36m0.2592\u001b[0m 0.7328 0.8541 1.9451\n", - " 16 \u001b[36m0.2580\u001b[0m 0.7334 0.8388 1.9098\n", - " 17 \u001b[36m0.2546\u001b[0m 0.7310 0.8543 1.9349\n", - " 18 0.2563 0.7303 0.8831 1.8856\n", - " 19 \u001b[36m0.2529\u001b[0m 0.7288 0.8675 2.0102\n", - " 20 0.2530 0.7306 0.8806 1.9563\n", - "3 target 2.83191881896769 2.7636896826616555 2.3447668478328176 2.0992075743542244 2.7792264125696287\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.2924\u001b[0m \u001b[32m0.7406\u001b[0m \u001b[35m0.8073\u001b[0m 1.9226\n", - " 2 \u001b[36m0.2779\u001b[0m 0.7356 0.8191 1.9693\n", - " 3 \u001b[36m0.2719\u001b[0m 0.7376 0.8348 1.9791\n", - " 4 \u001b[36m0.2696\u001b[0m \u001b[32m0.7422\u001b[0m 0.8160 1.9450\n", - " 5 \u001b[36m0.2658\u001b[0m 0.7417 0.8097 1.9260\n", - " 6 \u001b[36m0.2626\u001b[0m 0.7343 0.8447 1.9243\n", - " 7 \u001b[36m0.2614\u001b[0m 0.7285 0.8861 2.5562\n", - " 8 \u001b[36m0.2606\u001b[0m 0.7347 0.8403 2.5879\n", - " 9 \u001b[36m0.2594\u001b[0m 0.7357 0.8425 2.5267\n", - " 10 \u001b[36m0.2582\u001b[0m 0.7321 0.9104 2.4312\n", - " 11 \u001b[36m0.2558\u001b[0m 0.7346 0.8662 2.4281\n", - " 12 0.2563 0.7331 0.8522 1.9726\n", - " 13 \u001b[36m0.2529\u001b[0m 0.7371 0.8225 1.9481\n", - " 14 \u001b[36m0.2517\u001b[0m 0.7309 0.8443 1.9445\n", - " 15 \u001b[36m0.2515\u001b[0m 0.7340 0.8640 2.1017\n", - " 16 \u001b[36m0.2514\u001b[0m 0.7330 0.8448 1.9791\n", - " 17 \u001b[36m0.2495\u001b[0m 0.7325 0.8319 1.9765\n", - " 18 \u001b[36m0.2474\u001b[0m 0.7353 0.8358 2.0280\n", - " 19 \u001b[36m0.2474\u001b[0m 0.7336 0.8561 2.1248\n", - " 20 \u001b[36m0.2457\u001b[0m 0.7313 0.8815 1.8951\n", - "4 target 3.653140805176707 3.5984472146950415 3.102673455039748 2.860117413794323 3.534153401235002\n", - "target 0.7306281610353415 0.7196894429390083 0.6205346910079496 0.5720234827588646 0.7068306802470004\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.6123\u001b[0m \u001b[32m0.6854\u001b[0m \u001b[35m0.6098\u001b[0m 1.7297\n", - " 2 \u001b[36m0.5820\u001b[0m 0.6741 \u001b[35m0.6050\u001b[0m 1.9344\n", - " 3 \u001b[36m0.5563\u001b[0m \u001b[32m0.6933\u001b[0m 0.6177 1.9465\n", - " 4 \u001b[36m0.5389\u001b[0m \u001b[32m0.6949\u001b[0m 0.6085 1.9076\n", - " 5 \u001b[36m0.5295\u001b[0m \u001b[32m0.6951\u001b[0m \u001b[35m0.5954\u001b[0m 1.8626\n", - " 6 \u001b[36m0.5226\u001b[0m \u001b[32m0.7021\u001b[0m 0.6155 1.8612\n", - " 7 \u001b[36m0.5187\u001b[0m 0.7008 0.6308 1.9678\n", - " 8 \u001b[36m0.5148\u001b[0m \u001b[32m0.7042\u001b[0m 0.6142 1.9504\n", - " 9 \u001b[36m0.5116\u001b[0m 0.7008 0.6255 1.9460\n", - " 10 \u001b[36m0.5092\u001b[0m 0.7038 0.6254 1.9301\n", - " 11 \u001b[36m0.5071\u001b[0m 0.7031 0.6130 1.9370\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " 12 \u001b[36m0.5068\u001b[0m 0.7040 0.6102 1.9916\n", - " 13 \u001b[36m0.5042\u001b[0m 0.7037 0.6014 1.9493\n", - " 14 \u001b[36m0.5025\u001b[0m \u001b[32m0.7073\u001b[0m \u001b[35m0.5907\u001b[0m 1.9475\n", - " 15 \u001b[36m0.5014\u001b[0m 0.7036 0.6435 1.9280\n", - " 16 \u001b[36m0.5011\u001b[0m 0.7062 0.5945 1.9662\n", - " 17 \u001b[36m0.5001\u001b[0m 0.7057 0.6060 1.8562\n", - " 18 \u001b[36m0.4980\u001b[0m 0.7062 0.6120 1.9627\n", - " 19 0.4982 0.7066 0.6062 1.9322\n", - " 20 \u001b[36m0.4967\u001b[0m \u001b[32m0.7156\u001b[0m \u001b[35m0.5863\u001b[0m 1.9169\n", - "0 transporter 0.604867003182513 0.5284542578381921 0.39782362916902203 0.2896984666049192 0.6347237880496054\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.4984\u001b[0m \u001b[32m0.7088\u001b[0m \u001b[35m0.6169\u001b[0m 2.1633\n", - " 2 \u001b[36m0.4964\u001b[0m 0.7006 0.6235 2.1607\n", - " 3 \u001b[36m0.4931\u001b[0m \u001b[32m0.7143\u001b[0m \u001b[35m0.6157\u001b[0m 1.9774\n", - " 4 0.4932 0.7066 \u001b[35m0.6123\u001b[0m 1.9500\n", - " 5 0.4932 0.6962 \u001b[35m0.6009\u001b[0m 2.1496\n", - " 6 \u001b[36m0.4921\u001b[0m 0.7046 0.6149 1.9375\n", - " 7 \u001b[36m0.4911\u001b[0m 0.7089 \u001b[35m0.5953\u001b[0m 1.9143\n", - " 8 0.4912 0.7044 0.6034 1.9727\n", - " 9 \u001b[36m0.4896\u001b[0m 0.7016 0.6148 2.0145\n", - " 10 0.4910 0.7068 0.5991 1.9927\n", - " 11 \u001b[36m0.4889\u001b[0m 0.6979 0.5989 1.9505\n", - " 12 \u001b[36m0.4880\u001b[0m 0.7081 0.5983 1.9237\n", - " 13 0.4883 0.6959 0.6144 1.9153\n", - " 14 \u001b[36m0.4875\u001b[0m 0.6986 0.6095 1.9863\n", - " 15 0.4885 0.6986 0.6061 1.9928\n", - " 16 \u001b[36m0.4872\u001b[0m 0.7062 0.6072 1.9888\n", - " 17 \u001b[36m0.4866\u001b[0m 0.6895 0.6117 2.2049\n", - " 18 \u001b[36m0.4859\u001b[0m 0.6988 0.6017 1.9367\n", - " 19 0.4860 0.6987 0.6014 1.9559\n", - " 20 0.4859 0.7008 0.6039 1.9841\n", - "1 transporter 1.2455616865634491 1.10045902471084 0.8910950356237082 0.7159617165791912 1.2200000350469207\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.5110\u001b[0m \u001b[32m0.7208\u001b[0m \u001b[35m0.5811\u001b[0m 1.9756\n", - " 2 \u001b[36m0.5025\u001b[0m \u001b[32m0.7232\u001b[0m 0.6026 1.9281\n", - " 3 \u001b[36m0.4991\u001b[0m 0.7176 0.5875 1.8975\n", - " 4 \u001b[36m0.4971\u001b[0m \u001b[32m0.7285\u001b[0m \u001b[35m0.5794\u001b[0m 1.9587\n", - " 5 \u001b[36m0.4953\u001b[0m 0.7230 0.5884 1.9504\n", - " 6 \u001b[36m0.4942\u001b[0m 0.7250 0.5938 1.9355\n", - " 7 \u001b[36m0.4934\u001b[0m 0.7280 0.5839 1.9280\n", - " 8 \u001b[36m0.4927\u001b[0m \u001b[32m0.7294\u001b[0m 0.5819 1.8863\n", - " 9 \u001b[36m0.4914\u001b[0m 0.7275 0.5921 1.8466\n", - " 10 \u001b[36m0.4908\u001b[0m \u001b[32m0.7310\u001b[0m 0.5837 1.9351\n", - " 11 \u001b[36m0.4903\u001b[0m 0.7256 0.6033 1.9167\n", - " 12 \u001b[36m0.4890\u001b[0m \u001b[32m0.7328\u001b[0m \u001b[35m0.5761\u001b[0m 2.0199\n", - " 13 0.4891 \u001b[32m0.7341\u001b[0m 0.5812 1.9467\n", - " 14 \u001b[36m0.4889\u001b[0m 0.7275 0.5955 1.8931\n", - " 15 \u001b[36m0.4881\u001b[0m 0.7285 0.5877 1.9648\n", - " 16 \u001b[36m0.4878\u001b[0m 0.7295 0.5897 1.9696\n", - " 17 \u001b[36m0.4872\u001b[0m 0.7284 0.5806 1.9431\n", - " 18 \u001b[36m0.4864\u001b[0m 0.7240 0.5778 1.9727\n", - " 19 0.4866 0.7255 0.5991 1.9529\n", - " 20 0.4872 0.7314 0.5847 1.9538\n", - "2 transporter 1.9151766653330569 1.7452697608860523 1.420756597842213 1.1395492435937018 1.9266094771070064\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.4783\u001b[0m \u001b[32m0.7279\u001b[0m \u001b[35m0.5876\u001b[0m 1.9465\n", - " 2 \u001b[36m0.4709\u001b[0m 0.7277 0.5888 1.9567\n", - " 3 \u001b[36m0.4688\u001b[0m \u001b[32m0.7295\u001b[0m 0.5896 2.6437\n", - " 4 \u001b[36m0.4666\u001b[0m 0.7282 0.5958 1.9544\n", - " 5 0.4669 \u001b[32m0.7300\u001b[0m 0.5988 1.9842\n", - " 6 \u001b[36m0.4656\u001b[0m 0.7246 0.5990 1.9546\n", - " 7 \u001b[36m0.4643\u001b[0m 0.7278 0.5995 1.9703\n", - " 8 \u001b[36m0.4626\u001b[0m 0.7263 0.5928 2.0193\n", - " 9 0.4627 0.7295 0.5988 2.0183\n", - " 10 \u001b[36m0.4616\u001b[0m 0.7284 0.5920 1.9733\n", - " 11 0.4621 0.7284 0.5980 1.9807\n", - " 12 \u001b[36m0.4612\u001b[0m 0.7274 0.5999 1.9993\n", - " 13 \u001b[36m0.4604\u001b[0m 0.7294 0.6000 1.9641\n", - " 14 0.4607 0.7210 0.6012 1.9891\n", - " 15 \u001b[36m0.4598\u001b[0m 0.7269 0.6010 2.0734\n", - " 16 \u001b[36m0.4593\u001b[0m \u001b[32m0.7301\u001b[0m 0.5906 1.9172\n", - " 17 \u001b[36m0.4593\u001b[0m \u001b[32m0.7303\u001b[0m 0.5935 1.7595\n", - " 18 \u001b[36m0.4587\u001b[0m \u001b[32m0.7312\u001b[0m 0.5937 1.7492\n", - " 19 \u001b[36m0.4585\u001b[0m 0.7259 0.6009 1.7498\n", - " 20 \u001b[36m0.4583\u001b[0m 0.7295 0.6002 1.7923\n", - "3 transporter 2.5311426903348315 2.2664165877221216 1.851704683630948 1.471441803025625 2.5408951913927207\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.4865\u001b[0m \u001b[32m0.7242\u001b[0m \u001b[35m0.6016\u001b[0m 1.7812\n", - " 2 \u001b[36m0.4793\u001b[0m \u001b[32m0.7343\u001b[0m \u001b[35m0.5845\u001b[0m 1.8617\n", - " 3 \u001b[36m0.4769\u001b[0m 0.7302 0.5933 1.8455\n", - " 4 \u001b[36m0.4752\u001b[0m 0.7327 0.5849 1.7507\n", - " 5 \u001b[36m0.4745\u001b[0m 0.7308 0.5910 1.7677\n", - " 6 \u001b[36m0.4736\u001b[0m 0.7306 0.5846 1.7744\n", - " 7 0.4738 0.7307 0.5871 1.7992\n", - " 8 \u001b[36m0.4728\u001b[0m \u001b[32m0.7350\u001b[0m \u001b[35m0.5843\u001b[0m 1.8340\n", - " 9 \u001b[36m0.4718\u001b[0m \u001b[32m0.7350\u001b[0m \u001b[35m0.5790\u001b[0m 1.7645\n", - " 10 0.4719 0.7265 0.6009 1.7548\n", - " 11 \u001b[36m0.4707\u001b[0m 0.7320 0.5969 1.9489\n", - " 12 \u001b[36m0.4703\u001b[0m 0.7311 0.5913 1.7200\n", - " 13 \u001b[36m0.4699\u001b[0m 0.7348 0.5825 1.9105\n", - " 14 0.4703 0.7312 0.5875 1.7261\n", - " 15 \u001b[36m0.4699\u001b[0m 0.7322 0.5857 1.7617\n", - " 16 0.4704 0.7326 \u001b[35m0.5786\u001b[0m 1.7563\n", - " 17 0.4699 0.7316 0.5913 1.8520\n", - " 18 \u001b[36m0.4698\u001b[0m 0.7331 0.5904 1.7504\n", - " 19 0.4699 0.7290 0.5952 1.7533\n", - " 20 \u001b[36m0.4687\u001b[0m \u001b[32m0.7351\u001b[0m 0.5855 1.7479\n", - "4 transporter 3.195403209591113 2.8165947011608097 2.4018703366003393 2.0440025276036407 3.070351744167062\n", - "transporter 0.6390806419182227 0.5633189402321619 0.4803740673200679 0.4088005055207281 0.6140703488334124\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.6094\u001b[0m \u001b[32m0.6815\u001b[0m \u001b[35m0.6094\u001b[0m 1.7799\n", - " 2 \u001b[36m0.5624\u001b[0m \u001b[32m0.6882\u001b[0m \u001b[35m0.6062\u001b[0m 1.8374\n", - " 3 \u001b[36m0.5152\u001b[0m 0.6868 \u001b[35m0.6033\u001b[0m 2.2703\n", - " 4 \u001b[36m0.4886\u001b[0m 0.6682 0.6115 1.8077\n", - " 5 \u001b[36m0.4749\u001b[0m \u001b[32m0.6934\u001b[0m 0.6041 2.0013\n", - " 6 \u001b[36m0.4667\u001b[0m 0.6933 \u001b[35m0.5980\u001b[0m 1.9615\n", - " 7 \u001b[36m0.4599\u001b[0m \u001b[32m0.7154\u001b[0m 0.5981 2.0180\n", - " 8 \u001b[36m0.4549\u001b[0m 0.6944 0.6014 1.9747\n", - " 9 \u001b[36m0.4507\u001b[0m 0.6848 0.6064 1.9987\n", - " 10 \u001b[36m0.4447\u001b[0m 0.6988 0.6047 2.0938\n", - " 11 \u001b[36m0.4426\u001b[0m 0.6915 0.6116 2.0070\n", - " 12 \u001b[36m0.4393\u001b[0m 0.6948 0.6062 2.0641\n", - " 13 \u001b[36m0.4380\u001b[0m 0.6969 0.6201 1.9517\n", - " 14 \u001b[36m0.4353\u001b[0m 0.7064 0.6021 1.9757\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " 15 \u001b[36m0.4344\u001b[0m 0.6873 0.6205 2.0295\n", - " 16 \u001b[36m0.4314\u001b[0m 0.7136 0.6098 2.0667\n", - " 17 \u001b[36m0.4300\u001b[0m 0.6909 0.6374 2.3387\n", - " 18 \u001b[36m0.4276\u001b[0m 0.7127 0.6122 1.9974\n", - " 19 \u001b[36m0.4264\u001b[0m 0.6961 0.6166 2.5295\n", - " 20 \u001b[36m0.4253\u001b[0m 0.6985 0.5985 2.4709\n", - "0 enzyme 0.6190775352973376 0.5889849259751829 0.4264428121720882 0.31367706082124114 0.6657929226736566\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.4277\u001b[0m \u001b[32m0.7210\u001b[0m \u001b[35m0.5865\u001b[0m 2.4114\n", - " 2 \u001b[36m0.4260\u001b[0m 0.7169 \u001b[35m0.5862\u001b[0m 2.3603\n", - " 3 \u001b[36m0.4239\u001b[0m \u001b[32m0.7219\u001b[0m 0.5938 2.4785\n", - " 4 \u001b[36m0.4227\u001b[0m \u001b[32m0.7363\u001b[0m \u001b[35m0.5797\u001b[0m 2.4253\n", - " 5 \u001b[36m0.4209\u001b[0m 0.7235 0.5882 2.3830\n", - " 6 \u001b[36m0.4207\u001b[0m 0.7139 0.6052 1.9438\n", - " 7 \u001b[36m0.4182\u001b[0m 0.7226 0.5930 1.9795\n", - " 8 \u001b[36m0.4177\u001b[0m 0.7207 0.5948 1.9741\n", - " 9 \u001b[36m0.4165\u001b[0m 0.7239 0.5901 1.7555\n", - " 10 \u001b[36m0.4159\u001b[0m 0.7185 0.6161 1.8001\n", - " 11 \u001b[36m0.4137\u001b[0m 0.7229 0.6067 1.9412\n", - " 12 0.4137 0.7210 0.6233 1.9256\n", - " 13 \u001b[36m0.4133\u001b[0m 0.7242 0.6082 1.9524\n", - " 14 \u001b[36m0.4120\u001b[0m 0.7193 0.6059 1.9634\n", - " 15 0.4121 0.7224 0.6103 1.9603\n", - " 16 \u001b[36m0.4105\u001b[0m 0.7219 0.6129 1.9536\n", - " 17 \u001b[36m0.4101\u001b[0m 0.7277 0.6071 1.9781\n", - " 18 \u001b[36m0.4088\u001b[0m 0.7254 0.6009 1.9513\n", - " 19 0.4089 0.7249 0.5934 1.9786\n", - " 20 \u001b[36m0.4075\u001b[0m 0.7298 0.5926 2.0322\n", - "1 enzyme 1.252099886586031 1.123179523900371 0.906873691884835 0.7267675208397654 1.2397954966762308\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.4471\u001b[0m \u001b[32m0.7334\u001b[0m \u001b[35m0.5677\u001b[0m 2.0238\n", - " 2 \u001b[36m0.4348\u001b[0m 0.7244 0.5773 1.9706\n", - " 3 \u001b[36m0.4303\u001b[0m 0.7214 0.5857 2.0433\n", - " 4 \u001b[36m0.4267\u001b[0m 0.7213 0.5911 1.9804\n", - " 5 \u001b[36m0.4232\u001b[0m 0.7280 0.5875 1.9720\n", - " 6 \u001b[36m0.4217\u001b[0m 0.7199 0.6019 2.0173\n", - " 7 \u001b[36m0.4206\u001b[0m 0.7199 0.5883 1.9599\n", - " 8 \u001b[36m0.4197\u001b[0m 0.7177 0.6327 2.0358\n", - " 9 \u001b[36m0.4176\u001b[0m 0.7263 0.5865 1.9600\n", - " 10 \u001b[36m0.4171\u001b[0m 0.7205 0.5892 2.0079\n", - " 11 \u001b[36m0.4153\u001b[0m 0.7205 0.6047 1.9784\n", - " 12 \u001b[36m0.4149\u001b[0m 0.7197 0.6064 1.9925\n", - " 13 0.4155 0.7266 0.5950 2.2339\n", - " 14 0.4151 0.7212 0.6266 2.0097\n", - " 15 \u001b[36m0.4137\u001b[0m 0.7169 0.6459 2.0609\n", - " 16 \u001b[36m0.4136\u001b[0m 0.7135 0.6377 1.9680\n", - " 17 \u001b[36m0.4120\u001b[0m 0.7191 0.6222 2.0250\n", - " 18 \u001b[36m0.4119\u001b[0m 0.7210 0.6176 2.1641\n", - " 19 \u001b[36m0.4109\u001b[0m 0.7143 0.6315 2.2735\n", - " 20 \u001b[36m0.4103\u001b[0m 0.7156 0.6131 1.9396\n", - "2 enzyme 1.9373091374998799 1.8199478352548872 1.461235756333355 1.1622928887516724 2.0021806525040096\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.4233\u001b[0m \u001b[32m0.7188\u001b[0m \u001b[35m0.5993\u001b[0m 2.1753\n", - " 2 \u001b[36m0.4145\u001b[0m 0.7182 \u001b[35m0.5886\u001b[0m 1.9684\n", - " 3 \u001b[36m0.4104\u001b[0m \u001b[32m0.7254\u001b[0m 0.5997 1.9410\n", - " 4 \u001b[36m0.4085\u001b[0m 0.7205 0.5899 1.9391\n", - " 5 \u001b[36m0.4069\u001b[0m 0.7134 0.5995 1.9285\n", - " 6 \u001b[36m0.4055\u001b[0m 0.7245 \u001b[35m0.5770\u001b[0m 1.9491\n", - " 7 \u001b[36m0.4045\u001b[0m 0.7179 0.5920 1.9267\n", - " 8 \u001b[36m0.4037\u001b[0m 0.7236 0.5825 1.9322\n", - " 9 \u001b[36m0.4022\u001b[0m 0.7184 0.5953 1.9136\n", - " 10 \u001b[36m0.4000\u001b[0m 0.7198 0.5835 1.9510\n", - " 11 \u001b[36m0.3995\u001b[0m 0.7214 0.5945 1.8964\n", - " 12 0.4002 0.7162 0.6043 1.9169\n", - " 13 \u001b[36m0.3980\u001b[0m 0.7175 0.5972 1.9553\n", - " 14 \u001b[36m0.3980\u001b[0m 0.7199 0.5888 1.8916\n", - " 15 \u001b[36m0.3975\u001b[0m 0.7228 0.6038 1.9653\n", - " 16 0.3979 0.7177 0.5904 2.3466\n", - " 17 \u001b[36m0.3964\u001b[0m 0.7214 0.6104 2.4008\n", - " 18 0.3967 0.7240 0.5872 2.3039\n", - " 19 \u001b[36m0.3951\u001b[0m 0.7219 0.5899 1.9289\n", - " 20 \u001b[36m0.3945\u001b[0m \u001b[32m0.7257\u001b[0m 0.5974 2.1648\n", - "3 enzyme 2.638070613725886 2.46250618476942 2.0496222500826247 1.6805598435731193 2.6826292405326546\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.4241\u001b[0m \u001b[32m0.7196\u001b[0m \u001b[35m0.5639\u001b[0m 1.9783\n", - " 2 \u001b[36m0.4153\u001b[0m \u001b[32m0.7217\u001b[0m 0.5726 2.3067\n", - " 3 \u001b[36m0.4122\u001b[0m 0.7134 0.5981 2.0008\n", - " 4 \u001b[36m0.4105\u001b[0m 0.7210 0.5657 2.3765\n", - " 5 \u001b[36m0.4091\u001b[0m \u001b[32m0.7234\u001b[0m 0.5704 2.2906\n", - " 6 \u001b[36m0.4077\u001b[0m 0.7223 0.5706 2.4286\n", - " 7 \u001b[36m0.4068\u001b[0m \u001b[32m0.7326\u001b[0m 0.5684 2.3401\n", - " 8 \u001b[36m0.4066\u001b[0m \u001b[32m0.7360\u001b[0m 0.5670 2.2834\n", - " 9 \u001b[36m0.4052\u001b[0m 0.7225 0.5831 2.3926\n", - " 10 \u001b[36m0.4051\u001b[0m 0.7174 0.5897 1.9899\n", - " 11 \u001b[36m0.4048\u001b[0m 0.7189 0.5917 1.9511\n", - " 12 \u001b[36m0.4037\u001b[0m 0.7214 0.5824 1.9729\n", - " 13 \u001b[36m0.4033\u001b[0m 0.7216 0.5868 2.0706\n", - " 14 \u001b[36m0.4016\u001b[0m 0.7245 0.5874 2.0209\n", - " 15 \u001b[36m0.4015\u001b[0m 0.7170 0.5935 1.9888\n", - " 16 \u001b[36m0.4009\u001b[0m 0.7299 0.5788 1.9787\n", - " 17 0.4013 0.7223 0.5823 1.9918\n", - " 18 \u001b[36m0.4006\u001b[0m 0.7271 0.5790 1.9598\n", - " 19 \u001b[36m0.3996\u001b[0m 0.7237 0.5812 2.0040\n", - " 20 \u001b[36m0.3993\u001b[0m 0.7232 0.5834 1.9756\n", - "4 enzyme 3.3565628203824946 3.148863346675525 2.6739866478094116 2.43251538083125 3.216420429251877\n", - "enzyme 0.671312564076499 0.629772669335105 0.5347973295618823 0.48650307616625 0.6432840858503754\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.6185\u001b[0m \u001b[32m0.6798\u001b[0m \u001b[35m0.6228\u001b[0m 1.9648\n", - " 2 \u001b[36m0.5625\u001b[0m \u001b[32m0.6909\u001b[0m 0.6357 2.0414\n", - " 3 \u001b[36m0.5055\u001b[0m \u001b[32m0.7058\u001b[0m \u001b[35m0.6098\u001b[0m 2.0059\n", - " 4 \u001b[36m0.4752\u001b[0m \u001b[32m0.7156\u001b[0m 0.6121 2.1869\n", - " 5 \u001b[36m0.4578\u001b[0m \u001b[32m0.7197\u001b[0m \u001b[35m0.5987\u001b[0m 1.9789\n", - " 6 \u001b[36m0.4466\u001b[0m 0.7170 0.6457 1.9454\n", - " 7 \u001b[36m0.4338\u001b[0m 0.7099 0.6624 1.9970\n", - " 8 \u001b[36m0.4260\u001b[0m \u001b[32m0.7197\u001b[0m 0.6371 1.9291\n", - " 9 \u001b[36m0.4179\u001b[0m \u001b[32m0.7318\u001b[0m \u001b[35m0.5772\u001b[0m 1.9620\n", - " 10 \u001b[36m0.4124\u001b[0m 0.7237 0.6251 2.4665\n", - " 11 \u001b[36m0.4076\u001b[0m 0.7271 0.6208 2.4929\n", - " 12 \u001b[36m0.4027\u001b[0m 0.7198 0.6398 1.9545\n", - " 13 \u001b[36m0.3989\u001b[0m 0.7266 0.6092 1.9949\n", - " 14 \u001b[36m0.3970\u001b[0m 0.7210 0.6615 1.7808\n", - " 15 \u001b[36m0.3925\u001b[0m 0.7211 0.6475 1.8557\n", - " 16 \u001b[36m0.3899\u001b[0m 0.7217 0.6604 1.7936\n", - " 17 \u001b[36m0.3855\u001b[0m 0.7255 0.6487 1.8261\n", - " 18 \u001b[36m0.3838\u001b[0m \u001b[32m0.7324\u001b[0m 0.6489 1.8680\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " 19 \u001b[36m0.3821\u001b[0m 0.7298 0.6348 1.8930\n", - " 20 \u001b[36m0.3799\u001b[0m 0.7257 0.6353 1.8213\n", - "0 pathway 0.6310810356302675 0.5974075914435812 0.4514884233737596 0.3371410929299166 0.6832116788321168\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.3837\u001b[0m \u001b[32m0.7161\u001b[0m \u001b[35m0.5985\u001b[0m 1.8151\n", - " 2 \u001b[36m0.3781\u001b[0m \u001b[32m0.7300\u001b[0m \u001b[35m0.5892\u001b[0m 1.8137\n", - " 3 \u001b[36m0.3765\u001b[0m 0.7281 0.6098 2.0036\n", - " 4 \u001b[36m0.3738\u001b[0m 0.7244 0.6009 1.9845\n", - " 5 \u001b[36m0.3712\u001b[0m 0.7298 0.5986 1.9893\n", - " 6 \u001b[36m0.3710\u001b[0m 0.7265 0.6216 2.1948\n", - " 7 \u001b[36m0.3687\u001b[0m 0.7257 0.6215 1.9602\n", - " 8 \u001b[36m0.3664\u001b[0m \u001b[32m0.7333\u001b[0m 0.6058 1.9499\n", - " 9 0.3667 0.7315 0.6018 1.9440\n", - " 10 \u001b[36m0.3632\u001b[0m 0.7317 0.6152 1.9538\n", - " 11 \u001b[36m0.3627\u001b[0m 0.7151 0.6423 2.0355\n", - " 12 \u001b[36m0.3610\u001b[0m \u001b[32m0.7425\u001b[0m 0.5965 2.0036\n", - " 13 \u001b[36m0.3600\u001b[0m 0.7353 0.6151 1.9700\n", - " 14 \u001b[36m0.3572\u001b[0m 0.7301 0.6136 1.9744\n", - " 15 0.3590 0.7317 0.6260 1.9676\n", - " 16 \u001b[36m0.3543\u001b[0m 0.7362 0.6198 1.9907\n", - " 17 0.3550 0.7352 0.6331 1.9663\n", - " 18 \u001b[36m0.3542\u001b[0m 0.7317 0.6211 1.9278\n", - " 19 \u001b[36m0.3540\u001b[0m 0.7298 0.6395 2.0168\n", - " 20 \u001b[36m0.3530\u001b[0m 0.7144 0.6554 2.1076\n", - "1 pathway 1.2877781662329273 1.2195215505057897 0.9612599969778206 0.7505402902130287 1.3479477370813306\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.3969\u001b[0m \u001b[32m0.7218\u001b[0m \u001b[35m0.6231\u001b[0m 1.9405\n", - " 2 \u001b[36m0.3795\u001b[0m 0.7214 \u001b[35m0.6203\u001b[0m 2.1512\n", - " 3 \u001b[36m0.3732\u001b[0m 0.7182 0.6651 2.0859\n", - " 4 \u001b[36m0.3684\u001b[0m 0.7165 0.6397 2.2913\n", - " 5 \u001b[36m0.3647\u001b[0m 0.7075 0.6759 2.3245\n", - " 6 \u001b[36m0.3626\u001b[0m 0.7147 0.6682 1.9366\n", - " 7 \u001b[36m0.3608\u001b[0m 0.7089 0.6614 1.9367\n", - " 8 \u001b[36m0.3587\u001b[0m 0.6995 0.6912 2.4325\n", - " 9 \u001b[36m0.3577\u001b[0m \u001b[32m0.7230\u001b[0m 0.6788 2.4160\n", - " 10 \u001b[36m0.3566\u001b[0m 0.7098 0.6787 2.1800\n", - " 11 \u001b[36m0.3542\u001b[0m \u001b[32m0.7348\u001b[0m 0.6380 1.9457\n", - " 12 \u001b[36m0.3530\u001b[0m 0.7116 0.6872 1.9435\n", - " 13 \u001b[36m0.3519\u001b[0m 0.7238 0.6677 2.3769\n", - " 14 \u001b[36m0.3513\u001b[0m 0.7059 0.6960 1.9401\n", - " 15 \u001b[36m0.3489\u001b[0m 0.7135 0.6928 1.9995\n", - " 16 0.3500 0.7267 0.6802 2.0365\n", - " 17 \u001b[36m0.3473\u001b[0m 0.7237 0.6691 1.9733\n", - " 18 \u001b[36m0.3464\u001b[0m 0.7059 0.7129 1.9769\n", - " 19 \u001b[36m0.3460\u001b[0m 0.7132 0.6888 2.1551\n", - " 20 \u001b[36m0.3452\u001b[0m 0.7145 0.6949 2.4517\n", - "2 pathway 2.0370664086556403 1.965438881329851 1.6227155795277308 1.3949778738293712 2.027344503925206\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.3630\u001b[0m \u001b[32m0.7240\u001b[0m \u001b[35m0.6505\u001b[0m 1.9899\n", - " 2 \u001b[36m0.3511\u001b[0m 0.6846 0.7266 1.9880\n", - " 3 \u001b[36m0.3470\u001b[0m 0.7055 0.6819 1.9492\n", - " 4 \u001b[36m0.3435\u001b[0m 0.7232 0.6818 1.9628\n", - " 5 \u001b[36m0.3404\u001b[0m \u001b[32m0.7342\u001b[0m 0.6647 1.9699\n", - " 6 0.3417 0.7070 0.6942 1.9913\n", - " 7 \u001b[36m0.3399\u001b[0m 0.7219 0.6964 2.1267\n", - " 8 \u001b[36m0.3388\u001b[0m 0.7299 0.6821 1.9754\n", - " 9 \u001b[36m0.3366\u001b[0m 0.7311 0.6826 2.0034\n", - " 10 \u001b[36m0.3353\u001b[0m 0.7321 0.6830 2.0641\n", - " 11 \u001b[36m0.3341\u001b[0m 0.7086 0.6959 1.9657\n", - " 12 \u001b[36m0.3336\u001b[0m \u001b[32m0.7362\u001b[0m 0.6651 1.9808\n", - " 13 0.3346 0.7341 0.6746 2.0021\n", - " 14 \u001b[36m0.3313\u001b[0m 0.7247 0.7216 1.9558\n", - " 15 \u001b[36m0.3312\u001b[0m 0.7300 0.6847 1.9300\n", - " 16 \u001b[36m0.3297\u001b[0m 0.7292 0.6927 2.0166\n", - " 17 \u001b[36m0.3291\u001b[0m 0.7316 0.6796 1.9291\n", - " 18 \u001b[36m0.3287\u001b[0m 0.7246 0.6943 1.9555\n", - " 19 \u001b[36m0.3279\u001b[0m 0.7271 0.7041 1.9790\n", - " 20 \u001b[36m0.3268\u001b[0m 0.7225 0.7250 1.9810\n", - "3 pathway 2.7698516949896663 2.689668526843535 2.260701593513745 1.9817845013893178 2.7262903239791405\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.3559\u001b[0m \u001b[32m0.7278\u001b[0m \u001b[35m0.6812\u001b[0m 1.9887\n", - " 2 \u001b[36m0.3434\u001b[0m \u001b[32m0.7283\u001b[0m 0.6827 1.9738\n", - " 3 \u001b[36m0.3409\u001b[0m 0.7262 0.6892 1.9878\n", - " 4 \u001b[36m0.3389\u001b[0m \u001b[32m0.7313\u001b[0m 0.6848 1.9753\n", - " 5 \u001b[36m0.3365\u001b[0m 0.7287 0.6967 1.7787\n", - " 6 \u001b[36m0.3361\u001b[0m \u001b[32m0.7332\u001b[0m \u001b[35m0.6770\u001b[0m 1.8299\n", - " 7 \u001b[36m0.3322\u001b[0m \u001b[32m0.7348\u001b[0m \u001b[35m0.6683\u001b[0m 1.8337\n", - " 8 \u001b[36m0.3315\u001b[0m 0.7333 0.6899 1.7787\n", - " 9 \u001b[36m0.3308\u001b[0m 0.7155 0.7227 1.7837\n", - " 10 0.3308 \u001b[32m0.7356\u001b[0m 0.6758 1.7484\n", - " 11 \u001b[36m0.3302\u001b[0m 0.7300 0.7109 1.7557\n", - " 12 \u001b[36m0.3301\u001b[0m 0.7239 0.7111 1.7440\n", - " 13 \u001b[36m0.3285\u001b[0m \u001b[32m0.7382\u001b[0m 0.6794 1.7565\n", - " 14 0.3288 0.7301 0.6987 2.0276\n", - " 15 \u001b[36m0.3262\u001b[0m 0.7207 0.7144 1.9730\n", - " 16 0.3284 0.7275 0.7041 1.9839\n", - " 17 \u001b[36m0.3258\u001b[0m 0.7366 0.6871 1.9363\n", - " 18 \u001b[36m0.3251\u001b[0m 0.7288 0.6874 2.0183\n", - " 19 0.3265 0.7314 0.7254 1.9398\n", - " 20 \u001b[36m0.3239\u001b[0m 0.7254 0.7375 1.9593\n", - "4 pathway 3.531113608459308 3.452834373959601 2.932879303463298 2.7977581530978397 3.29776008538619\n", - "pathway 0.7062227216918616 0.6905668747919201 0.5865758606926595 0.5595516306195679 0.659552017077238\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.5783\u001b[0m \u001b[32m0.7197\u001b[0m \u001b[35m0.5635\u001b[0m 2.0282\n", - " 2 \u001b[36m0.4543\u001b[0m \u001b[32m0.7408\u001b[0m \u001b[35m0.5450\u001b[0m 2.1332\n", - " 3 \u001b[36m0.3876\u001b[0m \u001b[32m0.7498\u001b[0m \u001b[35m0.5388\u001b[0m 1.9565\n", - " 4 \u001b[36m0.3519\u001b[0m 0.7446 0.5968 2.0606\n", - " 5 \u001b[36m0.3306\u001b[0m 0.7425 0.6145 2.0775\n", - " 6 \u001b[36m0.3175\u001b[0m \u001b[32m0.7514\u001b[0m 0.5845 2.0047\n", - " 7 \u001b[36m0.3057\u001b[0m 0.7493 0.6221 2.5536\n", - " 8 \u001b[36m0.2987\u001b[0m 0.7418 0.6571 2.0514\n", - " 9 \u001b[36m0.2914\u001b[0m 0.7452 0.6462 2.1195\n", - " 10 \u001b[36m0.2860\u001b[0m 0.7488 0.6390 1.9759\n", - " 11 \u001b[36m0.2799\u001b[0m 0.7486 0.6303 2.0376\n", - " 12 \u001b[36m0.2769\u001b[0m 0.7475 0.6587 2.0454\n", - " 13 \u001b[36m0.2708\u001b[0m 0.7480 0.6583 2.1140\n", - " 14 \u001b[36m0.2680\u001b[0m 0.7420 0.6920 1.9913\n", - " 15 \u001b[36m0.2624\u001b[0m 0.7480 0.6741 2.0122\n", - " 16 \u001b[36m0.2603\u001b[0m 0.7445 0.6642 1.9770\n", - " 17 \u001b[36m0.2570\u001b[0m 0.7451 0.7025 2.0386\n", - " 18 \u001b[36m0.2541\u001b[0m \u001b[32m0.7522\u001b[0m 0.6496 1.9465\n", - " 19 \u001b[36m0.2509\u001b[0m 0.7456 0.7107 1.9885\n", - " 20 \u001b[36m0.2487\u001b[0m 0.7460 0.7023 1.8316\n", - "0 indication 0.6464172952241307 0.5925051019049874 0.49153300530385335 0.39580117320160546 0.648347943358058\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.2598\u001b[0m \u001b[32m0.7455\u001b[0m \u001b[35m0.6612\u001b[0m 1.7486\n", - " 2 \u001b[36m0.2516\u001b[0m \u001b[32m0.7544\u001b[0m \u001b[35m0.6396\u001b[0m 1.7811\n", - " 3 \u001b[36m0.2464\u001b[0m 0.7533 \u001b[35m0.6386\u001b[0m 1.7823\n", - " 4 \u001b[36m0.2437\u001b[0m 0.7541 \u001b[35m0.6290\u001b[0m 1.7403\n", - " 5 \u001b[36m0.2423\u001b[0m \u001b[32m0.7553\u001b[0m \u001b[35m0.6288\u001b[0m 1.7978\n", - " 6 \u001b[36m0.2393\u001b[0m 0.7543 0.6295 1.7591\n", - " 7 \u001b[36m0.2351\u001b[0m \u001b[32m0.7573\u001b[0m 0.6375 1.7378\n", - " 8 \u001b[36m0.2332\u001b[0m \u001b[32m0.7596\u001b[0m \u001b[35m0.6187\u001b[0m 1.7545\n", - " 9 \u001b[36m0.2310\u001b[0m \u001b[32m0.7597\u001b[0m 0.6311 1.8734\n", - " 10 \u001b[36m0.2275\u001b[0m \u001b[32m0.7635\u001b[0m 0.6284 1.7605\n", - " 11 \u001b[36m0.2264\u001b[0m \u001b[32m0.7644\u001b[0m \u001b[35m0.6132\u001b[0m 1.7635\n", - " 12 \u001b[36m0.2255\u001b[0m 0.7582 0.6423 1.7717\n", - " 13 \u001b[36m0.2227\u001b[0m 0.7623 0.6452 1.7801\n", - " 14 \u001b[36m0.2191\u001b[0m 0.7597 0.6330 1.8238\n", - " 15 0.2202 0.7644 0.6444 1.7919\n", - " 16 \u001b[36m0.2168\u001b[0m \u001b[32m0.7701\u001b[0m 0.6182 1.8587\n", - " 17 \u001b[36m0.2141\u001b[0m 0.7668 0.6574 1.7678\n", - " 18 \u001b[36m0.2126\u001b[0m 0.7611 0.6500 1.7840\n", - " 19 \u001b[36m0.2113\u001b[0m 0.7619 0.6548 1.7648\n", - " 20 \u001b[36m0.2097\u001b[0m 0.7617 0.6647 1.7889\n", - "1 indication 1.3634953974617772 1.314116444731755 1.1028163283044732 0.9031594113409489 1.4170989179409366\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.2730\u001b[0m \u001b[32m0.7663\u001b[0m \u001b[35m0.5851\u001b[0m 1.8305\n", - " 2 \u001b[36m0.2462\u001b[0m \u001b[32m0.7689\u001b[0m \u001b[35m0.5777\u001b[0m 1.7648\n", - " 3 \u001b[36m0.2370\u001b[0m 0.7665 0.6087 1.7319\n", - " 4 \u001b[36m0.2306\u001b[0m 0.7655 0.6132 1.7460\n", - " 5 \u001b[36m0.2256\u001b[0m \u001b[32m0.7706\u001b[0m 0.5977 1.7488\n", - " 6 \u001b[36m0.2221\u001b[0m 0.7685 0.6116 1.7539\n", - " 7 \u001b[36m0.2194\u001b[0m 0.7616 0.6261 1.7593\n", - " 8 \u001b[36m0.2158\u001b[0m 0.7691 0.6112 1.7388\n", - " 9 \u001b[36m0.2124\u001b[0m 0.7639 0.6369 1.7857\n", - " 10 \u001b[36m0.2101\u001b[0m 0.7685 0.6334 1.7964\n", - " 11 \u001b[36m0.2077\u001b[0m \u001b[32m0.7716\u001b[0m 0.6175 1.7606\n", - " 12 \u001b[36m0.2073\u001b[0m \u001b[32m0.7744\u001b[0m 0.6113 1.7540\n", - " 13 \u001b[36m0.2054\u001b[0m 0.7690 0.6359 1.7485\n", - " 14 \u001b[36m0.2013\u001b[0m 0.7705 0.6475 1.7555\n", - " 15 \u001b[36m0.2000\u001b[0m 0.7709 0.6332 1.7708\n", - " 16 \u001b[36m0.1981\u001b[0m 0.7683 0.6620 1.7769\n", - " 17 \u001b[36m0.1971\u001b[0m 0.7692 0.6573 1.7846\n", - " 18 \u001b[36m0.1931\u001b[0m 0.7724 0.6414 1.7905\n", - " 19 \u001b[36m0.1927\u001b[0m 0.7703 0.6758 1.7965\n", - " 20 \u001b[36m0.1908\u001b[0m 0.7686 0.6828 1.7897\n", - "2 indication 2.196293130541763 2.164995762294784 1.87478878678969 1.6877637130801688 2.176839824767045\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.2210\u001b[0m \u001b[32m0.7727\u001b[0m \u001b[35m0.6559\u001b[0m 1.7762\n", - " 2 \u001b[36m0.2042\u001b[0m 0.7704 \u001b[35m0.6473\u001b[0m 1.7876\n", - " 3 \u001b[36m0.1978\u001b[0m 0.7706 \u001b[35m0.6303\u001b[0m 1.7898\n", - " 4 \u001b[36m0.1938\u001b[0m 0.7694 0.6576 1.7927\n", - " 5 \u001b[36m0.1900\u001b[0m \u001b[32m0.7753\u001b[0m 0.6567 1.7402\n", - " 6 \u001b[36m0.1870\u001b[0m 0.7748 0.6421 1.7659\n", - " 7 \u001b[36m0.1834\u001b[0m 0.7710 0.6517 1.8013\n", - " 8 \u001b[36m0.1828\u001b[0m 0.7737 \u001b[35m0.6297\u001b[0m 1.9263\n", - " 9 \u001b[36m0.1782\u001b[0m 0.7741 0.6510 1.8212\n", - " 10 \u001b[36m0.1779\u001b[0m \u001b[32m0.7773\u001b[0m 0.6648 1.7934\n", - " 11 \u001b[36m0.1750\u001b[0m 0.7752 0.6609 1.7979\n", - " 12 \u001b[36m0.1721\u001b[0m 0.7773 0.6583 1.7680\n", - " 13 0.1732 0.7699 0.6740 1.8162\n", - " 14 \u001b[36m0.1706\u001b[0m 0.7720 0.6835 2.2396\n", - " 15 \u001b[36m0.1672\u001b[0m 0.7737 0.6871 2.4444\n", - " 16 0.1682 0.7731 0.6757 2.0417\n", - " 17 \u001b[36m0.1656\u001b[0m 0.7723 0.6886 2.0084\n", - " 18 \u001b[36m0.1634\u001b[0m 0.7735 0.6793 2.1220\n", - " 19 \u001b[36m0.1630\u001b[0m 0.7770 0.6706 2.3860\n", - " 20 \u001b[36m0.1611\u001b[0m 0.7749 0.6690 2.2080\n", - "3 indication 3.0159181386686336 3.0149379804758283 2.6352165942763213 2.4194710301533395 2.96833487096542\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.2107\u001b[0m \u001b[32m0.7694\u001b[0m \u001b[35m0.6654\u001b[0m 2.0174\n", - " 2 \u001b[36m0.1937\u001b[0m \u001b[32m0.7744\u001b[0m 0.6686 1.9758\n", - " 3 \u001b[36m0.1880\u001b[0m 0.7734 0.6814 2.0618\n", - " 4 \u001b[36m0.1862\u001b[0m 0.7731 0.6737 1.9872\n", - " 5 \u001b[36m0.1821\u001b[0m 0.7740 0.6991 1.9876\n", - " 6 \u001b[36m0.1794\u001b[0m 0.7679 0.7135 1.9593\n", - " 7 \u001b[36m0.1781\u001b[0m 0.7653 0.7449 2.0057\n", - " 8 \u001b[36m0.1755\u001b[0m 0.7698 0.7121 2.0293\n", - " 9 \u001b[36m0.1746\u001b[0m 0.7704 0.7112 2.0200\n", - " 10 0.1751 0.7706 0.7038 1.9990\n", - " 11 \u001b[36m0.1708\u001b[0m 0.7729 0.7024 1.9563\n", - " 12 \u001b[36m0.1704\u001b[0m 0.7721 0.7094 1.9584\n", - " 13 \u001b[36m0.1677\u001b[0m 0.7707 0.7182 1.9675\n", - " 14 \u001b[36m0.1671\u001b[0m 0.7701 0.7136 2.0071\n", - " 15 \u001b[36m0.1667\u001b[0m 0.7661 0.7566 1.9508\n", - " 16 \u001b[36m0.1634\u001b[0m 0.7695 0.7468 2.0027\n", - " 17 0.1637 0.7691 0.7311 2.4390\n", - " 18 \u001b[36m0.1606\u001b[0m 0.7688 0.7514 1.9963\n", - " 19 0.1614 0.7662 0.7600 2.0525\n", - " 20 \u001b[36m0.1604\u001b[0m 0.7672 0.7416 1.9953\n", - "4 indication 3.9038291065176565 3.9398872414091155 3.458937000138779 3.3594669132327963 3.7013792565501378\n", - "indication 0.7807658213035313 0.7879774482818231 0.6917874000277557 0.6718933826465593 0.7402758513100276\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.5901\u001b[0m \u001b[32m0.7030\u001b[0m \u001b[35m0.5807\u001b[0m 1.9732\n", - " 2 \u001b[36m0.5559\u001b[0m \u001b[32m0.7215\u001b[0m \u001b[35m0.5600\u001b[0m 1.9761\n", - " 3 \u001b[36m0.5345\u001b[0m \u001b[32m0.7229\u001b[0m \u001b[35m0.5577\u001b[0m 1.9633\n", - " 4 \u001b[36m0.5233\u001b[0m 0.7096 0.5681 2.4139\n", - " 5 \u001b[36m0.5192\u001b[0m 0.7124 0.5846 2.2775\n", - " 6 \u001b[36m0.5136\u001b[0m 0.7086 0.5792 1.9967\n", - " 7 \u001b[36m0.5132\u001b[0m 0.6778 0.7682 1.8504\n", - " 8 0.5167 \u001b[32m0.7279\u001b[0m \u001b[35m0.5521\u001b[0m 1.9034\n", - " 9 0.5202 0.6960 0.5996 1.9002\n", - " 10 0.5184 0.7152 0.5666 2.0161\n", - " 11 0.5142 0.7079 0.5797 1.9527\n", - " 12 0.5152 0.6986 0.6125 1.9900\n", - " 13 0.5218 0.7024 0.5990 1.9023\n", - " 14 0.5246 0.7002 0.6318 1.9648\n", - " 15 0.5216 0.7073 0.5981 1.9065\n", - " 16 0.5168 0.7213 0.5573 2.1858\n", - " 17 0.5269 0.7106 0.5842 2.2842\n", - " 18 0.5135 \u001b[32m0.7285\u001b[0m 0.5659 1.9842\n", - " 19 0.5205 0.7119 0.5587 1.9727\n", - " 20 0.5245 0.6997 0.6139 1.9652\n", - "0 sideeffect 0.5138384594402153 0.5746993251776205 0.057887407996817186 0.029947514665020068 0.8635014836795252\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.5189\u001b[0m \u001b[32m0.6999\u001b[0m \u001b[35m0.6043\u001b[0m 2.1976\n" + "(119902, 1096)\n", + "(119902, 1)\n" ] - }, + } + ], + "source": [ + "print(X_train.shape)\n", + "print(y_train.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 1377, + "metadata": {}, + "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " 2 0.5227 0.6777 0.7452 2.3916\n", - " 3 \u001b[36m0.5120\u001b[0m \u001b[32m0.7157\u001b[0m \u001b[35m0.5630\u001b[0m 2.3480\n", - " 4 0.5158 0.6953 0.5837 1.9979\n", - " 5 0.5226 0.6975 0.5857 2.0067\n", - " 6 0.5271 0.6762 0.6361 1.9625\n", - " 7 0.5180 0.7138 \u001b[35m0.5619\u001b[0m 2.0029\n", - " 8 \u001b[36m0.5067\u001b[0m 0.6766 0.6831 1.9727\n", - " 9 \u001b[36m0.5007\u001b[0m \u001b[32m0.7188\u001b[0m 0.5634 2.0476\n", - " 10 0.5093 0.6875 0.6270 1.9644\n", - " 11 0.5252 0.6842 0.6061 1.9711\n", - " 12 0.5351 0.7048 0.5850 1.9322\n", - " 13 0.5204 0.7001 0.5868 1.9061\n", - " 14 0.5167 0.6851 0.6072 1.9837\n", - " 15 0.5216 0.7062 0.5808 2.0573\n", - " 16 0.5232 0.7105 0.5668 1.9783\n", - " 17 0.5168 0.6956 0.5971 2.3828\n", - " 18 0.5297 0.6982 0.5896 2.4299\n", - " 19 0.5213 0.6767 0.6191 2.2688\n", - " 20 0.5314 0.6831 0.5973 2.5076\n", - "1 sideeffect 1.0558402220116427 1.1660323358773546 0.23173693319769373 0.12792013996089327 1.6343516861086749\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.5487\u001b[0m \u001b[32m0.6893\u001b[0m \u001b[35m0.6043\u001b[0m 1.9524\n", - " 2 \u001b[36m0.5304\u001b[0m \u001b[32m0.6972\u001b[0m \u001b[35m0.5896\u001b[0m 2.0687\n", - " 3 0.5369 \u001b[32m0.7135\u001b[0m \u001b[35m0.5714\u001b[0m 1.9579\n", - " 4 0.5511 0.7111 \u001b[35m0.5669\u001b[0m 1.9551\n", - " 5 0.5366 0.6996 0.5829 1.9694\n", - " 6 0.5389 0.7121 \u001b[35m0.5652\u001b[0m 1.9686\n", - " 7 \u001b[36m0.5206\u001b[0m 0.7091 0.5801 1.9564\n", - " 8 0.5372 0.6870 0.6092 1.9452\n", - " 9 0.5310 0.7017 0.5860 1.9514\n", - " 10 0.5246 0.7090 0.5801 1.8940\n", - " 11 0.5367 0.6780 0.5850 1.9389\n", - " 12 0.5357 0.6935 0.5858 1.9323\n", - " 13 0.5303 0.6824 0.6195 1.9322\n", - " 14 0.5306 0.6973 0.6030 1.9347\n", - " 15 0.5256 0.6843 0.6008 1.9844\n", - " 16 0.5531 0.6870 0.5912 1.9416\n", - " 17 0.5581 0.7024 0.5834 1.9664\n", - " 18 0.5498 0.7103 0.5787 1.9678\n", - " 19 0.5507 0.6913 0.5958 1.9741\n", - " 20 0.5521 0.6933 0.6074 2.4156\n", - "2 sideeffect 1.626663969882702 1.7446190864221762 0.5110386788336039 0.30081300813008127 2.3606811283914246\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.5489\u001b[0m \u001b[32m0.7103\u001b[0m \u001b[35m0.5788\u001b[0m 1.7751\n", - " 2 \u001b[36m0.5353\u001b[0m 0.6762 0.6355 2.0228\n", - " 3 \u001b[36m0.5350\u001b[0m 0.6888 0.6003 1.9493\n", - " 4 \u001b[36m0.5276\u001b[0m 0.6916 0.5926 1.8865\n", - " 5 \u001b[36m0.5272\u001b[0m 0.6912 0.5802 1.9343\n", - " 6 0.5326 0.6892 0.5879 1.9431\n", - " 7 0.5454 0.6969 0.5820 1.8945\n", - " 8 0.5446 0.7024 0.5796 2.0126\n", - " 9 0.5579 0.6956 0.6027 1.9470\n", - " 10 0.5578 0.6758 0.5911 1.9169\n", - " 11 0.5419 0.6798 0.6041 1.9411\n", - " 12 0.5543 0.6758 0.5835 1.9447\n", - " 13 0.5416 0.6839 0.5995 1.9651\n", - " 14 0.5540 0.7046 0.5841 1.9376\n", - " 15 0.5321 0.6839 0.5856 1.9520\n", - " 16 0.5370 0.6947 0.5825 2.0100\n", - " 17 0.5523 \u001b[32m0.7142\u001b[0m \u001b[35m0.5751\u001b[0m 1.9423\n", - " 18 0.5354 \u001b[32m0.7187\u001b[0m 0.5764 1.9233\n", - " 19 0.5485 0.6885 0.6085 1.9885\n", - " 20 0.5354 0.6758 0.6286 1.9077\n", - "3 sideeffect 2.126663969882702 2.173942726439983 0.5110386788336039 0.30081300813008127 2.3606811283914246\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.5691\u001b[0m \u001b[32m0.6758\u001b[0m \u001b[35m0.5865\u001b[0m 1.8874\n", - " 2 \u001b[36m0.5612\u001b[0m \u001b[32m0.6993\u001b[0m \u001b[35m0.5855\u001b[0m 2.0043\n", - " 3 \u001b[36m0.5431\u001b[0m \u001b[32m0.7026\u001b[0m 0.5911 2.0189\n", - " 4 \u001b[36m0.5389\u001b[0m \u001b[32m0.7070\u001b[0m \u001b[35m0.5812\u001b[0m 2.1660\n", - " 5 0.5470 0.6967 0.5913 2.2712\n", - " 6 0.5600 0.6788 0.6029 2.3870\n", - " 7 0.5567 0.6758 0.6035 1.8900\n", - " 8 0.5871 0.6758 0.5927 1.9346\n", - " 9 0.5763 0.6758 0.5926 2.0515\n", - " 10 0.5848 0.6758 0.5904 1.9870\n", - " 11 0.5851 0.6758 0.5957 1.9202\n", - " 12 0.5681 0.6768 0.6190 2.0681\n", - " 13 0.5680 0.6758 0.6305 1.9984\n", - " 14 0.5643 0.6758 0.5980 1.9622\n", - " 15 0.5773 0.6758 0.6037 1.9285\n", - " 16 0.5662 0.6758 0.5943 1.9357\n", - " 17 0.5894 0.6758 0.6096 1.9389\n", - " 18 0.5613 0.6875 0.6015 2.0127\n", - " 19 0.5545 0.6758 0.6106 1.9310\n", - " 20 0.5566 0.7047 0.5836 1.9099\n", - "4 sideeffect 2.7814314427191853 2.7456054416817004 1.0902269985564448 1.2643782613206946 2.7747090781526134\n", - "sideeffect 0.5562862885438371 0.5491210883363401 0.21804539971128895 0.25287565226413894 0.5549418156305227\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.5661\u001b[0m \u001b[32m0.7170\u001b[0m \u001b[35m0.5458\u001b[0m 1.9763\n", - " 2 \u001b[36m0.5005\u001b[0m 0.6947 0.5649 1.9595\n", - " 3 \u001b[36m0.4736\u001b[0m \u001b[32m0.7184\u001b[0m \u001b[35m0.5391\u001b[0m 1.9442\n", - " 4 \u001b[36m0.4605\u001b[0m \u001b[32m0.7363\u001b[0m \u001b[35m0.5102\u001b[0m 2.0052\n", - " 5 \u001b[36m0.4575\u001b[0m \u001b[32m0.7657\u001b[0m \u001b[35m0.4827\u001b[0m 1.9814\n", - " 6 \u001b[36m0.4556\u001b[0m 0.7249 0.5214 1.9555\n", - " 7 \u001b[36m0.4524\u001b[0m 0.7428 0.5014 1.9402\n", - " 8 0.4551 \u001b[32m0.7712\u001b[0m \u001b[35m0.4775\u001b[0m 2.1328\n", - " 9 0.4560 0.7659 0.4830 2.3639\n", - " 10 0.4597 0.7611 0.4859 1.9532\n", - " 11 0.4573 0.7621 0.4964 1.9342\n", - " 12 0.4586 0.7428 0.5186 1.9720\n", - " 13 0.4663 0.7627 0.4860 2.5719\n", - " 14 0.4653 0.7392 0.5188 2.3725\n", - " 15 0.4679 0.7561 0.4911 1.9703\n", - " 16 0.4632 0.7608 0.4891 2.4229\n", - " 17 0.4682 0.7494 0.4942 2.1244\n", - " 18 0.4700 0.7396 0.5215 1.9202\n", - " 19 0.4730 0.7384 0.5125 1.9610\n", - " 20 0.4734 0.7594 0.4935 1.9410\n", - "0 offsideeffect 0.7022339340916424 0.6299070537554166 0.5991838359709365 0.619532777606257 0.5801291317336417\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.4697\u001b[0m \u001b[32m0.7561\u001b[0m \u001b[35m0.4970\u001b[0m 1.9079\n", - " 2 0.4729 0.7547 0.5035 1.9468\n", - " 3 0.4778 0.7469 0.5130 1.9442\n", - " 4 0.4703 0.7171 0.5318 1.9914\n", - " 5 0.4756 \u001b[32m0.7627\u001b[0m 0.5047 1.9808\n", - " 6 0.4710 0.7205 0.5297 1.9697\n", - " 7 0.4710 0.7595 \u001b[35m0.4935\u001b[0m 2.0329\n", - " 8 0.4799 0.6993 0.5677 1.9828\n", - " 9 0.4805 0.7317 0.5251 1.9254\n", - " 10 0.4791 \u001b[32m0.7640\u001b[0m 0.5011 1.8055\n", - " 11 0.4751 0.7090 0.5432 1.7774\n", - " 12 0.4699 0.7427 0.5180 1.7907\n", - " 13 0.4790 0.7630 0.5040 1.7645\n", - " 14 0.4849 0.7412 0.5138 1.7621\n" + "Running fold0 for sideeffect...\n" ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - " 15 0.4781 0.7252 0.5289 1.7613\n", - " 16 \u001b[36m0.4677\u001b[0m 0.7349 0.5213 1.7573\n", - " 17 0.4731 0.7583 0.5048 1.8166\n", - " 18 0.4853 0.7604 0.4955 1.8098\n", - " 19 0.4789 0.7616 0.4985 1.7639\n", - " 20 0.4838 0.7552 0.5007 1.7664\n", - "1 offsideeffect 1.376255295711216 1.3196404247148268 1.1343416204412782 1.0401358443964186 1.315597909851331\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.4868\u001b[0m \u001b[32m0.7496\u001b[0m \u001b[35m0.5067\u001b[0m 1.8134\n", - " 2 0.4869 \u001b[32m0.7514\u001b[0m \u001b[35m0.5015\u001b[0m 1.7575\n", - " 3 0.4876 \u001b[32m0.7573\u001b[0m 0.5046 1.7535\n", - " 4 \u001b[36m0.4782\u001b[0m 0.7534 0.5020 1.7525\n", - " 5 \u001b[36m0.4735\u001b[0m \u001b[32m0.7651\u001b[0m \u001b[35m0.4935\u001b[0m 1.7669\n", - " 6 \u001b[36m0.4684\u001b[0m \u001b[32m0.7652\u001b[0m \u001b[35m0.4903\u001b[0m 1.7635\n", - " 7 \u001b[36m0.4658\u001b[0m 0.7384 0.5166 1.9508\n", - " 8 \u001b[36m0.4586\u001b[0m 0.7531 0.5330 1.9282\n", - " 9 0.4750 0.7595 0.4909 1.9345\n", - " 10 0.4666 0.7458 0.5046 1.9461\n", - " 11 \u001b[36m0.4572\u001b[0m 0.7634 0.4947 1.9534\n", - " 12 0.4626 \u001b[32m0.7796\u001b[0m \u001b[35m0.4730\u001b[0m 1.9606\n", - " 13 0.4600 0.7591 0.4964 1.9803\n", - " 14 0.4723 0.7658 0.4827 1.8888\n", - " 15 0.4683 0.7251 0.5205 2.0204\n", - " 16 0.4635 0.7529 0.5108 1.9082\n", - " 17 0.4730 0.7734 0.4736 1.9219\n", - " 18 0.4686 0.7666 0.4983 2.0971\n", - " 19 0.4840 0.7135 0.5567 2.0675\n", - " 20 0.4715 0.6795 0.5681 2.1841\n", - "2 offsideeffect 1.8810307109532358 2.052603284133125 1.1546833616943295 1.0504270865493464 2.1851631272426353\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.4674\u001b[0m \u001b[32m0.7758\u001b[0m \u001b[35m0.4873\u001b[0m 2.0970\n", - " 2 \u001b[36m0.4508\u001b[0m 0.7482 0.5066 1.9136\n", - " 3 0.4811 0.7491 0.5080 2.1551\n", - " 4 0.4807 0.7513 0.5096 2.0353\n", - " 5 0.4600 0.7509 0.4962 1.9234\n", - " 6 0.4603 0.7736 0.4910 2.0255\n", - " 7 0.4656 0.7704 0.4908 1.9850\n", - " 8 0.4637 0.7570 0.5100 1.9799\n", - " 9 0.4789 0.7103 0.5353 1.9484\n", - " 10 0.4639 0.7472 0.5081 1.9526\n", - " 11 0.4860 0.7518 0.5090 1.9683\n", - " 12 0.4907 0.7578 0.4958 2.0262\n", - " 13 0.4720 0.7371 0.5045 1.9842\n", - " 14 0.4679 0.7659 0.4883 2.0039\n", - " 15 0.4822 0.7524 0.5034 2.0838\n", - " 16 0.4775 0.7415 0.5264 1.9705\n", - " 17 0.4802 0.7569 0.5076 2.0119\n", - " 18 0.4881 0.7729 0.4975 1.9715\n", - " 19 0.4715 0.7212 0.5178 1.9566\n", - " 20 0.4774 0.7590 0.4909 1.9625\n", - "3 offsideeffect 2.5442253635754506 2.6542026469235807 1.6719965613819108 1.4594010497066994 2.888899624497815\n", - " epoch train_loss valid_acc valid_loss dur\n", - "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m0.4777\u001b[0m \u001b[32m0.7578\u001b[0m \u001b[35m0.4928\u001b[0m 2.2456\n", - " 2 \u001b[36m0.4740\u001b[0m \u001b[32m0.7607\u001b[0m 0.5043 2.1342\n", - " 3 \u001b[36m0.4671\u001b[0m 0.7515 0.4977 1.9554\n", - " 4 0.4811 \u001b[32m0.7706\u001b[0m 0.4967 2.0481\n", - " 5 0.4887 0.7316 0.5137 1.9594\n", - " 6 0.4879 0.7533 0.5149 2.0612\n", - " 7 0.5081 0.7259 0.5536 2.0438\n", - " 8 0.5151 0.7669 0.5123 1.9921\n", - " 9 0.5023 0.7370 0.5138 2.2733\n", - " 10 0.5066 0.7584 0.5164 1.9480\n", - " 11 0.4900 0.7365 0.5134 1.9470\n", - " 12 0.5088 0.6758 0.5386 1.9412\n", - " 13 0.5184 0.7546 0.5303 1.9819\n", - " 14 0.4865 \u001b[32m0.7726\u001b[0m 0.5129 1.9715\n", - " 15 0.5155 0.7346 0.5403 1.9627\n", - " 16 0.5081 0.7586 0.5208 1.9596\n", - " 17 0.4948 0.7473 0.5100 1.9471\n", - " 18 0.5056 0.6758 0.5474 1.9672\n", - " 19 0.5561 0.7538 0.5587 1.9508\n", - " 20 0.5310 0.7340 0.5383 1.9739\n", - "4 offsideeffect 3.304566062983737 3.3895360338269507 2.3450646492253364 2.213312124222961 3.496783441925201\n", - "offsideeffect 0.6609132125967474 0.6779072067653902 0.46901292984506726 0.4426624248445922 0.6993566883850402\n" + "ename": "RuntimeError", + "evalue": "result type Float can't be cast to the desired output type Long", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0mmodelPicklePath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpathPickles\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m\"model_params/model_params_fold\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"_\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr_hidden_layers_params\u001b[0m\u001b[0;34m+\u001b[0m \u001b[0;34m\"_\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msimilarity\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\".p\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdo_train_model\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodelPicklePath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/classifier.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0;31m# this is actually a pylint bug:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;31m# https://github.com/PyCQA/pylint/issues/1085\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mNeuralNetClassifier\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[1;32m 846\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 847\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 848\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpartial_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 849\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 850\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mpartial_fit\u001b[0;34m(self, X, y, classes, **fit_params)\u001b[0m\n\u001b[1;32m 805\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'on_train_begin'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 806\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 807\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 808\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mfit_loop\u001b[0;34m(self, X, y, epochs, **fit_params)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[0myi_res\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0myi\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0my_train_is_ph\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 738\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'on_batch_begin'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0myi_res\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 739\u001b[0;31m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 740\u001b[0m \u001b[0mtrain_batch_count\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 741\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecord_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'train_loss'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mtrain_step\u001b[0;34m(self, Xi, yi, **fit_params)\u001b[0m\n\u001b[1;32m 662\u001b[0m \u001b[0mstep_accumulator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstore_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 663\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 664\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 665\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstep_accumulator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/optim/sgd.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mclosure\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 80\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mgroup\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_groups\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mstep_fn\u001b[0;34m()\u001b[0m\n\u001b[1;32m 659\u001b[0m \u001b[0mstep_accumulator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_train_step_accumulator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 660\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstep_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 661\u001b[0;31m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step_single\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 662\u001b[0m \u001b[0mstep_accumulator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstore_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 663\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loss'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mtrain_step_single\u001b[0;34m(self, Xi, yi, **fit_params)\u001b[0m\n\u001b[1;32m 602\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 603\u001b[0m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 604\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mXi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 605\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/classifier.py\u001b[0m in \u001b[0;36mget_loss\u001b[0;34m(self, y_pred, y_true, *args, **kwargs)\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0meps\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;31m# pylint: disable=signature-differs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mget_loss\u001b[0;34m(self, y_pred, y_true, X, training)\u001b[0m\n\u001b[1;32m 1097\u001b[0m \"\"\"\n\u001b[1;32m 1098\u001b[0m \u001b[0my_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1099\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcriterion_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1101\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 541\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 542\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m 599\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 600\u001b[0m \u001b[0mpos_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpos_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 601\u001b[0;31m reduction=self.reduction)\n\u001b[0m\u001b[1;32m 602\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 603\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mbinary_cross_entropy_with_logits\u001b[0;34m(input, target, weight, size_average, reduce, reduction, pos_weight)\u001b[0m\n\u001b[1;32m 2112\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Target size ({}) must be the same as input size ({})\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2113\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2114\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbinary_cross_entropy_with_logits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpos_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction_enum\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2116\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: result type Float can't be cast to the desired output type Long" ] } ], @@ -1347,8 +402,8 @@ "do_prepare_data = False\n", "do_train_model = True\n", "kfold_nsplits = 5\n", - "similaritiesToRun = df_paperIndividualScores['Similarity']\n", - "# similaritiesToRun = [\"enzyme\"]\n", + "# similaritiesToRun = df_paperIndividualScores['Similarity']\n", + "similaritiesToRun = [\"sideeffect\"]\n", "\n", "for similarity in similaritiesToRun:\n", " input_fea = pathInput+DS1_path+\"/\" + similarity + \"_Jacarrd_sim.csv\"\n", @@ -1356,13 +411,13 @@ " dataPicklePath = pathPickles+\"data_X_y_\" + similarity + \"_Jaccard.p\"\n", "\n", " # Define model\n", - " D_in, H1, H2, D_out, drop = X.shape[1], 400, 300, 2, 0.5\n", + " D_in, H1, H2, D_out, drop = X.shape[1], 300, 400, 1, 0.5\n", " str_hidden_layers_params = \"-H1-\" + str(H1) + \"-H2-\" + str(H2)\n", - " model = NDD(D_in, H1, H2, D_out, drop)\n", " callbacks = []\n", " \n", " # Prepare data if not available\n", " if do_prepare_data:\n", + " print(\"Preparing \" + similarity + \" data...\")\n", " X,y = prepare_data(input_fea, input_lab, seperate = False)\n", "\n", " with open(dataPicklePath, 'wb') as f:\n", @@ -1372,19 +427,29 @@ " with open(dataPicklePath, 'rb') as f:\n", " X, y = pickle.load(f)\n", " \n", + "\n", + " y = np.reshape(y, (y.shape[0], 1))\n", + " \n", " X = X.astype(np.float32)\n", - " y = y.astype(np.int64) \n", + " y = y.astype(np.int64) \n", + "\n", + " \n", + "# y_cat = np_utils.to_categorical(y)\n", " \n", " AUROC, AUPR, F1, Rec, Prec = 0,0,0,0,0\n", " kFoldSplit = getStratifiedKFoldSplit(X,y,n_splits=kfold_nsplits)\n", " for i, indices in enumerate(kFoldSplit):\n", + " print(\"Running fold\" + str(i) + \" for \" + similarity +\"...\")\n", + " \n", " train_index = indices[0]\n", " test_index = indices[1]\n", " X_train, X_test = X[train_index], X[test_index]\n", " y_train, y_test = y[train_index], y[test_index]\n", + "# y_train, y_test = y_cat[train_index], y_cat[test_index]\n", " \n", " # Create Network Classifier\n", - " net = getNDDClassifier()\n", + " Xy_test = skorch.dataset.Dataset(X_test, y_test)\n", + " net = getNDDClassifier(D_in, H1, H2, D_out, drop, Xy_test)\n", " \n", " # Fit and save OR load model\n", " modelPicklePath = pathPickles+\"model_params/model_params_fold\" + str(i) + \"_\" + str_hidden_layers_params+ \"_\" + similarity + \".p\"\n", @@ -1428,59 +493,27 @@ }, { "cell_type": "code", - "execution_count": 830, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Similarity AUC AUPR F-measure Recall Precision\n", - "0 chem 0.631 0.455 0.527 0.899 0.373\n", - "1 target 0.787 0.642 0.617 0.721 0.540\n", - "2 transporter 0.682 0.568 0.519 0.945 0.358\n", - "3 enzyme 0.734 0.599 0.552 0.579 0.529\n", - "4 pathway 0.767 0.623 0.587 0.650 0.536\n", - "5 indication 0.802 0.654 0.632 0.740 0.551\n", - "6 sideeffect 0.778 0.601 0.619 0.748 0.528\n", - "7 offsideeffect 0.782 0.606 0.617 0.764 0.517\n" - ] - } - ], + "outputs": [], "source": [ "print(df_paperIndividualScores)" ] }, { "cell_type": "code", - "execution_count": 831, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Similarity AUC AUPR F-measure Recall Precision\n", - "0 chem 0.501 0.379 0.018 0.009 0.625\n", - "1 target 0.731 0.720 0.621 0.572 0.707\n", - "2 transporter 0.639 0.563 0.480 0.409 0.614\n", - "3 enzyme 0.671 0.630 0.535 0.487 0.643\n", - "4 pathway 0.706 0.691 0.587 0.560 0.660\n", - "5 indication 0.781 0.788 0.692 0.672 0.740\n", - "6 sideeffect 0.556 0.549 0.218 0.253 0.555\n", - "7 offsideeffect 0.661 0.678 0.469 0.443 0.699\n" - ] - } - ], + "outputs": [], "source": [ "print(df_replicatedIndividualScores)" ] }, { "cell_type": "code", - "execution_count": 832, + "execution_count": null, "metadata": { "scrolled": false }, @@ -1494,155 +527,18 @@ }, { "cell_type": "code", - "execution_count": 833, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
AUCAUPRF-measureRecallPrecision
00.1300.0765.090000e-010.890-0.252
10.056-0.078-4.000000e-030.149-0.167
20.0430.0053.900000e-020.536-0.256
30.063-0.0311.700000e-020.092-0.114
40.061-0.0681.110223e-160.090-0.124
50.021-0.134-6.000000e-020.068-0.189
60.2220.0524.010000e-010.495-0.027
70.121-0.0721.480000e-010.321-0.182
\n", - "
" - ], - "text/plain": [ - " AUC AUPR F-measure Recall Precision\n", - "0 0.130 0.076 5.090000e-01 0.890 -0.252\n", - "1 0.056 -0.078 -4.000000e-03 0.149 -0.167\n", - "2 0.043 0.005 3.900000e-02 0.536 -0.256\n", - "3 0.063 -0.031 1.700000e-02 0.092 -0.114\n", - "4 0.061 -0.068 1.110223e-16 0.090 -0.124\n", - "5 0.021 -0.134 -6.000000e-02 0.068 -0.189\n", - "6 0.222 0.052 4.010000e-01 0.495 -0.027\n", - "7 0.121 -0.072 1.480000e-01 0.321 -0.182" - ] - }, - "execution_count": 833, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "df_diff" ] }, { "cell_type": "code", - "execution_count": 834, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 834, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZoAAAD4CAYAAADVTSCGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3debwcVZ338c+XECSQEAgIQliCEIGAEE0IIFuAyICjLA4KGMSAkAfZRAcHHHlcRh0zD46MKBgiA0FEYWQNDLJFWSUkIWRniwEkwLA6QCBAcu/v+aPOhUp7l763u25XX75vXvVK9alTp37VCf3rc+p0lSICMzOzoqzR6ADMzKxvc6IxM7NCOdGYmVmhnGjMzKxQTjRmZlaoNRsdQDP60VbH9Lmpei9qVaNDKMR39nup0SHU3TrnTml0CIUYsNnejQ6hEKveeUa1trHypaVVfeb03+jDNR+rCO7RmJlZodyjMTMru9aWRkdQEycaM7Oya2nuoW0nGjOzkotobXQINXGiMTMru1YnGjMzK5J7NGZmVihPBjAzs0K5R2NmZkUKzzozM7NCNflkgKa/M4CkqZKOaHQcZmaFidbqlpJyj8bMrOyafDJA0/VoJB0rab6keZIuT8X7SPqTpKX53o2kb0ialep/L5UNk/SIpIslLZR0haRxku6T9LikMQ05MTOzjjR5j6apEo2kHYFvAftHxC7AV9OmTYG9gE8Dk1LdA4HhwBhgJDBK0j6p/rbAT4Gdge2BL6T9zwT+uYNjT5Q0W9LsmcsfL+DszMw60LKquqUKkg6S9KikJZLObmf7YEk3pi/ziyQdV2v4TZVogP2BqyPiJYCIeCWVXx8RrRGxGNgklR2YloeAOWQJZXja9kRELIjsvg6LgOkREcACYFh7B46IKRExOiJGjxk4vL0qZmbFaG2tbumCpH7ABcDBwAjgaEkjKqqdAixOX+bHAv8uaa1awm+2azQC2nsuw9sVddr+/FFEXLRaA9KwivqtudetNN97YmZ9XETdrtGMAZZExFIASVcChwKL84cDBkkSMBB4BahpfnWz9WimA5+XtCGApCGd1L0VOF7SwFR3qKSNeyFGM7P6qvIaTX6IPy0TK1oaCjyde70sleX9HNgBeJZslOerUeNdPZvq23tELJL0Q+AuSS1kw2Id1b1N0g7A/VliZjlwDNDc0zfM7P2nyt/RRMQUoLNHsLb3BM7KUaK/A+aSXarYBrhd0j0R8VpVQbSjqRINQERcBlzWyfaBufWfkl30r7RTrs6E3PqT+W1mZqVQvxlly4Atcq83J+u55B0HTErXrZdIeoLsGvfMnh602YbOzMzef1pWVrd0bRYwXNLW6QL/UcC0ijp/AQ4AkLQJsB2wtJbwm65HY2b2vlOnW9BExCpJp5Jdw+4HXJIuSZyUtk8Gvg9MlbSAbKjtrLaZvj3lRGNmVnZ1/DFmRNwM3FxRNjm3/izZT0PqxonGzKzsmvymmk40ZmZl50RjZmZFiuou9JeWE42ZWdmV+IaZ1XCi6YGn9U6jQ6i7eStrmlRSWif8YVCjQ6i73T/+7UaHUIjNBnZ2o4/3OQ+dmZlZodyjMTOzQrlHY2ZmhXKPxszMCrWqprv0N5wTjZlZ2blHY2ZmhfI1GjMzK5R7NGZmVij3aMzMrFDu0ZiZWaGafNZZ6Z+wKWl9SSf3wnHGSvpE0ccxM+u2iOqWkip9ogHWB6pONMr05LzGAk40ZlY+ra3VLSXVDENnk4BtJM0F/gjsDGwA9AfOiYgbJA0Dfp+27wEcJmkccBbwLPA48HZEnCrpg8BkYMvU/hnAM8BJQIukY4DTIuKeXjo/M7POlTiJVKMZEs3ZwE4RMVLSmsA6EfGapI2AGZKmpXrbAcdFxMmSNgP+L/Bx4HXgD8C8VO+nwHkRca+kLYFbI2IHSZOB5RHx4/aCkDQRmAiw75BRjBj04YJO18ysgicD9CoB/yppH6AVGApskrY9FREz0voY4K6IeAVA0u+Aj6Rt44ARktraXE9Sl/eSj4gpwBSAk4d9vryDoWbW97S0NDqCmjRbohkPfBAYFRErJT0JrJ22vZGrp8odc9YA9oiIFfnCXOIxMyuXJh86a4bJAK8DbT2OwcALKcnsB2zVwT4zgX0lbZCG2/4ht+024NS2F5JGtnMcM7PyaPLJAKVPNBHxMnCfpIXASGC0pNlkvZtHOtjnGeBfgQeAO4DFwKtp8+mpjfmSFpNNAgC4EThc0lxJexd2QmZm3RWt1S0l1RRDZxHxhSqq7VTx+jcRMSX1aK4j68kQES8BR7ZzjMfIZrSZmZVKtDb3ZeGmSDQ99N00xXltsiRzfYPjMTPrmRIPi1WjzyaaiDiz0TGYmdWFZ52ZmVmh3KMxM7NCOdGYmVmhSnzDzGqUfnqzmdn7Xh1/RyPpIEmPSloi6ewO6oxNP/VYJOmuWsN3j8bMrOzqNL1ZUj/gAuCTwDJglqRpEbE4V2d94ELgoIj4i6SNaz2uE00PrKS5u7Ht2XLN9RodQiHWVd/7Jz7xM680OoRCnDWlb55XXdRv1tkYYElELAWQdCVwKNmP2tt8Abg2Iv4CEBEv1HpQD52ZmZVctLZWtUiaKGl2bplY0dRQ4Onc62WpLO8jwAaS7pT0oKRja42/733dMzPra6ocOsvfZb4D7d09uLLxNYFRwAHAAOB+STPS3VN6xInGzKzs6ncfs2XAFrnXm5M9HLKyzksR8QbwhqS7gV2AHicaD52ZmZVda1S3dG0WMFzS1pLWAo4CplXUuQHYW9KaktYBdgMeriV892jMzMpuVX0mA0TEKkmnArcC/YBLImKRpJPS9skR8bCkW4D5ZA+YvDgiFtZyXCcaM7Oyq+MjACLiZuDmirLJFa/PBc6t1zGdaMzMys6PCTAzsyKF73VmZmaFco/GzMwK1eSJpsfTmyWtL+nkegZTBEn/3OgYzMxq0tJS3VJStfyOZn3gbxJNumlbwymzBtDtRFOWczAzA4jWqGopq1oSzSRgm3Qr6VmS/ijpN8ACAEnXp/vkLMrfb0fSckk/lDRP0gxJm6Tyz0lamMrvTmUTJN0g6ZZ0W+vv5Nr5eqq/UNIZqWyYpIclXQjMAf4TGJBivCLVOUbSzFR2UVtSSXH9i6QHgD1qeF/MzOqrfj/YbIhaEs3ZwJ8jYiTwDbK7gn4rIkak7cdHxChgNHC6pA1T+brAjIjYBbgbODGVfxv4u1R+SO44Y4DxwEjgc5JGSxoFHEf2i9XdgRMlfSzV3w74VUR8LCKOA1ZExMiIGC9pB+BIYM8Ud0tquy2uhRGxW0TcW3my+ZvVPfL60h6/aWZm3VbH59E0Qj1vQTMzIp7IvT5d0jxgBtm9dYan8neAm9L6g8CwtH4fMFXSiWS/WG1ze0S8HBErgGuBvdJyXUS8ERHLU/neqf5TETGjgxgPILtZ3CxJc9PrD6dtLcA1HZ1cREyJiNERMXr7QR/uqJqZWf01eY+mnrPO3mhbkTQWGAfsERFvSroTWDttXhnx7nNJW9piiIiTJO0G/D0wV9LIVKfy3QvavwPp38TRDgGXRcQ329n2VkSU92qamb1/lTiJVKOWHs3rwKAOtg0G/pqSzPZkw1udkrRNRDwQEd8GXuK9O4x+UtIQSQOAw8h6PncDh0laR9K6wOHAPR00vVJS/7Q+HTii7Ylxqd2tuj5VM7PGiZbWqpay6nGPJiJelnSfpIXACuD53OZbgJMkzQceJRs+68q5koaT9TqmA/PIrsvcC1wObAv8JiJmA0iaCsxM+14cEQ9JGtZOu1OA+ZLmpOs05wC3pRlpK4FTgKeqP3Mzs17W5D0avTeKVT6SJgCjI+LURseSd+Kwz5X3Teuh5bGy0SEUoi8+yvm8T7/Z6BAKscGUeY0OoRCr3nmms6H+qrx63LiqPnMGX3pHzccqQt/7v9DMrK9p8h5NqRNNREwFpjY4DDOzxirv5ZeqlDrRmJkZxKrmzjRONGZmZdfcecaJxsys7Mp8H7NqONGYmZWdezRmZlYk92jehwbU9RZx5fBcH/0dTb8++Hf1g5sGNzqEQuw4xDfp6JB7NGZmVqRY1egIauNEY2ZWcuEejZmZFcqJxszMiuQejZmZFcqJxszMChUtpbwpc9WcaMzMSs49GjMzK1S0NnePpu/9ms3MrI+J1uqWakg6SNKjkpZIOruTertKapF0RK3xO9GYmZVchKpauiKpH3ABcDAwAjha0ogO6v0bcGs94neiMTMruTr2aMYASyJiaUS8A1wJHNpOvdOAa4AX6hG/E42ZWcm1tqiqRdJESbNzy8SKpoYCT+deL0tl75I0FDgcmFyv+BueaCQdI2mmpLmSLpLUT9JyST+UNE/SDEmbpLpzc8sKSftKelzSB9P2NdK440aSpkr6haQ/Slqa6l4i6WFJU3PHP1DS/ZLmSPqdpIENeivMzNoVrapuiZgSEaNzy5SKptobX6u8NfR/AGdFREu94m9oopG0A3AksGdEjARagPHAusCMiNgFuBs4ESAiRqZ6/xeYDfwJ+HXaB2AcMC8iXkqvNwD2B74G3AicB+wIfFTSSEkbAecA4yLi46nNr3cQ67vfFBa+/ud6vg1mZp2qNtFUYRmwRe715sCzFXVGA1dKehI4ArhQ0mG1xN/o6c0HAKOAWZIABpCNCb4D3JTqPAh8sm0HScOBc4H9I2KlpEuAG8iy8PHApbn2b4yIkLQAeD4iFqQ2FgHDyN7kEcB96fhrAfe3F2j6ZjAF4PRhRzb3wyHMrKlE/T5xZgHDJW0NPAMcBXxh9WPF1m3rafTnpoi4vpaDNjrRCLgsIr65WqF0ZsS7b20LKU5J6wL/BZwYEc8CRMTTkp6XtD+wG+/1bgDeTn+25tbbXq+Z2r49Io6u72mZmdVPvX5HExGrJJ1KNpusH3BJRCySdFLaXrfrMnmNTjTTgRsknRcRL0gaAgzqpP6lwKURcU9F+cVkQ2iXd3NccQZwgaRtI2KJpHWAzSPise6chJlZkaqZulx9W3EzcHNFWbsJJiIm1OOYDb1GExGLya6R3CZpPnA7sGl7dSVtRTZeeHxuQsDotHkaMJDVh82qOf6LwATgt+n4M4Dte3IuZmZFaWlRVUtZNbpHQ0RcBVxVUTwwt/1q4Or0sqPEuAvZJIBHcvtNyK0/CezUwbY/ALv2KHgzs15Qzx5NIzQ80dQq3ULhK6x+bcbMrM/wvc4aLCImRcRWEXFvo2MxMytCRHVLWTV9j8bMrK9r9h6NE42ZWcm1tDb34JMTjZlZyZV5WKwaTjRmZiXX6llnZmZWJE9vNjOzQnno7H2oyiemNpXXWt5qdAiFWHfN/o0OwarUUu2ziN+HPHRmZmaF8qwzMzMrVJOPnDnRmJmVnYfOzMysUJ51ZmZmhWr2aRJONGZmJRe4R2NmZgVa5aEzMzMrkns0ZmZWKF+jMTOzQjV7j6b0PzeVNEHSZrnXT0raqJExmZn1ptYql7IqfaIBJgCbdVXJzKyvakFVLWXV64lG0jBJj0i6TNJ8SVdLWkfStyXNkrRQ0hRljgBGA1dImitpQGrmNElzJC2QtH1qd4Gk9dN+L0s6NpVfLmlcOu49ab85kj6R235oLr4rJB3Sy2+LmVmHWlXdUlaN6tFsB0yJiJ2B14CTgZ9HxK4RsRMwAPh0RFwNzAbGR8TIiFiR9n8pIj4O/AI4M5XdB+wJ7AgsBfZO5bsDM4AXgE+m/Y4Ezk/bLwaOA5A0GPgEcHNlwJImSpotafai1/9cr/fBzKxLraiqpawalWiejoj70vqvgb2A/SQ9IGkBsD9ZwujItenPB4Fhaf0eYJ+0/AL4qKShwCsRsRzoD/wytf87YARARNwFbCtpY+Bo4JqIWFV5wIiYEhGjI2L0joO26el5m5l1W1S5lFWjEk3lexLAhcAREfFR4JfA2p3s/3b6s4X3Zs7dTdaL2Ru4E3gROIIsAQF8DXge2IVsOG6tXHuXA+PJejaXdvtszMwK5MkAPbOlpD3S+tHAvWn9JUkDyRJEm9eBQV01GBFPAxsBwyNiaWrzTN5LNIOB5yKiFfgi0C+3+1TgjNTOop6ckJlZUVqlqpayalSieRj4kqT5wBCyoa5fAguA64FZubpTgckVkwE68gDwWFq/BxjKe0nswnTMGcBHgDfadoqI51NM7s2YWem0VLmUVaN+sNkaESdVlJ2TltVExDXANbmiYblts4GxuddfzK3/iVwijYjHgZ1z7XyzbUXSOsBw4LfdOw0zs+LVc0aZpIOAn5KN6lwcEZMqto8HzkovlwNfiYh5tRyzGX5HUyhJ44BHgJ9FxKuNjsfMrFK9Zp1J6gdcABxMNiHqaEkjKqo9AeybZgV/H5hSa/y93qOJiCeBnXr7uB2JiDuALRsdh5lZR+o4o2wMsCRdx0bSlcChwOJ3j5WNBrWZAWxe60Hf9z0aM7Oyq/YHm/nf+6VlYkVTQ4Gnc6+XpbKOfBn4fa3x+6aaZmYlV+3U5YiYQudDXe2Nr7XbYZK0H1mi2avKw3fIicbMrORa6jcZYBmwRe715sCzlZUk7Ux215SDI+LlWg/qoTMzs5Kr4w82ZwHDJW0taS3gKGBavoKkLcnuvvLFiHisnTa6zT0aM7OSq9ev/iNilaRTgVvJpjdfEhGLJJ2Utk8Gvg1sCFyo7EegqyJidC3HdaLpgVdjZaNDqLsXV77e6BAK0U99r9P+oTU7uztT83rxrf9tdAilFXX8HU1E3EzFjYNTgmlbPwE4oX5HdKIxMyu9Mt/HrBpONGZmJVfm28tUw4nGzKzkyvxQs2o40ZiZlZyHzszMrFBONGZmVqgyPz2zGk40ZmYl52s0ZmZWKM86MzOzQrU2+eCZE42ZWck1+2SAwu7PIelPXddarf5YSTel9UMknd3D4/5zLXGYmZVNVLmUVWGJJiI+UcO+0yqfY90NqyWaWuIwMyuDOt69uSGK7NEsT3+OlXSnpKslPSLpCqVbgko6KJXdC3w2t+8EST9P65tIuk7SvLR8IpVfL+lBSYvaniInaRIwQNJcSVdUxCFJ50paKGmBpCO7is/MrAxWKapayqq3rtF8DNiR7AE79wF7SpoN/BLYH1gCXNXBvucDd0XE4ZL6AQNT+fER8YqkAcAsSddExNmSTo2Ike2081lgJLALsFHa5+6O4gPuze+cktlEgN2GjGT4wK27/SaYmfVEeVNIdXrrHuozI2JZRLQCc4FhwPbAExHxeEQE8OsO9t0f+AVARLRExKup/HRJ84AZZE+MG95FDHsBv01tPA/cBezaSXyriYgpETE6IkY7yZhZb2r2obPe6tG8nVtvyR23R4la0lhgHLBHRLwp6U6gq4d0dDYc1lF8ZmYN1+zTmxv5VKhHgK0lbZNeH91BvenAVwAk9ZO0HjAY+GtKMtsDu+fqr5TUv5127gaOTG18ENgHmFmPEzEzK5JnnfVQRLxFds3jv9NkgKc6qPpVYD9JC4AHya6l3AKsKWk+8H2y4bM2U4D5bZMBcq4D5gPzgD8A/xQR/1Ov8zEzK0qzD50puzxi3fHFrT7b59602SuWNTqEQnxorcGNDqHudl5zw0aHUIgr/zq30SEU4vlXH6l5FuvXhh1V1WfOeU9eWcoZs74WYWZWcmXurVTDicbMrOSi1FdguuZEY2ZWcu7RmJlZoZp9erMTjZlZyTV3mnGiMTMrvVVNnmqcaMzMSs6TAd6HPv3OgEaHUHcD1xnW6BAK8c2NXm50CHW33nYvNDqEQiy9Z9tGh1BangxgZmaFco/GzMwK1ew9mkbeVNPMzKrQElHVUo30wMlHJS2RdHY72yXp/LR9vqSP1xq/E42ZWcm1ElUtXUkPj7wAOBgYARwtaURFtYPJnu81nOzGx7+oNX4nGjOzkosq/6vCGGBJRCyNiHeAK4FDK+ocCvwqMjOA9SVtWkv8TjRmZiVX7WMCJE2UNDu3TKxoaijwdO71slTW3Trd4skAZmYlV+0taCJiCtkzuTrS3mMEKhuvpk63ONGYmZVcHac3LwO2yL3eHHi2B3W6xUNnZmYlV8dZZ7OA4ZK2lrQWcBQwraLONODYNPtsd+DViHiulvjdozEzK7l63b05IlZJOhW4FegHXBIRiySdlLZPBm4GPgUsAd4Ejqv1uL2WaCRdDPwkIhZXlE8ARkfEqT1s97fAjsClwO/JZlEEcERE/Lkb7YwF3omIP/UkDjOzotTzB5sRcTNZMsmXTc6tB3BKHQ/Ze4kmIk6od5uSPgR8IiK2Sq/PBm6IiO/0oLmxwHLAicbMSqXZb0FTyDUaSetK+m9J8yQtlHSkpDsljU7bj5P0mKS7gD1z+31Q0jWSZqVlz1x7l6SyhyS1zfu+DdhY0lxJ3wHOAE6Q9Me03zGSZqbtF6UfK7X9MnZOim+6pGHAScDXUt29i3hfzMx6ol4/2GyUono0BwHPRsTfA0gaDHwlrW8KfA8YBbwK/BF4KO33U+C8iLhX0pZk44g7AN8C/hARx0taH5gp6Q7gEOCmiBiZ2hawPCJ+LGkH4Ehgz4hYKelCYLyk3wO/BPaJiCckDYmIVyRNbtu3vRNK89EnApyw3hjGreM7zZpZ74gqby9TVkUlmgXAjyX9G1kiuCfLAQDsBtwZES8CSLoK+EjaNg4Ykau7nqRBwIHAIZLOTOVrA1sCKzqJ4QCyZDYrtTcAeAHYHbg7Ip4AiIhXqjmh/Pz0qzYd39x/62bWVFpK3FupRiGJJiIekzSKbObCjyTdVlmlg13XAPaIiNUSSOqp/ENEPFpRPqyTMARcFhHfrNjnkE6Ob2ZWOmUeFqtGUddoNgPejIhfAz8G8nf/fAAYK2lDSf2Bz+W23QacmmtnZFq9FTgtJRwkfayKMKYDR0jaOO0zRNJWwP3AvpK2bitP9V8HBnXvTM3MihcRVS1lVdQPNj9Kdh1lLtn1lR+0bUg//Pku2Qf+HcCc3H6nA6PTrakXk12gB/g+0B+YL2lhet2pNI36HOA2SfOB24FN05DdROBaSfOAq9IuNwKHezKAmZVNs08GUJmzYFn1xWs0d35gZaNDKETffJRzn/vnB8AX71mn0SEU4sa/3NTevcO6Zezm46r6S79z2R01H6sIvjOAmVnJVftQs7JyojEzK7kyD4tVw4nGzKzknGjMzKxQzX4t3YnGzKzk3KMxM7NCNftNNZ1ozMxKriXq+aCA3udE0wOPrdXoCOpv0cq+93sTgN+9sGmjQ6i7ff9neaNDKMTs5Y81OoTS8jUaMzMrlK/RmJlZoXyNxszMCtXqoTMzMyuSezRmZlYozzozM7NCeejMzMwK5aEzMzMrlHs0ZmZWKPdozMysUC3R0ugQarJGT3eUdLqkhyVdIekDku6QNFfSkR3UP0nSse2UD5O0sLfi6KSdYZK+0NM4zMyKEhFVLWVVS4/mZODgiHhC0u5A/4gY2VHliJhcw7HqFkcnhgFfAH5Tz+DMzGrV7LegqapHI+nrkham5QxJk4EPA9MknQX8GhiZehLbSJokabGk+ZJ+nNr4rqQz0/ooSfMk3Q+ckjtOP0nnSpqV9v0/uW3fyJV/L5V1FccoSXdJelDSrZI2Tfttm3o+8yTNkbQNMAnYO+37tZrfWTOzOunzPRpJo4DjgN0AAQ8AxwAHAftFxEuSHgDOjIhPSxoCHA5sHxEhaf12mr0UOC0i7pJ0bq78y8CrEbGrpA8A90m6DRieljEphmmS9omIkyR1FEd/4HLg0Ih4MQ2l/RA4HrgCmBQR10lamyzhnt22bwfvw0RgIsAhQ8YweuC2Xb11ZmZ10VuzztLn91VkIzxPAp+PiL9W1NkC+BXwIaAVmBIRP+2s3Wp6NHsB10XEGxGxHLgW2LuT+q8BbwEXS/os8GZFkIOB9SPirlR0eW7zgcCxkuaSJbQNyRLMgWl5CJgDbJ/KO7MdsBNwe2rvHGBzSYOAoRFxHUBEvBURb3bSDqnelIgYHRGjnWTMrDdFlf/VwdnA9IgYDkxPryutAv4xInYAdgdOkTSis0aruUaj7kQZEaskjQEOAI4CTgX2r2ivo3dEZD2dW1crlP4O+FFEXNSNUAQsiog9KtparxttmJk1XC/eguZQYGxavwy4EzgrXyEingOeS+uvS3oYGAos7qjRano0dwOHSVpH0rpkw2L3dFRZ0kBgcETcDJwBrHZhPiL+F3hV0l6paHxu863AV9KwF5I+ko55K3B8ahtJQyVt3EXcjwIflLRH2qe/pB0j4jVgmaTDUvkHJK0DvA4M6vLdMDPrZdVeo5E0UdLs3DKxm4faJCWStoTS6eespGHAx8hGoDrUZY8mIuZImgrMTEUXR8RDUocdnUHADenah4D2LqwfB1wi6U2yJNLmYrKxwTnKDvAicFhE3CZpB+D+dNzlZNeJXugk7nckHQGcn4br1gT+A1gEfBG4SNK/ACuBzwHzgVWS5gFTI+K8Tt4WM7NeU+01moiYAkzprI6kO8iur1T6VndiSl/8rwHOSF/gO65b5pkKZfX9rcb3uTft9lX/0+gQCnFovz74KOfWvvko58/00Uc5P/e/i7t1+aE9GwzctqrPnL8uX1LTsSQ9CoyNiOfSLN07I2K7dur1B24Cbo2In3TVbo9/sGlmZr2jlahqqYNpwJfS+peAGyorpNGm/wQeribJgBONmVnp9eLvaCYBn5T0OPDJ9BpJm0m6OdXZk+zyw/7pd4dzJX2qs0Z9rzMzs5LrrVlnEfEy2YzhyvJngU+l9Xvp5mxkJxozs5LzYwLMzKxQzT5py4nGzKzk/DwaMzMrlHs0ZmZWqGa/RuMfbJacpInp1759Rl88J+ib59UXzwn67nmVlX9HU37dvVdRM+iL5wR987z64jlB3z2vUnKiMTOzQjnRmJlZoZxoyq8vjiP3xXOCvnleffGcoO+eVyl5MoCZmRXKPRozMyuUE42ZmRXKiaaBJB0uKSRtn16PlXRTRZ2p6UmhbY+jniTpcUkLJc2UdHAjYq/Ug3O5U9KjkuZJuk/Sdu2Uz5I08m+P1u3YWnK3M5+bHj/7vpJ7DxZKulHS+nVuf4Kkn6f170o6s07t5uP+XXrseq1tjpZ0fifbN5N0da3Hsfc40TTW0cC9wFFV1v8+sCmwU0TsBHyG7NHZZdDdc+4hGiUAAAXKSURBVAEYHxG7AJcB57ZTfmFFeU+tiIiRueXJOrRZKEn96txk23uwE/AKcEqd2y9KPu53gJPyG5Xp1udYRMyOiNM72f5sRBzRs3CtPU40DZKet70n8GWq+HBO3+ROBE6LiLcBIuL5iPivQgOtQnfPpR13A9u2U34/MLSG0KqWelLnSbpb0sOSdpV0beo9/iBX75jUk5wr6aK2hCDpF5JmS1ok6Xu5+pMkLZY0X9KPU9m7Pbv0enn6c6ykP0r6DbCgs+PVaLX3VdI3Uu9xfkXsx6ayeZIuT2WfkfSApIck3SFpkzrEU617gG0lDUt/RxcCc4AtJB0o6X5Jc1LPZ2CKd1dJf0rnMFPSoHxvW9K+uZ7uQ2n7MEkL0/a1JV0qaUHavl8qn5D+fdyS/o38v158H5qO73XWOIcBt0TEY5JekfTxLupvC/wlIl7rhdi6q7vnUukzpA/WCgcB19ccHQyQNDetPxERh3dQ752I2EfSV8keYTuK7Nv/nyWdB2wMHAnsGREr0wfdeOBXwLci4pWUCKZL2hlYBhwObB8RUeVw1RiyHusTknbo5Hg9kuI7gOxRvEg6EBiejitgmqR9gJeBb6VjvyRpSGriXmD3dD4nAP8E/GNP4+lG3GsCBwO3pKLtgOMi4mRJGwHnAOMi4g1JZwFflzQJuAo4MiJmSVoPWFHR9JnAKRFxX0pOb1VsPwUgIj6qbFj4NkkfSdtGAh8D3gYelfSziHi6rifeRzjRNM7RwH+k9SvT65s6qFv2Oeg9PZcrJK0AngROqyhfF+gHdDdptWdFRFRzrWda+nMBsCgingOQtBTYAtiLLPnMkgQwAHgh7fN5SRPJ/p/aFBgBLCb74LpY0n/T8XuSNzMinkjrB3RyvO5qS7bDgAeB21P5gWl5KL0eSJZ4dgGujoiXACLilbR9c+AqSZsCawFtsRYl/yXhHrIEuRnwVETMSOW7k73f96X3aS2yXtt2wHMRMSudw2sAqU6b+4CfSLoCuDYillVs3wv4Wdr/EUlPAW2JZnpEvJraXAxsBTjRtMOJpgEkbQjsD+wkKcg+UIPsm+oGFdWHAC8BS4AtJQ2KiNd7M97O9PBc2oyPiNntNDsemEf2vPILgM8WEPelZN9Gn42Ituedv53+bM2tt71ek+wb/2UR8c2KtrYm+2a8a0T8VdJUYO2IWCVpDFnCOAo4ley9WkUatlb2qbZWrrk38k23d7weWhERIyUNJkt4pwDnp2P8KCIuqjin02n/C87PgJ9ExDRJY4Hv1iG2zvzNl4SUCCrfp9sj4uiKejvTxZe0iJiUvgR8CpghaRyr92o6e2Rx/t9IC/487ZCv0TTGEcCvImKriBgWEVuQfTMcAmyWhkyQtBXZN8u5EfEm2be58yWtlbZvKumYxpzCu7p9LtU0GhEryYZDdm9ro54i4rh0kflTXdd+13TgCEkbA0gaks5rPbIPvlfTNYuD0/aBwOCIuBk4g2yoBbIe3Ki0fijQv5vH67H0Dfx04ExJ/YFbgeNz1zSGpuNNJ+ulbdh27NTEYOCZtP6lWmKpoxnAnpK2hex6ZhreeoTs3+CuqXxQGoJ7l6RtImJBRPwbMBvYvqLtu8m++JDa3BJ4tNCz6YOcgRvjaLJv63nXkH3rPQa4VNLawErghLbuOdkH7w+AxZLeIvtw+3bvhNyhnp5LlyJihaR/J+stfLlO8fZYRCyWdA7ZOP0aZOd0SkTMkPQQsAhYSjYcA9mMwBvS+Qv4Wir/ZSqfSfaBnv923uXxgKdqPI+HJM0DjoqIy1Mivz/1FJYDx0TEIkk/BO6S1EI2tDaBrAfzO0nPkH3Ab11LLPUQES9KmgD8VtIHUvE56ZrhkcDPJA0guz4zrmL3M9IF/hayoc7fkw19trkQmCxpAVlPdEJEvF0xvGZd8C1ozMysUB46MzOzQjnRmJlZoZxozMysUE40ZmZWKCcaMzMrlBONmZkVyonGzMwK9f8B6KuRBKwpouYAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "from seaborn import heatmap\n", "heatmap(df_diff, yticklabels=df_paperIndividualScores[\"Similarity\"])" @@ -1650,84 +546,27 @@ }, { "cell_type": "code", - "execution_count": 835, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 835, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAD4CAYAAADRuPC7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3debxdVX338c+XDJCQEIgMZU6EMAuRhAgGMECkYKuQFssQSoFqikyCD1YcHout1lh8RFExBAqhCIIyRkQGUQgEAgkhI4PSABJjCQEKBEKGe3/PH3td2Dnc4dx7z75nn5vv29d+5Zy111p77cP1/M5aa++1FRGYmZnV2kb1boCZmfVODjBmZlYIBxgzMyuEA4yZmRXCAcbMzArRt94NaEQHbjeu111699Lq1+rdhEI8MHzLejeh5ra964p6N6EQF4z+Sr2bUIgfPH+DulvH2hVLqvrO6bflB7t9rFpyD8bMzArhHoyZWdk1N9W7BV3iAGNmVnZN6+rdgi5xgDEzK7mI5no3oUscYMzMyq7ZAcbMzIrgHoyZmRXCk/xmZlYI92DMzKwI4avIzMysEA06yd/wd/JLmibpuHq3w8ysMNFc3VYy7sGYmZVdg07yN1wPRtIpkhZImi/p2pR8qKSHJS3J92YkfVHS7JT/GyltmKSnJV0paZGk6ySNlzRT0h8kjanLiZmZtaVBezANFWAk7Q18FTg8IvYDPp92bQscDPw1MDnlPRIYAYwBRgKjJB2a8u8K/ADYF9gDOCmVvwBodUlXSZMkzZE0Z/nbywo4OzOzNjStq24rmYYKMMDhwE0RsQIgIl5N6bdFRHNEPAlsk9KOTNsTwFyyQDIi7XsuIhZGtv7CYuC+iAhgITCstQNHxNSIGB0Ro7ceuF0Bp2Zm1obm5uq2kmm0ORgBrT0XYXVFnpZ/vx0Rl69XgTSsIn9z7n0zjfeZmFkvF+E5mJ5wH/B3kj4AIGloO3nvBk6XNCjl3V7S1j3QRjOz2mrQOZiG+rUeEYslfQt4QFIT2fBXW3nvkbQn8IgkgJXAyUBj/hQwsw1XCYe/qtFQAQYgIq4Brmln/6Dc6x+QTeZX2ieX59Tc6+fz+8zMSqGEvZNqNFyAMTPb4DStrXcLusQBxsys7DxEZmZmhfAQmZmZFcI9GDMzK4QDjJmZFSE8yW9mZoXwHMyG44W3l9e7CTU3oO/G9W5CIXZZ8HS9m1BzY0dOqncTCjFA/erdhPLyEJmZmRXCPRgzMyuEezBmZlYI92DMzKwQ62r3MDFJR5Gt0dgHuDIiJlfs/yIwMb3tC+wJbBURr0p6HniTbNHgdRExur1jOcCYmZVdjXowkvoAPwY+DiwFZkuanh7WmB0q4mLg4pT/k8D5uYc7AhzW8tDHjjTa82DMzDY8tXui5Rjg2YhYEhFrgBuAY9rJfyLws6422wHGzKzsavfAse2BF3Pvl6a095E0EDgKuDnfEuAeSY9L6vB6eQ+RmZmVXZVXkaUv/fwX/9SImJrP0kqx1h5DD/BJYGbF8NjYiFiWng58r6SnI2JGW+1xgDEzK7sq52BSMJnaTpalwI659zsAy9rIewIVw2MRsSz9u1zSrWRDbm0GGA+RmZmV3bp11W0dmw2MkDRcUn+yIDK9MpOkIcDHgNtzaZtKGtzyGjgSWNTewUrfg5G0OXBSRFxW8HHGAWsi4uEij2Nm1mnR1ihWZ6uJdZLOBu4mu0z5qohYLOmMtH9KyjoBuCci3soV3wa4VRJkseP6iLirveOVPsAAmwNnAlUFGGVnr4hOX9c3DlgJOMCYWbnU8E7+iLgTuLMibUrF+2nAtIq0JcB+nTlWIwSYycAukuYBvwP2BbYA+gFfi4jbJQ0Dfp32HwQcK2k88CWy8cU/AKsj4mxJWwFTgJ1S/ecBfwLOAJoknQycExEP9tD5mZm1z0vFFOZCYJ+IGCmpLzAwIt6QtCUwS1LL+OHuwGkRcaak7YD/C+xPdtfpb4H5Kd8PgEsi4iFJOwF3R8SekqYAKyPiu601In91xmYD/oKB/bco6HTNzCp4qZgeIeDfJR0KNJNdv71N2vdCRMxKr8cAD7RcXifpF8Buad94YK80jgiwWcvEVXvyV2dsu/letRkQNTOrRlNTvVvQJY0WYCYCWwGjImJtWhdnk7QvPxnV2rXeLTYCDoqIVfnEXMAxMyuXBh0ia4TLlN8EWnoYQ4DlKbgcBuzcRpnHgI9J2iINq/1tbt89wNktbySNbOU4ZmblUbulYnpU6QNMRLwCzJS0CBgJjJY0h6w30+rjCiPiT8C/A48CvwGeBF5Pu89NdSyQ9CTZ5D7AL4EJkuZJOqSwEzIz66zaLRXToxpiiCwiTqoi2z4V76+PiKmpB3MrWc+FtAro8a0c4/dkV6iZmZVKNDfmtG9DBJguuihdqrwJWXC5rc7tMTPrmhIOf1Wj1waYiLig3m0wM6sJX0VmZmaFcA/GzMwK4QBjZmaFqNFilz3NAcbMrOzcgzEzs0L4MuUNx9rmqh7s01BWvPF6x5ka0OCNB9a7CTV344i19W5CISY/t2W9m1BevorMzMyKEB4iMzOzQniIzMzMClHCdcaq4QBjZlZ27sGYmVkh1nmS38zMiuAhMjMzK4SHyMzMrAi+TNnMzIrhHoyZmRWiQQPMRl0tKGlzSWfWsjFFkPSVerfBzKxbmpqq20qmywEG2Bx4X4CR1KcbddaMMhsBnQ4wZTkHMzOAaI6qtrLpToCZDOwiaZ6k2ZJ+J+l6YCGApNskPS5psaRJLYUkrZT0LUnzJc2StE1K/7SkRSl9Rko7VdLtku6S9Iykf8nV84WUf5Gk81LaMElPSboMmAv8JzAgtfG6lOdkSY+ltMtbgklq179KehQ4qBufi5lZbTVHdVvJdCfAXAj8d0SMBL4IjAG+GhF7pf2nR8QoYDRwrqQPpPRNgVkRsR8wA/hsSv868Jcp/VO544wBJgIjgU9LGi1pFHAa8BHgQOCzkj6c8u8O/FdEfDgiTgNWRcTIiJgoaU/geGBsandTqrulXYsi4iMR8VDlyUqaJGmOpDnvrOmdKw+bWUk1N1e3VUHSUekH+7OSLmwjz7j0I3yxpAc6UzavlpP8j0XEc7n350qakF7vCIwAXgHWAHek9MeBj6fXM4Fpkn4O3JKr596IeAVA0i3AwUAAt0bEW7n0Q4DpwAsRMauNNh4BjAJmSwIYACxP+5qAm9s6uYiYCkwF2HKz3cr3U8HMeq8a9U7SiM2Pyb53l5J9F06PiCdzeTYHLgOOiog/Stq62rKVahlg3so1cBwwHjgoIt6WdD+wSdq9NuLd5382tbQhIs6Q9BHgr4B5kkamPJWfbACqph2tEHBNRHy5lX3vRET5ZsnMzGo3/DUGeDYilgBIugE4BsgHiZOAWyLijwARsbwTZdfTnSGyN4HBbewbAryWgsseZMNY7ZK0S0Q8GhFfB1aQ9XoAPi5pqKQBwLFkPZ0ZwLGSBkraFJgAPNhG1Wsl9Uuv7wOOy0XkoZJ27vhUzczqJ5qaq9ryQ/lpm1RR1fbAi7n3S1Na3m7AFpLuT/Pop3Si7Hq63IOJiFckzZS0CFgFvJTbfRdwhqQFwDNAW0NWeRdLGkHWy7gPmE827/IQcC2wK3B9RMwBkDQNeCyVvTIinpA0rJV6pwILJM1N8zBfA+5JV5itBc4CXqj+zM3MeliVPZj8UH4bWhv9qay8L9lUwhFk0wiPSJpVZdn3VdRlEXFSG+mrgaPb2Dco9/om4Kb0+m8q86Z5kuURcXYr9XwP+F5F2vPAPhVpXwK+lHt/I3Bje+0yMyuTGl6CvJT3RocAdgCWtZJnRZrjfitd1btflWXX050hMjMz6wm1u0x5NjBC0nBJ/YETyC6OyrsdOERSX0kDya7WfarKsusp9VIxETENmFbnZpiZ1VeN1rqMiHWSzgbuBvoAV0XEYklnpP1TIuIpSXcBC9KRr4yIRQCtlW3veKUOMGZmBrGudqspR8SdwJ0VaVMq3l8MXFxN2fY4wJiZlV1jrtbvAGNmVnZlXGesGg4wZmZl5x6MmZkVwT2YDcgmffvXuwk1t6bfuno3oRBvr11d7ybU3McWv13vJhRiwEYvdpxpQ+UejJmZFSEa9PefA4yZWcmFezBmZlYIBxgzMyuCezBmZlYIBxgzMytENLX3jMXycoAxMys592DMzKwQ0ewejJmZFcA9GDMzK0SEezBmZlYA92DMzKwQzQ16FdlG9W6ApJMlPSZpnqTLJfWRtFLStyTNlzRL0jYp77zctkrSxyT9QdJWaf9Gkp6VtKWkaZJ+Iul3kpakvFdJekrStNzxj5T0iKS5kn4haVCdPgozs1ZFs6rayqauAUbSnsDxwNiIGAk0AROBTYFZEbEfMAP4LEBEjEz5/i8wB3gY+GkqAzAemB8RK9L7LYDDgfOBXwKXAHsDH5I0UtKWwNeA8RGxf6rzC220dZKkOZLmvLX61Vp+DGZm7WrUAFPvIbIjgFHAbEkAA4DlwBrgjpTnceDjLQUkjSB7VvThEbFW0lXA7cD3gdOBq3P1/zIiQtJC4KWIWJjqWAwMA3YA9gJmpuP3Bx5praERMRWYCrDD0H0a8+EMZtaQokG/ceodYARcExFfXi9RuiDi3Y+0idROSZsCPwc+GxHLACLiRUkvSToc+Ajv9WYAWh4G0px73fK+b6r73og4sbanZWZWO2XsnVSj3nMw9wHHSdoaQNJQSTu3k/9q4OqIeLAi/UqyobKfR0RTJ44/Cxgradd0/IGSdutEeTOzwkWoqq1s6hpgIuJJsjmQeyQtAO4Ftm0tbwo8xwGn5yb6R6fd04FBrD88Vs3xXwZOBX6Wjj8L2KMr52JmVpSmJlW1lY2iUQf3clKguSQiDumJ4/XGOZg3VvfOx/CuaWrQRwG2Y9hm29S7CYUYsFHvexQ5wBP/M7Pb3/zP7HF0Vd85uz/961JFmXrPwXSbpAuBz7H+3IuZWa/hOZg6iYjJEbFzRDxU77aYmRUhorqtbBq+B2Nm1ts1ag/GAcbMrOSamhtzsKkxW21mtgGp5RCZpKMkPZOW1bqwnXwHSGqSdFwu7XlJC9NVvHM6OpZ7MGZmJddco3tcJPUBfky2OspSslVUpqdbRirzfQe4u5VqDsstx9Uu92DMzEquhjdajgGejYglEbEGuAE4ppV85wA3ky3d1WUOMGZmJVftEFl+Ud60Taqoanvgxdz7pSntXZK2ByYAU1prCtmN8Y+3Uvf7eIisC3rDzamVNt9403o3oRDLVva+la+be+HfH0DfjfrUuwmlVe0QWX5R3ja0VlHlH9T3gS9FRFNaBDhvbEQsS8t73Svp6YiY0dbBHGDMzEquhleRLQV2zL3fAVhWkWc0cEMKLlsCn5C0LiJuyy0yvFzSrWRDbm0GGA+RmZmVXFS5VWE2MELScEn9gRPI1nJ871gRwyNiWEQMA24CzoyI2yRtKmkwvLuy/ZHAovYO5h6MmVnJ1eoqsohYJ+lssqvD+gBXRcRiSWek/a3Nu7TYBrg19Wz6AtdHxF3tHc8Bxsys5Gq5FH9E3AncWZHWamCJiFNzr5cA+3XmWA4wZmYl11zvBnSRA4yZWclFqxd/lZ8DjJlZya0r4dMqq+EAY2ZWcu7BmJlZITwHY2ZmhWjUHkzpb7SUdKqk7XLvn5e0ZT3bZGbWk5qr3Mqm9AEGOBXYrqNMZma9VROqaiubHg8wkoZJelrSNZIWSLpJ0kBJX5c0W9IiSVOVOY5sXZzr0gNuBqRqzpE0Nz34Zo9U70JJm6dyr0g6JaVfK2l8Ou6DqdxcSR/N7T8m177rJH2qhz8WM7M2Nau6rWzq1YPZHZgaEfsCbwBnAj+KiAMiYh9gAPDXEXETMAeYGBEjI2JVKr8iIvYHfgJckNJmAmOBvYElwCEp/UBgFtlzDT6eyh0PXJr2XwmcBiBpCPBRKu5yTfveXQb7rdWv1epzMDPrUDOqaiubegWYFyNiZnr9U+Bg4DBJj0paCBxOFijackv693FgWHr9IHBo2n4CfCg91+DViFgJ9AOuSPX/AtgLICIeAHZNy0+fCNwcEesqDxgRUyNidESM3nTjLbp63mZmnVbDxS57VL0CTOVnEcBlwHER8SHgCmCTdsqvTv828d6VcDPIei2HAPcDLwPHkQUegPOBl8jW0hkN9M/Vdy0wkawnc3Wnz8bMrECe5O+cnSQdlF6fCDyUXq+QNIgsMLR4ExjcUYUR8SLZswtGpEXZHiIbPmsJMEOAP0dEM/D3ZCuJtpgGnJfqWdyVEzIzK0qzVNVWNvUKME8B/yBpATCUbEjrCmAhcBvZMwtaTAOmVEzyt+VR4Pfp9YNkjwJtCV6XpWPOAnYD3mopFBEvpTa592JmpdNU5VY26unH/0oaBtyRJvNLQdJAsuC2f0S83lH+7bfYu4zDnd3SR41wxXrn9cZHJn9wyLb1bkIhhvQbWO8mFGL2shnd7lr8bLuJVX3nnLjsulJ1Y3rnt0onSBoPPA38sJrgYmbW0xr1KrIeXyomIp4HStN7iYjfADvVux1mZm1p1CETr0VmZlZyZbyJshoOMGZmJVfGS5Cr4QBjZlZyTe7BmJlZEdyDMTOzQjjAbED+d/VbHWdqMPtuMbzeTShEUzTq/zXb9k7T6o4zNaB9BvqpHG0JD5GZmVkRGvVnkgOMmVnJlXEZmGo4wJiZlZzvgzEzs0I06hDZBr8WmZlZ2dXyeTCSjpL0jKRnJV3Yyv5j0uPs56Wn+B5cbdlKDjBmZiVXqydaSuoD/Bg4muypvidK2qsi233AfhExEjid7LHy1ZZdjwOMmVnJNau6rQpjgGcjYklErAFuAI7JZ4iIlfHec1w25b3Y1WHZSg4wZmYlV8MHjm0PvJh7vzSlrUfSBElPA78i68VUXTbPAcbMrOSaiao2SZPSvEnLNqmiqtb6Oe8bXYuIWyNiD+BY4N86UzbPV5GZmZVctRP4ETEVmNpOlqXAjrn3OwDL2qlvhqRdJG3Z2bJQYA9G0sOdzD9O0h3p9aequUKhjXq+0p12mJmVTa0m+YHZwAhJwyX1B04ApuczSNpVktLr/YH+wCvVlK1UWA8mIj7ajbLT6aDh7fgK8O+1aIeZWRnU6j6YiFgn6WzgbqAPcFVELJZ0Rto/Bfhb4BRJa4FVwPFp0r/Vsu0dr7AAI2llRAySNA64CFhB9qjkx4GTIyIkHQV8P+2bmyt7KjA6Is6WtA0wBfhg2v25iHhY0m1k3bVNgB9ExFRJk4EBkuYBiyNiYq4dAv6D7BK7AL4ZETe2176iPhszs85Yp9p9HUXEncCdFWlTcq+/A3yn2rLt6ak5mA8De5ON180ExkqaA1wBHA48C9zYRtlLgQciYkK6DntQSj89Il6VNACYLenmiLhQ0tnp+u1KfwOMBPYDtkxlZrTVPuChfOE0WTYJoH+/ofTtO7jTH4KZWVc06q/dnrqK7LGIWBoRzcA8YBiwB/BcRPwh9RZ+2kbZw4GfAEREU0S8ntLPlTQfmEXWkxnRQRsOBn6W6ngJeAA4oJ32rScipkbE6IgY7eBiZj2plnfy96Se6sHkH2DRlDtulwJzGtYaDxwUEW9Lup9sqKzdYl1on5lZ3TU3aB+mnvfBPA0Ml7RLen9iG/nuAz4H2VIFkjYDhgCvpeCyB3BgLv9aSf1aqWcGcHyqYyvgUOCxWpyImVmRangVWY+qW4CJiHfI5jR+Jekh4IU2sn4eOEzSQrIJ+L2Bu4C+khaQ3QQ0K5d/KrBA0nUV9dwKLADmA78F/jki/qdW52NmVpRGHSKTL5bqvE0HDut1H1pvfWTyH99eXu8m1Fwf9c4FOEYP7p1/g7e8ML3bT3M5f9gJVX3nXPL8DaV6coznGszMSq6MvZNqOMCYmZVclHKGpWMOMGZmJecejJmZFaJRL1N2gDEzK7nGDC8OMGZmpbeuQUOMA4yZWcl5kn8D8ndbj6p3E2ruztfaXXW7YV2zSWvrnja23bZ+td5NKMS+/72w3k0oLU/ym5lZIdyDMTOzQrgHY2ZmhWhq0CW9HGDMzErO98GYmVkhPAdjZmaF8ByMmZkVwkNkZmZWCA+RmZlZIXwVmZmZFaJRh8h67Nmrkq6UtFcr6adK+lE36v2ZpAWSzpe0h6R5kp6QtEsn6xkn6aNdbYeZWVGaq9zKpsd6MBHxmVrXKekvgI9GxM7p/YXA7RHxL12obhywEni4di00M+u+Rp2DKaQHI2lTSb+SNF/SIknHS7pf0ui0/zRJv5f0ADA2V24rSTdLmp22sbn6rkppT0g6JhW5B9g69Vr+BTgP+Iyk36VyJ0t6LO2/XFKflH6UpLmpffdJGgacAZyf8h5SxOdiZtYVzURVW9kU1YM5ClgWEX8FIGkI8Ln0elvgG8Ao4HXgd8ATqdwPgEsi4iFJOwF3A3sCXwV+GxGnS9oceEzSb4BPAXdExMhUt4CVEfFdSXsCxwNjI2KtpMuAiZJ+DVwBHBoRz0kaGhGvSprSUra1E5I0CZgEcNDQD7P74OG1/LzMzNoUDTrJX9QczEJgvKTvSDokIl7P7fsIcH9EvBwRa4Abc/vGAz+SNA+YDmwmaTBwJHBhSr8f2ATYqYM2HEEWxGanckcAHwQOBGZExHMAEVHV2ucRMTUiRkfEaAcXM+tJTURVWzXSCM4zkp5N0wqV+/eQ9Iik1ZIuqNj3vKSFaaRnTkfHKqQHExG/lzQK+ATwbUn3VGZpo+hGwEERsSqfmHomfxsRz1SkD2unGQKuiYgvV5T5VDvHNzMrnVoNf6Vpgh8DHweWkv0Anx4RT+ayvQqcCxzbRjWHRcSKao5X1BzMdsDbEfFT4LvA/rndjwLjJH1AUj/g07l99wBn5+ppeVrU3cA5KdAg6cNVNOM+4DhJW6cyQyXtDDwCfEzS8Jb0lP9NYHDnztTMrHgRUdVWhTHAsxGxJI0g3QAck88QEcsjYjawtrvtLmqI7ENk8yTzyOZPvtmyIyL+DFxE9kX/G2Burty5wOh02fGTZBPvAP8G9AMWSFqU3rcrReSvAfdIWgDcC2wbES+TzaXcImk+7w3R/RKY4El+Myubaif5JU2SNCe3Taqoanvgxdz7pSmtWkH2nfp4K3W/T1FDZHeT9TryxuX2Xw1c3Uq5FWQT85Xpq4B/aiX9eWCf3PuLKvbfyPpzPC3pvwZ+XZH2e2Df95+NmVl9VXuZckRMBaa2k0WtVl+9sRGxLI0M3Svp6YiY0VbmHrvR0szMuqYpoqqtCkuBHXPvdwCWVduOiFiW/l0O3Eo25NYmBxgzs5Kr4X0ws4ERkoZL6g+cQHbFbofS/YiDW16TXd27qL0yXovMzKzkanUVWUSsk3Q22RRGH+CqiFgs6Yy0f0paIWUOsBnQLOk8YC9gS+DWdK1VX+D6iLirveM5wJiZlVwtb7SMiDuBOyvSpuRe/w/Z0FmlN4D9OnMsBxgzs5Ir4zIw1XCAMTMruUZd7NIBxsys5JqijIvxd8wBpgsWrl5e7ybU3CZ9+9e7CYWY3K+qFS0ayoTX/6LeTSjEEVv2zr/BWmjUxS4dYMzMSs5zMGZmVgjPwZiZWSGaPURmZmZFcA/GzMwK4avIzMysEB4iMzOzQniIzMzMCuEejJmZFcI9GDMzK0RTNNW7CV3S5QeOSTpX0lOSrpO0saTfpOfZv++Rxyn/GZJOaSV9mKR2H1pTy3a0U88wSSd1tR1mZkWJiKq2sulOD+ZM4OiIeE7SgUC/iBjZVub88wZqrFPtaMcw4CTg+lo2zsysuxp1qZiqejCSviBpUdrOkzQF+CAwXdKXgJ8CI1PPYRdJkyU9KWmBpO+mOi6SdEF6PUrSfEmPAGfljtNH0sWSZqey/5Tb98Vc+jdSWkftGCXpAUmPS7pb0rap3K6ppzNf0lxJuwCTgUNS2fO7/cmamdVIr+3BSBoFnAZ8BBDwKHAycBRwWESskPQocEFE/LWkocAEYI+ICEmbt1Lt1cA5EfGApItz6f8IvB4RB0jaGJgp6R5gRNrGpDZMl3RoRJwhqa129AOuBY6JiJfTkNm3gNOB64DJEXGrpE3IAu2FLWXb+BwmAZMAdtpsV7Ya2DtXtDWz8unNV5EdDNwaEW8BSLoFOKSd/G8A7wBXSvoVcEd+p6QhwOYR8UBKuhY4Or0+EthX0nHp/RCywHJk2p5I6YNS+ox22rE7sA9wb3qGdB/gz5IGA9tHxK0AEfFOalc7VUFETAWmAoze9pDG/K9tZg2pN19F1v43b4WIWCdpDHAEcAJwNnB4RX1tfVoi69ncvV6i9JfAtyPi8k40RcDiiDiooq7NOlGHmVndNepSMdXMwcwAjpU0UNKmZMNfD7aVWdIgYEhE3AmcB6w34R4R/wu8LunglDQxt/tu4HNpeAtJu6Vj3g2cnupG0vaStu6g3c8AW0k6KJXpJ2nviHgDWCrp2JS+saSBwJvA4A4/DTOzHtZr52AiYq6kacBjKenKiHiinSGlwcDtaW5DQGsT5qcBV0l6myx4tLiS7GquucoO8DJwbETcI2lP4JF03JVk80BtPloyItakobZL07BcX+D7wGLg74HLJf0rsBb4NLAAWCdpPjAtIi5p52MxM+sxjToHozJGvbLrjXMwL695vd5NKMTwAR11dBvPhI165wUmv+V/692EQtz+xzs6Nc3Qmi0G7VrVd85rK5/t9rFqyXfym5mVXKPeB+MAY2ZWco060uQAY2ZWco16FZkDjJlZyTXqJL8DjJlZyTXqEFmXV1M2M7OeEVX+rxqSjpL0jKRnJV3Yyn5JujTtXyBp/2rLVnKAMTMruVrdaCmpD/BjsuW59gJOlLRXRbajeW/9x0nATzpRdj0OMGZmJdccUdVWhTHAsxGxJCLWADcAx1TkOQb4r8jMAjZPK9FXU3Y9noPpgjl/frDHbmaSNCkttNlr9MZzgt55Xj15Tp/viYMkjfbfat2aP1X1nZNf9T2ZWnGe2wMv5t4vJVspnw7ybF9l2fW4B1N+kzrO0nB64zlB7zyv3nhO0EvPKyKmRsTo3FYZRFsLVJVdn7byVFN2Pe7BmJltOJYCO+be7wAsqzJP/yrKrsc9GDOzDcdsYISk4ZL6kz1SZXpFnunAKelqsgPJHgL552/OJMEAAAavSURBVCrLrsc9mPJrmHHiTuiN5wS987x64zlB7z2vdqXndZ1Ntop9H+CqiFgs6Yy0fwpwJ/AJ4FngbbLV79ss297xvJqymZkVwkNkZmZWCAcYMzMrhANMHUmaICkk7ZHej5N0R0WeaenJnC2PfZ4s6Q+SFkl6TNLR9Wh7pS6cy/1pyYn5kmZK2r2V9NmSRr7/aJ1uW5OkebltWHfrbDS5z2CRpF9K2rzG9Z8q6Ufp9UWSLqhRvfl2/yI93ry7dY6WdGk7+7eTdFN3j2MOMPV2IvAQ2dUY1fg3YFtgn4jYB/gk2SOqy6Cz5wIwMSL2A64BLm4l/bKK9K5aFREjc9vzNaizUGlZjlpq+Qz2AV4Fzqpx/UXJt3sNcEZ+Z7rSqVPfYxExJyLObWf/sog4rmvNtTwHmDqRNAgYC/wjVXwpp19unwXOiYjVABHxUkT8vNCGVqGz59KKGcCuraQ/Qnb3cOFSz+kSSTMkPSXpAEm3pN7iN3P5Tk49x3mSLm8JBJJ+ImmOpMWSvpHLP1nSk2nRwO+mtHd7cun9yvTvOEm/k3Q9sLC943XTep+rpC+m3uKCirafktLmS7o2pX1S0qOSnpD0G0nb1KA91XoQ2FXSsPTf6DJgLrCjpCMlPSJpburpDErtPUDSw+kcHpM0ON+7lvSxXM/2ibR/mKRFaf8mkq6WtDDtPyyln5r+Pu5KfyP/0YOfQ8PwZcr1cyxwV0T8XtKryq1Y2oZdgT9GxBs90LbO6uy5VPok6Qu1wlHAbd1uHQyQNC+9fi4iJrSRb01EHCrp88DtwCiyX/v/LekSYGvgeGBsRKxNX3ATgf8CvhoRr6YAcJ+kfcluWJsA7BERUeWw1BiyHupzkvZs53hdktp3BPCf6f2RZIsajiG7U3u6pEOBV4CvpmOvkDQ0VfEQcGA6n88A/wz8n662pxPt7ku2yOJdKWl34LSIOFPSlsDXgPER8ZakLwFfkDQZuBE4PiJmS9oMWFVR9QXAWRExMwWldyr2nwUQER9SNvx7j6Td0r6RwIeB1cAzkn4YES9i73KAqZ8Tge+n1zek93e0kbfs15J39Vyuk7QKeB44pyJ9U7Jr7TsbrFqzKiKqmctpuWlsIbA43VyGpCVkdzAfTBZ0ZksCGAAsT2X+Ttk6UH3JhjH3Ap4k+8K6UtKvaPszyXssIp5Lr49o53id1RJkhwGPA/em9CPT9kR6P4gs4OwH3BQRKwAi4tW0fwfgRmWLH/YHWtpalPyPgwfJAuN2wAtpIUaAA8k+75npc+pP1kvbHfhzRMxO5/AGQMrTYibwPUnXAbdExNKK/QcDP0zln5b0AtASYO6LiNdTnU8CO7P+Wl0bPAeYOpD0AeBwYB9JQfZFGmS/TLeoyD4UWEF209NOkgZHxJs92d72dPFcWkyMiDmtVDsRmA9MJlse/G8KaPfVZL8+l0XEJ1Ly6vRvc+51y/u+ZL/wr4mIL1fUNZzsl/ABEfGapGnAJunGtDFkgeIE4Gyyz2odaXha2bdZ/1x1b+Wrbu14XbQqIkZKGkIW6M4CLk3H+HZEXF5xTufS+g+bHwLfi4jpksYBF9Wgbe1534+DFAAqP6d7I+LEinz70sGPs4iYnIL/J4BZksazfi+mvUUm838jTfj79H08B1Mfx5Eth71zRAyLiB3JfgkOBbZLQyNI2pnsl+S8iHib7NfbpcqWaUDStpJOrs8pvKvT51JNpRGxlmzY48CWOmopIk5Lk8ef6Dj3u+4DjpO0NYCkoem8NiP7wns9zUkcnfYPAoZExJ3AeWRDKpD12Eal18cA/Tp5vC5Lv7jPBS6Q1I/sruzTc3MW26fj3UfWK/tAy7FTFUOAP6XX/9CdttTQLGCspF0hm69Mw1hPk/0NHpDSB6ehtndJ2iUiFkbEd4A5wB4Vdc8g+8FDqnMn4JlCz6YXccStjxPJfp3n3Uz2K/dk4GpJmwBrgc+0dMPJvnC/CTwp6R2yL7Wv90yT29TVc+lQRKyS9P/Iegf/WKP2dllEPCnpa2Tj8BuRndNZETFL0hPAYmAJ2bALZFf43Z7OX8D5Kf2KlP4Y2Rd5/td4h8cDXujmeTwhaT5wQkRcmwL4I6lnsBI4OS0f8i3gAUlNZENop5L1WH4h6U9kX+zDu9OWWoiIlyWdCvxM0sYp+WtpTvB44IeSBpDNv4yvKH5emrhvIhvS/DXZEGeLy4ApkhaS9TxPjYjVFcNo1gYvFWNmZoXwEJmZmRXCAcbMzArhAGNmZoVwgDEzs0I4wJiZWSEcYMzMrBAOMGZmVoj/D1S7M3H1yXrCAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "heatmap(df_diff_abs, yticklabels=df_paperIndividualScores[\"Similarity\"])" ] }, { "cell_type": "code", - "execution_count": 836, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 836, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY8AAAD4CAYAAAAUymoqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3deZwdVZ338c+XJJiQhISAwbC2QmQxQoSAQAAjibgLODiAoATQDLIJDqNBGZdRxszII6I+CAElEVDZJaIDgShbJBvZE0B82CEDhCgQCJB0/54/6jTeNN19b3Xfpbr5vn3Vq+ueqjr1q0us3z3n1KKIwMzMLI9NGh2AmZn1PE4eZmaWm5OHmZnl5uRhZma5OXmYmVlufRsdQE90xA6f7HWXqP295dVGh1ATt957fqNDqDoNGNzoEGri1XNPaXQINTH44lvU3TrWr364onNOv63e1e19VcotDzMzy80tDzOzomtpbnQEb+LkYWZWdM0bGh3Bmzh5mJkVXERLo0N4EycPM7Oia3HyMDOzvNzyMDOz3DxgbmZmubnlYWZmeYWvtjIzs9wKOGDe4+8wlzRN0pGNjsPMrGaipbKpjtzyMDMrugIOmPe4loekz0taKmmJpCtS8cGS/izp4dJWiKR/kzQ/rf+dVNYk6QFJl0laLukqSRMkzZb0kKR9G3JgZmYdKWDLo0clD0nvAb4BHBIRewJfTotGAAcCnwCmpHUPBUYC+wKjgb0lHZzW3xm4ENgD2BX4bNr+bODrHex7kqQFkhY8uvaxGhydmVkHmjdUNtVRj0oewCHAdRGxGiAi1qTy30ZES0SsBLZOZYemaRGwkCxJjEzLHomIZZHd878CmBURASwDmtrbcURMjYgxETGmadCONTg0M7MOtLRUNtVRTxvzENDec+1fa7NO69/vR8QlG1UgNbVZv6Xkcws97zsxs14uwmMe3TUL+GdJWwJIGtbJurcCJ0oalNbdVtLwOsRoZlZdBRzz6FG/siNihaTzgDslNZN1SXW07kxJuwH3SgJYCxwHFC+Fm5l1poD3efSo5AEQEdOB6Z0sH1QyfyHZwHhbo0rWmVgy/2jpMjOzQvDjSczMLLfm9Y2O4E2cPMzMis7dVmZmlpu7rczMLDe3PMzMLDcnDzMzyys8YG5mZrl5zKN3mL/2kUaHUHVbbrp5o0OoiYE7fazRIVTdJ96xV6NDqIndN+nsgRE913nVqMTdVmZmlptbHmZmlptbHmZmlptbHmZmltuG+r7oqRJOHmZmReeWh5mZ5eYxDzMzy80tDzMzy80tDzMzy80tDzMzy62AV1tt0ugAypE0VNIpddjPOEkH1Ho/Zma5RVQ21VHhkwcwFKg4eSjTleMaBzh5mFnxtLRUNlVA0lmSVkhaLunXkvpLGibpNkkPpb9blKunJySPKcBOkhZLukDSLEkLJS2TdBiApCZJ90u6CFgIbC/pJEl/kXSHpEsl/TSt+3ZJ10uan6axkpqAk4Gz0n4OatCxmpm9WZWSh6RtgTOAMRExCugDHA1MBmZFxEhgVvrcqZ4w5jEZGBURoyX1BTaLiBclbQXMkTQjrbcLcEJEnCJpG+Dfgb2Al4A/AkvSehcCF0TEPZJ2AG6NiN0kXQysjYjz2wtC0iRgEsDQzUYw8G298wmgZlZA1R0w7wsMkLQe2Ax4GjiHrPcFYDpwB/C1cpX0JAL+U9LBQAuwLbB1WvZYRMxJ8/sCd0bEGgBJ1wLvTssmALtLaq1zc0mDy+04IqYCUwG2Gzaqvp2LZvbW1txc0WqlP3KTqencBUBEPCXpfOBxYB0wMyJmSto6IlaldVZJGl5uXz0teRwLvB3YOyLWS3oU6J+WvVyyntpuWGITYP+IWFdaWJJMzMyKpcLxjNIfue1JYxmHAe8E/g5cK+m4roTUE8Y8XgJaWwZDgGdT4vggsGMH28wDPiBpi9TV9U8ly2YCp7V+kDS6nf2YmRVH9QbMJwCPRMRzEbEeuIHsQqFnJI0ASH+fLVdR4ZNHRDwPzJa0HBgNjJG0gKwV8kAH2zwF/CcwF7gdWAm8kBafkepYKmkl2UA5wO+AIzxgbmaFEy2VTeU9DuwnaTNl3S3jgfuBGcDxaZ3jgZvKVdQjuq0i4rMVrDaqzedfRcTU1PK4kazFQUSsBo5qZx9/AfbobqxmZtUWLdUZZo2IuZKuI7sqdQOwiKybaxBwjaSTyBLMZ8rV1SOSRxd9W9IEsjGRmcBvGxyPmVnXVPHZVhHxLeBbbYpfI2uFVKzXJo+IOLvRMZiZVUWFV1vVU69NHmZmvYafqmtmZrk5eZiZWW51fuhhJZw8zMyKzi0PMzPLrUqX6laTk0cXNBfwV0B33f+3xxsdQk2MGNT7HmB54fBXGh1CTWz54f7lV3qr8tVWZmaWVxTwB6uTh5lZ0bnbyszMcqvu+zyqwsnDzKzo3PIwM7PcNnjA3MzM8nK3lZmZ5eZuKzMzy8uX6pqZWX5ueZiZWW4FTB5dfoe5pKGSTqlmMLUg6euNjsHMrFuamyub6qjLyQMYCrwpeUjq0406q0aZTYDcyaMox2BmBtk7zCuZ6qk7yWMKsJOkxZLmS/qTpF8BywAk/VbSfZJWSJrUupGktZLOk7RE0hxJW6fyz0hansrvSmUTJd0k6RZJD0r6Vkk9X0nrL5d0ZiprknS/pIvIXvD+c2BAivGqtM5xkualsktaE0WK6z8kzQX278b3YmZWXS1R2VRH3RnzmAyMiojRksYBv0+fH0nLT4yINZIGAPMlXR8RzwMDgTkR8Q1J/w18Efge8E3gwxHxlKShJfvZFxgFvJLq+T0QwAnA+wEBcyXdCfwN2AU4ISJOgSwpRcToNL8bcBQwNiLWpyRzLPDLFNfyiPhmewebEuAkgM0HvIPNNt2iG1+dmVkOvfxqq3kliQPgDElHpPntgZHA88DrwM2p/D7gQ2l+NjBN0jXADSX13JaSDpJuAA4kSx43RsTLJeUHATOAxyJiTgcxjgf2JktCAAOAZ9OyZuD6jg4uIqYCUwFGDN29eKNXZtZ7FXDAvJrJ4+XWmdQSmQDsHxGvSLoDaH1Y//qIN96p2NwaQ0ScLOn9wMeBxZJGp3XafmtB1tooG0c7BEyPiHPaWfZqRBTvGQBmZgVMHt0Z83gJGNzBsiHA31Li2BXYr1xlknaKiLmp22g1WWsF4EOShqXur8PJWih3AYdL2kzSQOAI4O4Oql4vqV+anwUcKWl42ucwSTuWP1Qzs8aJ5paKpnrqcssjIp6XNFvScmAd8EzJ4luAkyUtBR4EOupGKvUDSSPJWgezgCXAaOAe4ApgZ+BXEbEAQNI0YF7a9rKIWCSpqZ16pwJLJS2MiGMlnQvMTFdirQdOBR6r/MjNzOqsgC0P/aMHqXgkTQTGRMRpjY6lVG8c83h+3YuNDqEm3jGw913YcM+7hjc6hJrY8sO975XBAAPPu7azbvaKvHDChIrOOUMuv73b+6qU7zA3Myu6ArY8Cp08ImIaMK3BYZiZNVbxrtQtdvIwMzOIDcXLHk4eZmZFV7zc4eRhZlZ09X5uVSWcPMzMis4tDzMzy8stj15iQN+3NTqEqhvSf2CjQ6iJNa+ubXQIVXfcqt55n8fwK19rdAg1cd15VajELQ8zM8srNjQ6gjfrzrOtzMysDqKlsqkS6S2w10l6IL3/aP/0nL/bJD2U/pZ9NIOTh5lZ0bVUOFXmQuCWiNgV2BO4n+z9TLMiYiTZswUnl6vEycPMrOCq1fKQtDlwMNlbVomI1yPi78BhwPS02nSyJ5h3ysnDzKzgKk0ekiZJWlAyTWpT1buA54DLJS2SdFl6rcXWEbEKIP0te1WGB8zNzAoumit7WG7pG0870BfYCzg9IuZKupAKuqja45aHmVnBVXHA/EngyYiYmz5fR5ZMnpE0AiD9fbaD7d/g5GFmVnDRooqmsvVE/C/whKRdUtF4YCUwAzg+lR0P3FSuLndbmZkVXKWX4VbodOAqSZsCDwMnkDUkrpF0EvA48JlylTh5mJkVXET1XhAYEYuBMe0sGp+nHicPM7OCq3LLoyqcPMzMCq6lwqut6qnhA+aSjpM0T9JiSZdI6iNpraTzJC2RNEfS1mndxSXTOkkfSLfTvz0t30TSXyVtJWmapJ9J+pOkh9O6v0i3408r2f+hku6VtFDStZIGNeirMDNrV7UGzKupoclD0m7AUcDYiBgNNAPHAgOBORGxJ3AX8EWAiBid1vt3YAHwZ+DKtA3ABGBJRKxOn7cADgHOAn4HXAC8B3ivpNGStgLOBSZExF6pzq90EOsbN9+8+Orq9lYxM6uJIiaPRndbjQf2BuZLAhhAdn3x68DNaZ37gA+1biBpJPAD4JCIWC/pF2SXlf0IOBG4vKT+30VESFoGPBMRy1IdK4AmYDtgd2B22v+mwL3tBVp68827tnpf8R6ub2a9VhTwjNPo5CFgekScs1GhdHbEG19XMynOdBv9NcAXI+JpgIh4QtIzkg4B3s8/WiEArS8IaCmZb/3cN9V9W0QcU93DMjOrnnq3KirR6DGPWcCRkoYDpMcC79jJ+pcDl0fE3W3KLyPrvromIppz7H8OMFbSzmn/m0l6d47tzcxqLkIVTfXU0OQRESvJxhxmSloK3AaMaG/dlFSOBE4sGTRvvVZ5BjCIjbusKtn/c8BE4Ndp/3OAXbtyLGZmtdLcrIqmemp0txURcTVwdZviQSXLryN7/gp0nOz2JBsof6Bku4kl848CozpY9kdgny4Fb2ZWB/VuVVSi4cmjuyRNBr7ExmMdZma9hsc8aiAipkTEjhFxT6NjMTOrhYjKpnrq8S0PM7PerogtDycPM7OCa24pXieRk4eZWcH5JkEzM8utxVdbmZlZXr5U18zMcnO3VS/RnOsJKD3D8P5DGx1CTTy47slGh1B162J9o0OoiU3p0+gQCsvdVmZmlpuvtjIzs9wK2Gvl5GFmVnTutjIzs9x8tZWZmeXW0ugA2uHkYWZWcIFbHmZmltMGd1uZmVlebnmYmVluHvMwM7PcitjyKN5ti21Imihpm5LPj0raqpExmZnVU0uFUz0VPnkAE4Ftyq1kZtZbNaOKpnqqe/KQ1CTpAUnTJS2VdJ2kzSR9U9J8ScslTVXmSGAMcJWkxZIGpGpOl7RQ0jJJu6Z6l0kamrZ7XtLnU/kVkiak/d6dtlso6YCS5YeVxHeVpE/V+WsxM+tQiyqb6qlRLY9dgKkRsQfwInAK8NOI2CciRgEDgE9ExHXAAuDYiBgdEevS9qsjYi/gZ8DZqWw2MBZ4D/AwcFAq3w+YAzwLfChtdxTw47T8MuAEAElDgAOAP7QNWNIkSQskLVj76ppqfQ9mZmW1oIqmempU8ngiIman+SuBA4EPSporaRlwCFkS6MgN6e99QFOavxs4OE0/A94raVtgTUSsBfoBl6b6rwV2B4iIO4GdJQ0HjgGuj4gNbXcYEVMjYkxEjBnUf1hXj9vMLLeocKqnRiWPtscZwEXAkRHxXuBSoH8n27+W/jbzjyvG7iJrbRwE3AE8BxxJllQAzgKeAfYk6wrbtKS+K4BjyVogl+c+GjOzGvKA+T/sIGn/NH8McE+aXy1pENlJv9VLwOByFUbEE8BWwMiIeDjVeTb/SB5DgFUR0QJ8DjZ688w04MxUz4quHJCZWa20SBVN9dSo5HE/cLykpcAwsm6mS4FlwG+B+SXrTgMubjNg3pG5wF/S/N3AtvwjMV2U9jkHeDfwcutGEfFMismtDjMrnOYKp0pJ6iNpkaSb0+dhkm6T9FD6u0W5Ohp1k2BLRJzcpuzcNG0kIq4Hri8paipZtgAYV/L5cyXzf6YkOUbEQ8AeJfWc0zojaTNgJPDrfIdhZlZ7NbiS6stkP5g3T58nA7MiYoqkyenz1zqroCfc51FTkiYADwA/iYgXGh2PmVlb1bzaStJ2wMfJrjRtdRgwPc1PBw4vV0/dWx4R8Sgwqt777UhE3A7s0Og4zMw6UumVVJImAZNKiqZGxNQ2q/0I+CobjyVvHRGrACJiVbr6tFN+tpWZWcFV2m2VEkXbZPEGSZ8Ano2I+ySN605MTh5mZgVXxctwxwKfkvQxstshNpd0JfCMpBGp1TGC7KbqTr3lxzzMzIquWZVN5UTEORGxXUQ0AUcDf4yI44AZwPFpteOBm8rV5ZaHmVnB1eEGwCnANZJOAh4HPlNuAycPM7OCq0XyiIg7yJ7GQUQ8D4zPs72TRxesXvdio0OoutFb9c4Lzl4f8qbHlPV4a15f2+gQamJ0v975b7AaCvgKcycPM7Oi82tozcwstzyPHqkXJw8zs4Kr94ueKuHkYWZWcO62MjOz3Jw8zMwst3q/JbASTh5mZgXnMQ8zM8vNV1uZmVluLQXsuHLyMDMruCIOmNfsqbqS/pxz/XEl79P9VHoVYlf2+/XuxGFmVjRR4VRPNUseEXFAN7adERFTurj5RsmjO3GYmRVBS4VTPdWy5bE2/R0n6Q5J10l6QNJVkpSWfSSV3QN8umTbiZJ+mua3lnSjpCVpOiCV/1bSfZJWpFcvImkKMEDSYklXtYlDkn4gabmkZZKOKhefmVkRbFBUNNVTvcY83ge8B3gamA2MlbQAuBQ4BPgrcHUH2/4YuDMijpDUBxiUyk+MiDWSBgDzJV0fEZMlnRYRo9up59PAaGBPYKu0zV0dxQfcU7px6buBN+03jL59S1//a2ZWO8UbLq/fmwTnRcSTEdECLAaagF2BRyLioYgI4MoOtj0E+BlARDRHxAup/AxJS4A5wPbAyDIxHAj8OtXxDHAnsE8n8W0kIqZGxJiIGOPEYWb1VMRuq3q1PF4rmW8u2W+XEmp6cfsEYP+IeEXSHWTv4+10sy7EZ2bWcEW8VLeR7zB/AHinpJ3S52M6WG8W8CUASX0kbQ4MAf6WEseuwH4l66+X1K+deu4Cjkp1vB04GJhXjQMxM6ult9TVVuVExKtkYwi/TwPmj3Ww6peBD0paBtxHNjZxC9BX0lLgu2RdV62mAktbB8xL3AgsBZYAfwS+GhH/W63jMTOrlSJ2WykbbrA8Bm7W1Ou+tAlbjWp0CDWxct2qRodQdS299P+z/zJw90aHUBNffezKbl+9eVbT0RX9R7/g0d/U7UpR9+2bmRVcEe8wd/IwMyu4KOCAuZOHmVnBueVhZma5FfFSXScPM7OCK17qcPIwMyu8DQVMH04eZmYF5wHzXmLS8P3Kr9TD3LT2wUaHUBMLP7xFo0Oouk3692l0CDWx03XzGx1CTXy1CnV4wNzMzHJzy8PMzHJzy8PMzHJrLuAjaZw8zMwKzvd5mJlZbh7zMDOz3DzmYWZmuRWx26qRbxI0M7MKRIX/K0fS9pL+JOl+SSskfTmVD5N0m6SH0t+yN0g5eZiZFVxzREVTBTYA/xoRu5G9vvtUSbsDk4FZETGS7NXfk8tV5ORhZlZwLURFUzkRsSoiFqb5l4D7gW2Bw4DpabXpwOHl6qpb8pB0WcpwbcsnSvppN+r9taSlks6StKukxZIWSdopZz3jJB3Q1TjMzGql0neYS5okaUHJNKmjOiU1Ae8D5gJbR8QqyBIMMLxcTHUbMI+IL1S7TknvAA6IiB3T58nATRHxrS5UNw5YC/y5ehGamXVfpZfqRsRUYGq59SQNAq4HzoyIF6X8rz6vSctD0kBJv5e0RNJySUdJukPSmLT8BEl/kXQnMLZku7dLul7S/DSNLanvF6lskaTD0iYzgeGptfEt4EzgC5L+lLY7TtK8tPwSSX1S+UckLUzxzUoZ+GTgrLTuQbX4XszMuqJa3VYAkvqRJY6rIuKGVPyMpBFp+Qjg2XL11Krl8RHg6Yj4eApmCPClksC+A+wNvAD8CViUtrsQuCAi7pG0A3ArsBvwDeCPEXGipKHAPEm3A58Cbo6I0aluAWsj4nxJuwFHAWMjYr2ki4BjJf0PcClwcEQ8ImlYRKyRdHHrtu0dUGr+TQIYP2wM7x2cq1fMzKzLokqPJ0nnyJ8D90fED0sWzQCOB6akvzeVq6tWyWMZcL6k/yI7ud9d0ix6P3BHRDwHIOlq4N1p2QRg95J1N5c0GDgU+JSks1N5f2AHYF0nMYwnS1DzU30DyLLpfsBdEfEIQESsqeSASpuDZzUdXbyLrs2s12qu3n0eY4HPAcskLU5lXydLGtdIOgl4HPhMuYpqkjwi4i+S9gY+Bnxf0sy2q3Sw6SbA/hGxUVJI2fKfIuLBNuVNnYQhYHpEnNNmm091sn8zs8Kp1k2CEXEP2bmxPePz1FWrMY9tgFci4krgfGCvksVzgXGStkx9b6UZbiZwWkk9o9PsrcDpKYkg6X0VhDELOFLS8LTNMEk7AvcCH5D0ztbytP5LwOB8R2pmVnsRUdFUT7W6VPe9ZOMSi8nGK77XuiBdBvZtspP47cDCku3OAMakS29Xkg1iA3wX6AcslbQ8fe5URKwEzgVmSloK3AaMSN1lk4AbJC0Brk6b/A44wgPmZlY01RwwrxbVO1v1Br1xzMOvoe05eu9raB9vdAg18dwLD+a/DraNcdtNqOicc8eTt3d7X5XygxHNzArOL4MyM7PcivhUXScPM7OCc/IwM7Pcijg27eRhZlZwbnmYmVlufoe5mZnl1hzFe4u5k0cXLFj/XKNDqLpBffs3OoSa+PLs3nefxydf653/rb4+ZOtGh1BYHvMwM7PcPOZhZma5eczDzMxya3G3lZmZ5eWWh5mZ5earrczMLDd3W5mZWW7utjIzs9zc8jAzs9zc8jAzs9yao7nRIbxJl99hLukMSfdLukrS2yTdnt7/fVQH658s6fPtlDel95LXJY5O6mmS9NmuxmFmVisRUdFUT91peZwCfDQiHpG0H9AvIkZ3tHJEXNyNfVUtjk40AZ8FflXN4MzMuquIjyepqOUh6SuSlqfpTEkXA+8CZkj6GnAlMDr94t9J0hRJKyUtlXR+quPbks5O83tLWiLpXuDUkv30kfQDSfPTtv9SsuzfSsq/k8rKxbG3pDsl3SfpVkkj0nY7pxbKEkkLJe0ETAEOStue1e1v1sysSnpky0PS3sAJwPsBAXOB44CPAB+MiNWS5gJnR8QnJA0DjgB2jYiQNLSdai8HTo+IOyX9oKT8JOCFiNhH0tuA2ZJmAiPTtG+KYYakgyPiZEkdxdEPuAI4LCKeS91Y5wEnAlcBUyLiRkn9yZLo5NZtO/geJgGTAHYesgvvGLhtua/OzKwqeurVVgcCN0bEywCSbgAO6mT9F4FXgcsk/R64uXShpCHA0Ii4MxVdAXw0zR8K7CHpyPR5CFnSODRNi1L5oFR+Vydx7AKMAm6TBNAHWCVpMLBtRNwIEBGvprg6qQoiYiowFeCgbccX77+kmfVaPfVqq87Pqm1ExAZJ+wLjgaOB04BD2tTX0TchshbJrRsVSh8Gvh8Rl+QIRcCKiNi/TV2b56jDzKzhivh4kkrGPO4CDpe0maSBZF1Sd3e0sqRBwJCI+ANwJrDR4HVE/B14QdKBqejYksW3Al9KXU5Ienfa563AialuJG0raXiZuB8E3i5p/7RNP0nviYgXgSclHZ7K3yZpM+AlYHDZb8PMrM565JhHRCyUNA2Yl4oui4hFnXTzDAZuSmMJAtobfD4B+IWkV8gSQ6vLyK56WqhsB88Bh0fETEm7Afem/a4lG3d5tpO4X0/dXz9OXWV9gR8BK4DPAZdI+g9gPfAZYCmwQdISYFpEXNDJ12JmVjdFHPNQEV9vWHS9cczjheZXGh1CTbyv/zaNDqHqeutraJ/ol6uHvMc46/Eru31gWwzauaJzzt/W/rVuX6LvMDczK7gi3ufh5GFmVnBF7CFy8jAzK7giXm3l5GFmVnBFHDB38jAzK7gidlt1+am6ZmZWH1Hh/yoh6SOSHpT0V0mTuxqTk4eZWcFV6yZBSX2A/0v2SKjdgWMk7d6VmJw8zMwKriWioqkC+wJ/jYiHI+J14DfAYV2JyWMeXXD3U7PqdiOOpEnpoYy9Rm88Juidx9Ubjwl63nFteP2pis45pU//Tqa2Oc5tgSdKPj9J9sT03NzyKL5J5VfpcXrjMUHvPK7eeEzQS48rIqZGxJiSqW2CbC8JdWk03snDzOyt40lg+5LP2wFPd6UiJw8zs7eO+cBISe+UtCnZazNmdKUij3kUX4/pl82hNx4T9M7j6o3HBL33uDqV3rd0GtnTzPsAv4iIFV2py0/VNTOz3NxtZWZmuTl5mJlZbk4eDSTpCEkhadf0eZykm9usMy29EbH1VbpTJD0kabmkeZI+2ojY2+rCsdyRHpGwRNJsSbu0Uz5f0ug37y13bM2SFpdMTd2ts6cp+Q6WS/qdpKFVrn+ipJ+m+W9LOrtK9ZbGfW16ZXR36xwj6cedLN9G0nXd3U9v5+TRWMcA95Bd8VCJ7wIjgFERMQr4JMV573reYwE4NiL2BKYDP2in/KI25V21LiJGl0yPVqHOmkqPkaim1u9gFLAGOLXK9ddKadyvAyeXLlQm13ksIhZExBmdLH86Io7sWrhvHU4eDSJpEDAWOIkKTrjpF9cXgdMj4jWAiHgmIq6paaAVyHss7bgL2Lmd8nvJ7oitudTiuUDSXZLul7SPpBtSK+97Jesdl1p8iyVd0nqSl/QzSQskrZD0nZL1p0haKWmppPNT2RstsPR5bfo7TtKfJP0KWNbZ/rppo+9V0r+lVt7SNrF/PpUtkXRFKvukpLmSFkm6XdLWVYinUncDO0tqSv+NLgIWAttLOlTSvZIWphbKoBTvPpL+nI5hnqTBpa1iSR8oaZEuSsubJC1Py/tLulzSsrT8g6l8Yvr3cUv6N/LfdfweCsGX6jbO4cAtEfEXSWsk7VVm/Z2BxyPixTrEllfeY2nrk6STZRsfAX7b7ehggKTFaf6RiDiig/Vej4iDJX0ZuAnYm+xX+v+TdAEwHDgKGBsR69PJ61jgl8A3ImJNOrnPkrQH2Q1ZRwC7RkRU2FW0L1nL8hFJu3Wyvy5J8Y0Hfp4+HwqMTPsVMEPSwcDzwDfSvldLGpaquAfYLx3PF4CvAv/a1XhyxN2X7GF+t6SiXYATIuIUSVsB5wITIuJlSV8DviJpCnA1cFREzJe0ObCuTdVnA6dGxOyUcF5ts/xUgGyNRvcAAAO9SURBVIh4r7Iu2ZmS3p2WjQbeB7wGPCjpJxHxBG8RTh6NcwzwozT/m/T55g7WLfr11F09lqskrQMeBU5vUz6Q7Dr0vImoPesiopKxk9abpZYBKyJiFYCkh8nuyj2QLKHMlwQwAHg2bfPPyp4r1Jesa3F3YCXZyegySb+n4++k1LyIeCTNj+9kf3m1JtAm4D7gtlR+aJoWpc+DyJLJnsB1EbEaICLWpOXbAVdLGgFsCrTGWiulif9usqS3DfBYRMxJ5fuRfd+z0/e0KVnrahdgVUTMT8fwIkBap9Vs4IeSrgJuiIgn2yw/EPhJ2v4BSY8BrcljVkS8kOpcCezIxs+N6tWcPBpA0pbAIcAoSUF2kgyyX5RbtFl9GLAa+Cuwg6TBEfFSPePtTBePpdWxEbGgnWqPBZYAU8geH/3pGsR9Odmvxqcj4mOp+LX0t6VkvvVzX7Jf5tMj4pw2db2T7BfsPhHxN0nTgP7phqx9yZLA0cBpZN/VBlKXsbIz1aYl1b1cWnV7++uidRExWtIQsiR2KvDjtI/vR8QlbY7pDNr/0fIT4IcRMUPSOODbVYitM29K/Onk3vZ7ui0ijmmz3h6U+eEVEVNSYv8YMEfSBDZufXT2QMLSfyPNvMXOpx7zaIwjgV9GxI4R0RQR25P9ghsGbJO6K5C0I9kvwMUR8QrZr64fK3usAJJGSDquMYfwhtzHUkmlEbGerCtiv9Y6qikiTkgDsR8rv/YbZgFHShoOIGlYOq7NyU5mL6QxgI+m5YOAIRHxB+BMsm4OyFpae6f5w4B+OffXZemX8hnA2ZL6kd1pfGLJGMG2aX+zyFpTW7buO1UxBHgqzR/fnViqaA4wVtLOkI0Ppq6lB8j+De6Tygen7q83SNopIpZFxH8BC4Bd29R9F9mPGVKdOwAP1vRoeoi3VKYskGPIflWXup7s1+lxwOWS+gPrgS+0No3JTqbfA1ZKepXshPXN+oTcoa4eS1kRsU7S/yH7VX9SleLtsohYKelcsn7vTciO6dSImCNpEbACeJisKwSyK+FuSscv4KxUfmkqn0d2ki79FV12f8Bj3TyORZKWAEdHxBUpOd+bftGvBY6LiBWSzgPulNRM1q01kaylca2kp8hO2u/sTizVEBHPSZoI/FrS21LxuWkM7ijgJ5IGkI13TGiz+ZlpELyZrJvxf8i6HVtdBFwsaRlZi3FiRLzWpmvrLcmPJzEzs9zcbWVmZrk5eZiZWW5OHmZmlpuTh5mZ5ebkYWZmuTl5mJlZbk4eZmaW2/8HHRw7u/Rv3h8AAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "heatmap(df_diff_percent, yticklabels=df_paperIndividualScores[\"Similarity\"])" ] }, { "cell_type": "code", - "execution_count": 837, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.057754824999999996" - ] - }, - "execution_count": 837, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from sklearn.metrics import mean_squared_error\n", "mean_squared_error(df_paperIndividualScores[diff_metrics],\n", diff --git a/notebooks/02_AA_Skorch_DDI.ipynb b/notebooks/02_AA_Skorch_DDI.ipynb new file mode 100644 index 0000000..111f7ab --- /dev/null +++ b/notebooks/02_AA_Skorch_DDI.ipynb @@ -0,0 +1,581 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](https://scikit-learn.org/stable/_images/grid_search_workflow.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 1230, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')" + ] + }, + { + "cell_type": "code", + "execution_count": 1231, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import pickle\n", + "\n", + "from sklearn.datasets import make_classification\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.preprocessing import LabelEncoder\n", + "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.model_selection import StratifiedKFold\n", + "from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, matthews_corrcoef, precision_recall_curve, auc\n", + "\n", + "from keras.utils import np_utils\n", + "\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import TensorDataset\n", + "from torch.utils.data import Dataset\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "from torch.optim import SGD\n", + "\n", + "import skorch\n", + "from skorch import NeuralNetClassifier\n", + "from skorch.callbacks import EpochScoring\n", + "from skorch.callbacks import TensorBoard\n", + "from skorch.helper import predefined_split" + ] + }, + { + "cell_type": "code", + "execution_count": 1232, + "metadata": {}, + "outputs": [], + "source": [ + "# import configurations (file paths, etc.)\n", + "import yaml\n", + "try:\n", + " from yaml import CLoader as Loader, CDumper as Dumper\n", + "except ImportError:\n", + " from yaml import Loader, Dumper\n", + " \n", + "configFile = '../cluster/data/medinfmk/ddi/config/config.yml'\n", + "\n", + "with open(configFile, 'r') as ymlfile:\n", + " cfg = yaml.load(ymlfile, Loader=Loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 1233, + "metadata": {}, + "outputs": [], + "source": [ + "pathInput = cfg['filePaths']['dirRaw']\n", + "pathOutput = cfg['filePaths']['dirProcessed']\n", + "# path to store python binary files (pickles)\n", + "# in order not to recalculate them every time\n", + "pathPickles = cfg['filePaths']['dirProcessedFiles']['dirPickles']\n", + "pathRuns = cfg['filePaths']['dirProcessedFiles']['dirRuns']\n", + "pathPaperScores = cfg['filePaths']['dirRawFiles']['paper-individual-metrics-scores']\n", + "datasetDirs = cfg['filePaths']['dirRawDatasets']\n", + "DS1_path = str(datasetDirs[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 1234, + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_data(input_fea, input_lab, seperate=False):\n", + " offside_sim_path = input_fea\n", + " drug_interaction_matrix_path = input_lab\n", + " drug_fea = np.loadtxt(offside_sim_path,dtype=float,delimiter=\",\")\n", + " interaction = np.loadtxt(drug_interaction_matrix_path,dtype=int,delimiter=\",\")\n", + " \n", + " train = []\n", + " label = []\n", + " tmp_fea=[]\n", + " drug_fea_tmp = []\n", + " \n", + " for i in range(0, (interaction.shape[0]-1)):\n", + " for j in range((i+1), interaction.shape[1]):\n", + " label.append(interaction[i,j])\n", + " drug_fea_tmp_1 = list(drug_fea[i])\n", + " drug_fea_tmp_2 = list(drug_fea[j])\n", + " if seperate:\n", + " tmp_fea = (drug_fea_tmp_1,drug_fea_tmp_2)\n", + " else:\n", + " tmp_fea = drug_fea_tmp_1 + drug_fea_tmp_2\n", + " train.append(tmp_fea)\n", + "\n", + " return np.array(train), np.array(label)" + ] + }, + { + "cell_type": "code", + "execution_count": 1235, + "metadata": {}, + "outputs": [], + "source": [ + "def transfer_array_format(data):\n", + " formated_matrix1 = []\n", + " formated_matrix2 = []\n", + " for val in data:\n", + " formated_matrix1.append(val[0])\n", + " formated_matrix2.append(val[1])\n", + " return np.array(formated_matrix1), np.array(formated_matrix2)" + ] + }, + { + "cell_type": "code", + "execution_count": 1236, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_labels(labels, encoder=None, categorical=True):\n", + " if not encoder:\n", + " encoder = LabelEncoder()\n", + " encoder.fit(labels)\n", + " y = encoder.transform(labels).astype(np.int32)\n", + " if categorical:\n", + " y = np_utils.to_categorical(y)\n", + "# print(y)\n", + " return y, encoder" + ] + }, + { + "cell_type": "code", + "execution_count": 1237, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_names(labels, encoder=None, categorical=True):\n", + " if not encoder:\n", + " encoder = LabelEncoder()\n", + " encoder.fit(labels)\n", + " if categorical:\n", + " labels = np_utils.to_categorical(labels)\n", + " return labels, encoder" + ] + }, + { + "cell_type": "code", + "execution_count": 1238, + "metadata": {}, + "outputs": [], + "source": [ + "def getStratifiedKFoldSplit(X,y,n_splits):\n", + " skf = StratifiedKFold(n_splits=n_splits)\n", + " return skf.split(X,y)" + ] + }, + { + "cell_type": "code", + "execution_count": 1239, + "metadata": {}, + "outputs": [], + "source": [ + "class NDD(nn.Module):\n", + " def __init__(self, D_in=1096, H1=300, H2=400, D_out=2, drop=0.5):\n", + " super(NDD, self).__init__()\n", + " # an affine operation: y = Wx + b\n", + " self.fc1 = nn.Linear(D_in, H1) # Fully Connected\n", + " self.fc2 = nn.Linear(H1, H2)\n", + " self.fc3 = nn.Linear(H2, D_out)\n", + " self.drop = nn.Dropout(drop)\n", + " self._init_weights()\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(self.fc1(x))\n", + " x = self.drop(x)\n", + " x = F.relu(self.fc2(x))\n", + " x = self.drop(x)\n", + " x = self.fc3(x)\n", + " return x\n", + " \n", + " def _init_weights(self):\n", + " for m in self.modules():\n", + " if(isinstance(m, nn.Linear)):\n", + " m.weight.data.normal_(0, 0.05)\n", + " m.bias.data.uniform_(-1,0)" + ] + }, + { + "cell_type": "code", + "execution_count": 1240, + "metadata": {}, + "outputs": [], + "source": [ + "def updateSimilarityDFSingleMetric(df, sim_type, metric, value):\n", + " df.loc[df['Similarity'] == sim_type, metric ] = round(value,3)\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 1241, + "metadata": {}, + "outputs": [], + "source": [ + "def updateSimilarityDF(df, sim_type, AUROC, AUPR, F1, Rec, Prec):\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'AUC', AUROC)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'AUPR', AUPR)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'F-measure', F1)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'Recall', Rec)\n", + " df = updateSimilarityDFSingleMetric(df, sim_type, 'Precision', Prec)\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 1242, + "metadata": {}, + "outputs": [], + "source": [ + "def getNetParamsStr(net, str_hidden_layers_params, net_params_to_print=[\"max_epochs\", \"batch_size\"]):\n", + " net_params = [val for sublist in [[x,net.get_params()[x]] for x in net_params_to_print] for val in sublist]\n", + " net_params_str = '-'.join(map(str, flattened))\n", + " return(net_params_str+str_hidden_layers_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 1243, + "metadata": {}, + "outputs": [], + "source": [ + "def writeReplicatedIndividualScoresCSV(net, df, destination, str_hidden_layers_params):\n", + " filePath = destination + \"replicatedIndividualScores_\" + getNetParamsStr(net, str_hidden_layers_params) + \".csv\"\n", + " df.to_csv(path_or_buf = filePath, index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 1244, + "metadata": {}, + "outputs": [], + "source": [ + "def getNDDClassifier(D_in, H1, H2, D_out, drop, Xy_test):\n", + " model = NDD(D_in, H1, H2, D_out, drop)\n", + " \n", + " net = NeuralNetClassifier(\n", + " model,\n", + "# criterion=nn.CrossEntropyLoss,\n", + " criterion=nn.BCEWithLogitsLoss,\n", + " max_epochs=20,\n", + " optimizer=SGD,\n", + " optimizer__lr=0.01,\n", + " optimizer__momentum=0.9, \n", + " optimizer__weight_decay=1e-6, \n", + " optimizer__nesterov=True, \n", + " batch_size=200,\n", + " callbacks=callbacks,\n", + " # Shuffle training data on each epoch\n", + " iterator_train__shuffle=True,\n", + " device=device,\n", + " train_split=predefined_split(Xy_test),\n", + " )\n", + " return net" + ] + }, + { + "cell_type": "code", + "execution_count": 1245, + "metadata": {}, + "outputs": [], + "source": [ + "def avgMetrics(AUROC, AUPR, F1, Rec, Prec, kfold_nsplits):\n", + " AUROC /= kfold_nsplits\n", + " AUPR /= kfold_nsplits\n", + " F1 /= kfold_nsplits\n", + " Rec /= kfold_nsplits\n", + " Prec /= kfold_nsplits\n", + " return AUROC, AUPR, F1, Rec, Prec" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run" + ] + }, + { + "cell_type": "code", + "execution_count": 1246, + "metadata": {}, + "outputs": [], + "source": [ + "df_paperIndividualScores = pd.read_csv(pathPaperScores)\n", + "\n", + "df_replicatedIndividualScores = df_paperIndividualScores.copy()\n", + "\n", + "for col in df_replicatedIndividualScores.columns:\n", + " if col != 'Similarity':\n", + " df_replicatedIndividualScores[col].values[:] = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 1247, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "soft = nn.Softmax(dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 1248, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Preparing sideeffect data...\n", + "Running fold0 for sideeffect...\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Classification metrics can't handle a mix of multilabel-indicator and binary targets", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mmodelPicklePath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpathPickles\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m\"model_params/model_params_fold\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"_\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr_hidden_layers_params\u001b[0m\u001b[0;34m+\u001b[0m \u001b[0;34m\"_\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msimilarity\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\".p\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdo_train_model\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodelPicklePath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/classifier.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0;31m# this is actually a pylint bug:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;31m# https://github.com/PyCQA/pylint/issues/1085\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mNeuralNetClassifier\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[1;32m 846\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 847\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 848\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpartial_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 849\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 850\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mpartial_fit\u001b[0;34m(self, X, y, classes, **fit_params)\u001b[0m\n\u001b[1;32m 805\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'on_train_begin'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 806\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 807\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 808\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mfit_loop\u001b[0;34m(self, X, y, epochs, **fit_params)\u001b[0m\n\u001b[1;32m 760\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecord\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"valid_batch_count\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_batch_count\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 761\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 762\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'on_epoch_end'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mon_epoch_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 763\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 764\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/net.py\u001b[0m in \u001b[0;36mnotify\u001b[0;34m(self, method_name, **cb_kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mcb_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallbacks_\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mcb_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0;31m# pylint: disable=unused-argument\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/callbacks/scoring.py\u001b[0m in \u001b[0;36mon_epoch_end\u001b[0;34m(self, net, dataset_train, dataset_valid, **kwargs)\u001b[0m\n\u001b[1;32m 410\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 411\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcache_net_infer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_caching\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mcached_net\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 412\u001b[0;31m \u001b[0mcurrent_score\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_scoring\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcached_net\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 413\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 414\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_record_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcurrent_score\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/skorch/callbacks/scoring.py\u001b[0m in \u001b[0;36m_scoring\u001b[0;34m(self, net, X_test, y_test)\u001b[0m\n\u001b[1;32m 119\u001b[0m instead of running inference again, if available.\"\"\"\n\u001b[1;32m 120\u001b[0m \u001b[0mscorer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcheck_scoring\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscoring_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 121\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mscorer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 122\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_is_best_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcurrent_score\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_scorer.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, estimator, X, y_true, sample_weight)\u001b[0m\n\u001b[1;32m 167\u001b[0m stacklevel=2)\n\u001b[1;32m 168\u001b[0m return self._score(partial(_cached_call, None), estimator, X, y_true,\n\u001b[0;32m--> 169\u001b[0;31m sample_weight=sample_weight)\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_factory_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_scorer.py\u001b[0m in \u001b[0;36m_score\u001b[0;34m(self, method_caller, estimator, X, y_true, sample_weight)\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 211\u001b[0m return self._sign * self._score_func(y_true, y_pred,\n\u001b[0;32m--> 212\u001b[0;31m **self._kwargs)\n\u001b[0m\u001b[1;32m 213\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_classification.py\u001b[0m in \u001b[0;36maccuracy_score\u001b[0;34m(y_true, y_pred, normalize, sample_weight)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0;31m# Compute accuracy for each possible representation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0my_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_check_targets\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0mcheck_consistent_length\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0my_type\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstartswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'multilabel'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/sklearn/metrics/_classification.py\u001b[0m in \u001b[0;36m_check_targets\u001b[0;34m(y_true, y_pred)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_type\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m raise ValueError(\"Classification metrics can't handle a mix of {0} \"\n\u001b[0;32m---> 90\u001b[0;31m \"and {1} targets\".format(type_true, type_pred))\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# We can't have more than one value on y_type => The set is no more needed\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: Classification metrics can't handle a mix of multilabel-indicator and binary targets" + ] + } + ], + "source": [ + "do_prepare_data = True\n", + "do_train_model = True\n", + "kfold_nsplits = 5\n", + "# similaritiesToRun = df_paperIndividualScores['Similarity']\n", + "similaritiesToRun = [\"sideeffect\"]\n", + "\n", + "for similarity in similaritiesToRun:\n", + " input_fea = pathInput+DS1_path+\"/\" + similarity + \"_Jacarrd_sim.csv\"\n", + " input_lab = pathInput+DS1_path+\"/drug_drug_matrix.csv\"\n", + " dataPicklePath = pathPickles+\"data_X_y_\" + similarity + \"_Jaccard.p\"\n", + "\n", + " # Define model\n", + " D_in, H1, H2, D_out, drop = X.shape[1], 300, 400, 2, 0.5\n", + " str_hidden_layers_params = \"-H1-\" + str(H1) + \"-H2-\" + str(H2)\n", + " callbacks = []\n", + " \n", + " # Prepare data if not available\n", + " if do_prepare_data:\n", + " print(\"Preparing \" + similarity + \" data...\")\n", + " X,y = prepare_data(input_fea, input_lab, seperate = False)\n", + "\n", + " with open(dataPicklePath, 'wb') as f:\n", + " pickle.dump([X, y], f)\n", + "\n", + " # Load X,y and split in to train, test\n", + " with open(dataPicklePath, 'rb') as f:\n", + " X, y = pickle.load(f)\n", + " \n", + " X = X.astype(np.float32)\n", + " y = y.astype(np.int64) \n", + " \n", + " y_cat = np_utils.to_categorical(y)\n", + " \n", + " AUROC, AUPR, F1, Rec, Prec = 0,0,0,0,0\n", + " kFoldSplit = getStratifiedKFoldSplit(X,y,n_splits=kfold_nsplits)\n", + " for i, indices in enumerate(kFoldSplit):\n", + " print(\"Running fold\" + str(i) + \" for \" + similarity +\"...\")\n", + " \n", + " train_index = indices[0]\n", + " test_index = indices[1]\n", + " X_train, X_test = X[train_index], X[test_index]\n", + "# y_train, y_test = y[train_index], y[test_index]\n", + " y_train, y_test = y_cat[train_index], y_cat[test_index]\n", + " \n", + " # Create Network Classifier\n", + " Xy_test = skorch.dataset.Dataset(X_test, y_test)\n", + " net = getNDDClassifier(D_in, H1, H2, D_out, drop, Xy_test)\n", + " \n", + " # Fit and save OR load model\n", + " modelPicklePath = pathPickles+\"model_params/model_params_fold\" + str(i) + \"_\" + str_hidden_layers_params+ \"_\" + similarity + \".p\"\n", + " if do_train_model:\n", + " net.fit(X_train, y_train)\n", + " net.save_params(f_params=modelPicklePath)\n", + " else:\n", + " net.initialize() # This is important!\n", + " net.load_params(f_params=modelPicklePath)\n", + "\n", + " # Make predictions\n", + " y_pred = net.predict(X_test)\n", + " lr_probs = soft(net.forward(X_test))[:,1]\n", + " lr_precision, lr_recall, _ = precision_recall_curve(y_test, lr_probs)\n", + "\n", + " AUROC += roc_auc_score(y_test, y_pred)\n", + " AUPR += auc(lr_recall, lr_precision)\n", + " F1 += f1_score(y_test, y_pred)\n", + " Rec += recall_score(y_test, y_pred)\n", + " Prec += precision_score(y_test, y_pred)\n", + " \n", + " print(i, similarity, AUROC, AUPR, F1, Rec, Prec)\n", + " \n", + " \n", + " AUROC, AUPR, F1, Rec, Prec = avgMetrics(AUROC, AUPR, F1, Rec, Prec, kfold_nsplits)\n", + " print(similarity, AUROC, AUPR, F1, Rec, Prec)\n", + " \n", + " # Fill replicated metrics\n", + " updateSimilarityDF(df_replicatedIndividualScores, similarity, AUROC, AUPR, F1, Rec, Prec)\n", + " \n", + "# Write CSV\n", + "writeReplicatedIndividualScoresCSV(net, df_replicatedIndividualScores, pathRuns, str_hidden_layers_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compare to Paper" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "print(df_paperIndividualScores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(df_replicatedIndividualScores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "diff_metrics = ['AUC', 'AUPR', 'F-measure', 'Recall', 'Precision']\n", + "df_diff = df_paperIndividualScores[diff_metrics] - df_replicatedIndividualScores[diff_metrics]\n", + "df_diff_abs = df_diff.abs()\n", + "df_diff_percent = (df_diff_abs / df_paperIndividualScores[diff_metrics]) * 100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_diff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from seaborn import heatmap\n", + "heatmap(df_diff, yticklabels=df_paperIndividualScores[\"Similarity\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "heatmap(df_diff_abs, yticklabels=df_paperIndividualScores[\"Similarity\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "heatmap(df_diff_percent, yticklabels=df_paperIndividualScores[\"Similarity\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import mean_squared_error\n", + "mean_squared_error(df_paperIndividualScores[diff_metrics],\n", + " df_replicatedIndividualScores[diff_metrics])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/inactive/TorchNDD.ipynb b/notebooks/inactive/TorchNDD.ipynb index 65a6e2e..2c6a5d3 100644 --- a/notebooks/inactive/TorchNDD.ipynb +++ b/notebooks/inactive/TorchNDD.ipynb @@ -1492,7 +1492,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.7.3" } }, "nbformat": 4,