From 623462dd29f1ac73a7906f2d568483c0310ed8ca Mon Sep 17 00:00:00 2001 From: clementine Date: Mon, 23 Oct 2023 00:40:28 +0100 Subject: [PATCH] pre-commit --- neuralplayground/agents/domine_2023.py | 423 +++++++++++++----- .../class_Graph_generation.py | 74 ++- .../domine_2023_extras/class_config.yaml | 2 +- .../class_grid_run_config.py | 4 +- .../agents/domine_2023_extras/class_models.py | 10 +- .../class_plotting_utils.py | 200 +++++++-- .../agents/domine_2023_extras/class_run.py | 101 ++++- .../agents/domine_2023_extras/class_test.py | 39 +- .../agents/domine_2023_extras/class_utils.py | 5 +- .../agents/jax_optimised_version.py | 321 ++++++++----- 10 files changed, 856 insertions(+), 323 deletions(-) diff --git a/neuralplayground/agents/domine_2023.py b/neuralplayground/agents/domine_2023.py index 2a4caade..5a565f73 100644 --- a/neuralplayground/agents/domine_2023.py +++ b/neuralplayground/agents/domine_2023.py @@ -16,7 +16,9 @@ from neuralplayground.agents.agent_core import AgentCore os.environ["KMP_DUPLICATE_LIB_OK"] = "True" -from neuralplayground.agents.domine_2023_extras.class_Graph_generation import sample_padded_grid_batch_shortest_path +from neuralplayground.agents.domine_2023_extras.class_Graph_generation import ( + sample_padded_grid_batch_shortest_path, +) from neuralplayground.agents.domine_2023_extras.class_grid_run_config import GridConfig from neuralplayground.agents.domine_2023_extras.class_models import get_forward_function from neuralplayground.agents.domine_2023_extras.class_plotting_utils import ( @@ -24,18 +26,22 @@ plot_input_target_output, plot_message_passing_layers, plot_xy, + plot_curves, +) +from neuralplayground.agents.domine_2023_extras.class_utils import ( + rng_sequence_from_rng, + set_device, ) -from neuralplayground.agents.domine_2023_extras.class_utils import rng_sequence_from_rng, set_device from sklearn.metrics import matthews_corrcoef, roc_auc_score -class Domine2023(AgentCore,): - - - def __init__ ( # autogenerated +class Domine2023( + AgentCore, +): + def __init__( # autogenerated self, - #agent_name: str = "SR", - experiment_name = 'smaller size generalisation graph with no position feature', + # agent_name: str = "SR", + experiment_name="smaller size generalisation graph with no position feature", train_on_shortest_path: bool = True, resample: bool = True, wandb_on: bool = False, @@ -43,15 +49,15 @@ def __init__ ( # autogenerated feature_position: bool = False, weighted: bool = True, num_hidden: int = 100, - num_layers : int = 2, + num_layers: int = 2, num_message_passing_steps: int = 3, learning_rate: float = 0.001, num_training_steps: int = 10, batch_size: int = 4, - nx_min: int = 4, + nx_min: int = 4, nx_max: int = 7, - batch_size_test: int= 4, - nx_min_test: int = 4 , + batch_size_test: int = 4, + nx_min_test: int = 4, nx_max_test: int = 7, **mod_kwargs, ): @@ -74,7 +80,7 @@ def __init__ ( # autogenerated self.num_training_steps = num_training_steps # cconfig.num_training_steps # @param - self.batch_size =batch_size + self.batch_size = batch_size self.nx_min = nx_min self.nx_max = nx_max @@ -82,9 +88,9 @@ def __init__ ( # autogenerated # Could be explained during sleep self.batch_size_test = batch_size_test self.nx_min_test = nx_min_test # This is thought of the state density - self.nx_max_test = nx_max_test # This is thought of the state density + self.nx_max_test = nx_max_test # This is thought of the state density self.batch_size = batch_size - self.nx_min = nx_min # This is thought of the state density + self.nx_min = nx_min # This is thought of the state density self.nx_max = nx_max self.arena_x_limits = mod_kwargs["arena_y_limits"] @@ -102,8 +108,9 @@ def __init__ ( # autogenerated if self.wandb_on: dateTimeObj = datetime.now() wandb.init( - project="graph-brain", entity="graph-brain", - name="Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M") + project="graph-brain", + entity="graph-brain", + name="Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M"), ) self.wandb_logs = {} save_path = wandb.run.dir @@ -113,9 +120,16 @@ def __init__ ( # autogenerated else: dateTimeObj = datetime.now() save_path = os.path.join(Path(os.getcwd()).resolve(), "results") - os.mkdir(os.path.join(save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M"))) + os.mkdir( + os.path.join( + save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M") + ) + ) self.save_path = os.path.join( - os.path.join(save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M"))) + os.path.join( + save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M") + ) + ) self.reset() self.saving_run_parameters() @@ -123,16 +137,27 @@ def __init__ ( # autogenerated rng = jax.random.PRNGKey(self.seed) self.rng_seq = rng_sequence_from_rng(rng) - if self.train_on_shortest_path: self.graph, self.targets = sample_padded_grid_batch_shortest_path( - rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max + rng, + self.batch_size, + self.feature_position, + self.weighted, + self.nx_min, + self.nx_max, ) else: self.graph, self.targets = sample_padded_grid_batch_shortest_path( - rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max + rng, + self.batch_size, + self.feature_position, + self.weighted, + self.nx_min, + self.nx_max, ) - forward = get_forward_function(self.num_hidden, self.num_layers, self.num_message_passing_steps) + forward = get_forward_function( + self.num_hidden, self.num_layers, self.num_message_passing_steps + ) net_hk = hk.without_apply_rng(hk.transform(forward)) params = net_hk.init(rng, self.graph) self.params = params @@ -146,51 +171,61 @@ def compute_loss(params, inputs, targets): self._compute_loss = jax.jit(compute_loss) - def update_step(params,opt_state): + def update_step(params, opt_state): loss, grads = jax.value_and_grad(compute_loss)( params, self.graph, self.targets ) # jits inside of value_and_grad - updates, opt_state = optimizer.update(grads, opt_state,params) + updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) - return params, opt_state ,loss + return params, opt_state, loss self._update_step = jax.jit(update_step) def evaluate(params, inputs, target, Target_Value): outputs = net_hk.apply(params, inputs) if Target_Value: - roc_auc = roc_auc_score(np.squeeze(target), np.squeeze(outputs[0].nodes)) + roc_auc = roc_auc_score( + np.squeeze(target), np.squeeze(outputs[0].nodes) + ) else: roc_auc = False - MCC = matthews_corrcoef(np.squeeze(target), round(np.squeeze(outputs[0].nodes))) + MCC = matthews_corrcoef( + np.squeeze(target), round(np.squeeze(outputs[0].nodes)) + ) return outputs, roc_auc, MCC self._evaluate = evaluate - def saving_run_parameters(self): - path = os.path.join(self.save_path, "run.py") HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023.py") shutil.copyfile(HERE, path) path = os.path.join(self.save_path, "class_Graph_generation.py") - HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023_extras/class_Graph_generation.py") + HERE = os.path.join( + Path(os.getcwd()).resolve(), "domine_2023_extras/class_Graph_generation.py" + ) shutil.copyfile(HERE, path) path = os.path.join(self.save_path, "class_utils.py") - HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023_extras/class_utils.py") + HERE = os.path.join( + Path(os.getcwd()).resolve(), "domine_2023_extras/class_utils.py" + ) shutil.copyfile(HERE, path) path = os.path.join(self.save_path, "class_plotting_utils.py") - HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023_extras/class_plotting_utils.py") + HERE = os.path.join( + Path(os.getcwd()).resolve(), "domine_2023_extras/class_plotting_utils.py" + ) shutil.copyfile(HERE, path) path = os.path.join(self.save_path, "class_config_run.yaml") - HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023_extras/class_config.yaml") + HERE = os.path.join( + Path(os.getcwd()).resolve(), "domine_2023_extras/class_config.yaml" + ) shutil.copyfile(HERE, path) - def reset(self,a=1): + def reset(self, a=1): self.obs_history = [] # Initialize observation history to update weights later self.grad_history = [] self.global_steps = 0 @@ -210,69 +245,114 @@ def update(self): rng = next(self.rng_seq) if self.train_on_shortest_path: graph_test, target_test = sample_padded_grid_batch_shortest_path( - rng, self.batch_size_test, self.feature_position, self.weighted, self.nx_min_test, self.nx_max_test - ) + rng, + self.batch_size_test, + self.feature_position, + self.weighted, + self.nx_min_test, + self.nx_max_test, + ) rng = next(self.rng_seq) if self.resample: self.graph, self.targets = sample_padded_grid_batch_shortest_path( - rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max + rng, + self.batch_size, + self.feature_position, + self.weighted, + self.nx_min, + self.nx_max, ) else: - graph_test, target_test= sample_padded_grid_batch_shortest_path( - rng, self.batch_size_test, self.feature_position, self.weighted, self.nx_min_test, self.nx_max_test + graph_test, target_test = sample_padded_grid_batch_shortest_path( + rng, + self.batch_size_test, + self.feature_position, + self.weighted, + self.nx_min_test, + self.nx_max_test, + ) + target_test = np.reshape( + graph_test.nodes[:, 0], (graph_test.nodes[:, 0].shape[0], -1) ) - target_test = np.reshape(graph_test.nodes[:, 0], (graph_test.nodes[:, 0].shape[0], -1)) rng = next(self.rng_seq) - #target_test_wse = target_test - graph_test.nodes[:, 0] + # target_test_wse = target_test - graph_test.nodes[:, 0] if self.resample: - self.graph, self.targets= sample_padded_grid_batch_shortest_path( - rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max + self.graph, self.targets = sample_padded_grid_batch_shortest_path( + rng, + self.batch_size, + self.feature_position, + self.weighted, + self.nx_min, + self.nx_max, ) - self.targets = np.reshape(self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1)) #self.graph.nodes[:,0] - #target_wse = self.targets - self.graph.nodes[:, 0] + self.targets = np.reshape( + self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1) + ) # self.graph.nodes[:,0] + # target_wse = self.targets - self.graph.nodes[:, 0] if self.feature_position: - target_test_wse = target_test - np.reshape(graph_test.nodes[:, 0], (graph_test.nodes[:, 0].shape[0], -1)) - target_wse = self.targets - np.reshape(self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1)) + target_test_wse = target_test - np.reshape( + graph_test.nodes[:, 0], (graph_test.nodes[:, 0].shape[0], -1) + ) + target_wse = self.targets - np.reshape( + self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1) + ) else: target_test_wse = target_test - graph_test.nodes[:] target_wse = self.targets - self.graph.nodes[:] # Train - self.params,self.opt_state, loss = self._update_step(self.params,self.opt_state ) + self.params, self.opt_state, loss = self._update_step( + self.params, self.opt_state + ) self.losses.append(loss) - outputs_train, roc_auc_train, MCC_train = self._evaluate(self.params, self.graph, self.targets, True) + outputs_train, roc_auc_train, MCC_train = self._evaluate( + self.params, self.graph, self.targets, True + ) self.roc_aucs_train.append(roc_auc_train) self.MCCs_train.append(MCC_train) # Train without end start in the target loss_wse = self._compute_loss(self.params, self.graph, target_wse) self.losses_wse.append(loss_wse) - outputs_train_wse, roc_auc_train_wse, MCC_train_wse = self._evaluate(self.params, self.graph, target_wse, False) + outputs_train_wse, roc_auc_train_wse, MCC_train_wse = self._evaluate( + self.params, self.graph, target_wse, False + ) self.MCCs_train_wse.append(MCC_train_wse) # Test - loss_test = self._compute_loss(self.params,graph_test, target_test) + loss_test = self._compute_loss(self.params, graph_test, target_test) self.losses_test.append(loss_test) - outputs_test, roc_auc_test, MCC_test = self._evaluate(self.params, graph_test, target_test, True) + outputs_test, roc_auc_test, MCC_test = self._evaluate( + self.params, graph_test, target_test, True + ) self.roc_aucs_test.append(roc_auc_test) self.MCCs_test.append(MCC_test) # Test without end start in the target loss_test_wse = self._compute_loss(self.params, graph_test, target_test_wse) self.losses_test_wse.append(loss_test_wse) - outputs_test_wse, roc_auc_test_wse, MCC_test_wse = self._evaluate(self.params, graph_test, target_test_wse, False) + outputs_test_wse, roc_auc_test_wse, MCC_test_wse = self._evaluate( + self.params, graph_test, target_test_wse, False + ) self.MCCs_test_wse.append(MCC_test_wse) # Log - wandb_logs = {"loss": loss, "losses_test": loss_test, "roc_auc_test": roc_auc_test, "roc_auc": roc_auc_train} + wandb_logs = { + "loss": loss, + "losses_test": loss_test, + "roc_auc_test": roc_auc_test, + "roc_auc": roc_auc_train, + } if self.wandb_on: wandb.log(wandb_logs) self.global_steps = self.global_steps + 1 if self.global_steps % self.log_every == 0: - print(f"Training step {self.global_steps}: loss = {loss} , loss_test = {loss_test}, roc_auc_test = {roc_auc_test}, roc_auc_train = {roc_auc_train}") + print( + f"Training step {self.global_steps}: loss = {loss} , loss_test = {loss_test}, roc_auc_test = {roc_auc_test}, roc_auc_train = {roc_auc_train}" + ) return def print_and_plot(self): @@ -280,34 +360,61 @@ def print_and_plot(self): rng = next(self.rng_seq) if self.train_on_shortest_path: graph_test, target_test = sample_padded_grid_batch_shortest_path( - rng, self.batch_size_test, self.feature_position, self.weighted, self.nx_min_test, self.nx_max_test - ) + rng, + self.batch_size_test, + self.feature_position, + self.weighted, + self.nx_min_test, + self.nx_max_test, + ) else: rng = next(self.rng_seq) graph_test, target_test = sample_padded_grid_batch_shortest_path( - rng, self.batch_size_test, self.feature_position, self.weighted, self.nx_min_test, self.nx_max_test + rng, + self.batch_size_test, + self.feature_position, + self.weighted, + self.nx_min_test, + self.nx_max_test, + ) + target_test = np.reshape( + graph_test.nodes[:, 0], (graph_test.nodes[:, 0].shape[0], -1) ) - target_test = np.reshape(graph_test.nodes[:, 0], (graph_test.nodes[:, 0].shape[0], -1)) if self.feature_position: - target_test_wse = target_test - np.reshape(graph_test.nodes[:, 0], (graph_test.nodes[:, 0].shape[0], -1)) - target_wse = self.targets - np.reshape(self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1)) + target_test_wse = target_test - np.reshape( + graph_test.nodes[:, 0], (graph_test.nodes[:, 0].shape[0], -1) + ) + target_wse = self.targets - np.reshape( + self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1) + ) else: target_test_wse = target_test - graph_test.nodes[:] target_wse = self.targets - self.graph.nodes[:] - - outputs_test, roc_auc_test, MCC_test = self._evaluate(self.params, graph_test, target_test, True) - outputs_test_wse, roc_auc_test_wse, MCC_test_wse = self._evaluate(self.params, graph_test, target_test_wse, False) - outputs, roc_auc, MCC = self._evaluate(self.params, self.graph, self.targets, True) - outputs_wse, roc_auc_wse, MCC_wse = self._evaluate(self.params, self.graph, target_wse , False) + outputs_test, roc_auc_test, MCC_test = self._evaluate( + self.params, graph_test, target_test, True + ) + outputs_test_wse, roc_auc_test_wse, MCC_test_wse = self._evaluate( + self.params, graph_test, target_test_wse, False + ) + outputs, roc_auc, MCC = self._evaluate( + self.params, self.graph, self.targets, True + ) + outputs_wse, roc_auc_wse, MCC_wse = self._evaluate( + self.params, self.graph, target_wse, False + ) # SAVE PARAMETER (NOT WE SAVE THE FILES SO IT SHOULD BE THERE AS WELL ) if self.wandb_on: with open("readme.txt", "w") as f: f.write("readme") with open(os.path.join(self.save_path, "Constant.txt"), "w") as outfile: - outfile.write("num_message_passing_steps" + str(self.num_message_passing_steps) + "\n") + outfile.write( + "num_message_passing_steps" + + str(self.num_message_passing_steps) + + "\n" + ) outfile.write("Learning_rate:" + str(self.learning_rate) + "\n") outfile.write("num_training_steps:" + str(self.num_training_steps)) outfile.write("roc_auc" + str(roc_auc_test)) @@ -315,21 +422,87 @@ def print_and_plot(self): outfile.write("roc_auc_wse" + str(roc_auc_test_wse)) outfile.write("MCC_wse" + str(MCC_test_wse)) - # PLOTTING THE LOSS and AUC ROC - plot_xy(self.losses, os.path.join(self.save_path, "Losses.pdf"), "Losses") - plot_xy(self.losses_test, os.path.join(self.save_path, "Losses_test.pdf"), "Losses_test") + # PLOTTING THE LOSS and AUC RO + plot_curves( + [self.losses, self.losses_test, self.losses_wse, self.losses_test_wse], + os.path.join(self.save_path, "Losses.pdf"), + "All_Losses", + legend_labels=["loss", "loss test", "loss_wse", "loss_test_wse"], + ) - plot_xy(self.losses_wse, os.path.join(self.save_path, "Losses_wse.pdf"), "Losses_wse") - plot_xy(self.losses_test_wse, os.path.join(self.save_path, "Losses_test_wse.pdf"), "Losses_test_wse") + plot_curves( + [ + np.log(self.losses), + np.log(self.losses_test), + np.log(self.losses_wse), + np.log(self.losses_test_wse), + ], + os.path.join(self.save_path, "Log_Losses.pdf"), + "All_log_Losses", + legend_labels=[ + "log_loss", + "log_loss test", + "log_loss_wse", + "log_loss_test_wse", + ], + ) - plot_xy(self.roc_aucs_test, os.path.join(self.save_path, "auc_roc_test.pdf"), "auc_roc_test") - plot_xy(self.roc_aucs_train, os.path.join(self.save_path, "auc_roc_train.pdf"), "auc_roc_train") + plot_xy(self.losses, os.path.join(self.save_path, "Losses_train.pdf"), "Losses") + plot_xy( + self.losses_test, + os.path.join(self.save_path, "Losses_test.pdf"), + "Losses_test", + ) + plot_xy( + self.losses_wse, + os.path.join(self.save_path, "Losses_wse.pdf"), + "Losses_wse", + ) + plot_xy( + self.losses_test_wse, + os.path.join(self.save_path, "Losses_test_wse.pdf"), + "Losses_test_wse", + ) - plot_xy(self.MCCs_train, os.path.join(self.save_path, "MCC_train.pdf"), "MCC_train") - plot_xy(self.MCCs_test, os.path.join(self.save_path, "MCC_test.pdf"), "MCC_test") + plot_curves( + [self.roc_aucs_test, self.roc_aucs_train], + os.path.join(self.save_path, "auc_rocs.pdf"), + "All_auc_roc", + legend_labels=["auc_roc_test", "auc_roc_train"], + ) + plot_xy( + self.roc_aucs_test, + os.path.join(self.save_path, "auc_roc_test.pdf"), + "auc_roc_test", + ) + plot_xy( + self.roc_aucs_train, + os.path.join(self.save_path, "auc_roc_train.pdf"), + "auc_roc_train", + ) - plot_xy(self.MCCs_train_wse, os.path.join(self.save_path, "MCC_train_wse.pdf"), "MCC_train_wse") - plot_xy(self.MCCs_test_wse, os.path.join(self.save_path, "MCC_test_wse.pdf"), "MCC_test_wse") + plot_curves( + [self.MCCs_train, self.MCCs_test, self.MCCs_train_wse, self.MCCs_test_wse], + os.path.join(self.save_path, "MCCs.pdf"), + "All_MCCs", + legend_labels=["MCC", "MCC test", "MCC_wse", "MCC_test_wse"], + ) + plot_xy( + self.MCCs_train, os.path.join(self.save_path, "MCC_train.pdf"), "MCC_train" + ) + plot_xy( + self.MCCs_test, os.path.join(self.save_path, "MCC_test.pdf"), "MCC_test" + ) + plot_xy( + self.MCCs_train_wse, + os.path.join(self.save_path, "MCC_train_wse.pdf"), + "MCC_train_wse", + ) + plot_xy( + self.MCCs_test_wse, + os.path.join(self.save_path, "MCC_test_wse.pdf"), + "MCC_test_wse", + ) # PLOTTING ACTIVATION FOR TEST AND THE TARGET OF THE THING ( NOTE THAT IS WAS TRANED ON THE ALL THING) plot_input_target_output( @@ -340,6 +513,7 @@ def print_and_plot(self): 4, self.edge_lables, os.path.join(self.save_path, "in_out_targ_test.pdf"), + "in_out_targ_test", ) plot_message_passing_layers( list(graph_test.nodes.sum(-1)), @@ -350,7 +524,11 @@ def print_and_plot(self): 3, self.num_message_passing_steps, self.edge_lables, - os.path.join(self.save_path, "message_passing_graph_test.pdf"), + os.path.join( + self.save_path, + "message_passing_graph_test.pdf", + ), + "message_passing_graph_test", ) plot_input_target_output( @@ -361,6 +539,7 @@ def print_and_plot(self): 4, self.edge_lables, os.path.join(self.save_path, "in_out_targ_test_wse.pdf"), + "in_out_targ_test_wse", ) # Train @@ -373,6 +552,7 @@ def print_and_plot(self): 4, self.edge_lables, os.path.join(self.save_path, "in_out_targ_train.pdf"), + "in_out_targ_train", ) plot_input_target_output( @@ -383,9 +563,9 @@ def print_and_plot(self): 4, self.edge_lables, os.path.join(self.save_path, "in_out_targ_train_wse.pdf"), + "in_out_targ_train_wse", ) - # graph_test, target_test = sample_padded_grid_batch_shortest_path( # rng, self.batch_size_test, self.feature_position, self.weighted, self.nx_min_test, self.nx_max_test # ) @@ -413,20 +593,20 @@ def print_and_plot(self): # target_test.sum(-1), graph_test, os.path.join(self.save_path, "Target_test.pdf"), "Target_test", self.edge_lables # ) - # PLOTTING ACTIVATION OF THE FIRST 2 GRAPH OF THE BATCHe plot_input_target_output( - list( self.graph.nodes.sum(-1)), + list(self.graph.nodes.sum(-1)), target_wse.sum(-1), outputs_wse[0].nodes.tolist(), self.graph, 4, self.edge_lables, os.path.join(self.save_path, "in_out_targ_train_wse.pdf"), + "in_out_targ_train_wse", ) plot_message_passing_layers( - list( self.graph.nodes.sum(-1)), + list(self.graph.nodes.sum(-1)), outputs_wse[1], target_wse.sum(-1), outputs_wse[0].nodes.tolist(), @@ -434,21 +614,26 @@ def print_and_plot(self): 3, self.num_message_passing_steps, self.edge_lables, - os.path.join(self.save_path, "message_passing_graph_train_wse.pdf"), + os.path.join( + self.save_path, + "message_passing_graph_train_wse.pdf", + ), + "message_passing_graph_train_wse", ) plot_input_target_output( - list( self.graph.nodes.sum(-1)), + list(self.graph.nodes.sum(-1)), self.targets.sum(-1), outputs[0].nodes.tolist(), self.graph, 4, self.edge_lables, os.path.join(self.save_path, "in_out_targ_train.pdf"), + "in_out_targ_train", ) plot_message_passing_layers( - list( self.graph.nodes.sum(-1)), + list(self.graph.nodes.sum(-1)), outputs[1], self.targets.sum(-1), outputs[0].nodes.tolist(), @@ -457,6 +642,7 @@ def print_and_plot(self): self.num_message_passing_steps, self.edge_lables, os.path.join(self.save_path, "message_passing_graph_train.pdf"), + "message_passing_graph_train", ) # plot_message_passing_layers_units(outputs[1], target_test.sum(-1), outputs[0].nodes.tolist(),graph_test,config.num_hidden,config.num_message_passing_steps,edege_lables,os.path.join(save_path, 'message_passing_hidden_unit.pdf')) @@ -479,52 +665,55 @@ def print_and_plot(self): # target_test.sum(-1), graph_test, os.path.join(self.save_path, "Target_train.pdf"), "Target", self.edge_lables # ) print('End') + if __name__ == "__main__": from neuralplayground.arenas import Simple2D # @title Graph net functions parser = argparse.ArgumentParser() parser.add_argument( - "--config_path", - metavar="-C", - default="domine_2023_extras/class_config.yaml", - help="path to base configuration file.", -) + "--config_path", + metavar="-C", + default="domine_2023_extras/class_config.yaml", + help="path to base configuration file.", + ) args = parser.parse_args() set_device() config_class = GridConfig config = config_class(args.config_path) - # Init environment arena_x_limits = [-100, 100] arena_y_limits = [-100, 100] - #env = Simple2D + # env = Simple2D # time_step_size=time_step_size, # agent_step_size=agent_step_size, # arena_x_limits=arena_x_limits, # arena_y_limits=arena_y_limits, # ) - agent = Domine2023( experiment_name=config.experiment_name, - train_on_shortest_path= config.train_on_shortest_path, - resample= config.resample, # @param - wandb_on= config.wandb_on, - seed= config.seed, - feature_position= config.feature_position, - weighted= config.weighted, - num_hidden= config.num_hidden, # @param - num_layers= config.num_layers, # @param - num_message_passing_steps= config.num_message_passing_steps, # @param - learning_rate= config.learning_rate , # @param - num_training_steps= config.num_training_steps, # @param - batch_size= config.batch_size, - nx_min= config.nx_min, - nx_max= config. nx_max, - batch_size_test= config.batch_size_test, - nx_min_test= config.nx_min_test, - nx_max_test= 7,arena_y_limits=arena_y_limits, arena_x_limits=arena_x_limits + agent = Domine2023( + experiment_name=config.experiment_name, + train_on_shortest_path=config.train_on_shortest_path, + resample=config.resample, # @param + wandb_on=config.wandb_on, + seed=config.seed, + feature_position=config.feature_position, + weighted=config.weighted, + num_hidden=config.num_hidden, # @param + num_layers=config.num_layers, # @param + num_message_passing_steps=config.num_message_passing_steps, # @param + learning_rate=config.learning_rate, # @param + num_training_steps=config.num_training_steps, # @param + batch_size=config.batch_size, + nx_min=config.nx_min, + nx_max=config.nx_max, + batch_size_test=config.batch_size_test, + nx_min_test=config.nx_min_test, + nx_max_test=7, + arena_y_limits=arena_y_limits, + arena_x_limits=arena_x_limits, ) for n in range(config.num_training_steps): @@ -535,9 +724,9 @@ def print_and_plot(self): # The other alternative is to see that we have multiple env that we resample every time # TODO: Make juste an env type (so that is accomodates for not only 2 d env// different transmats) # TODO: Make The plotting in the general plotting utilse -#if __name__ == "__main__": +# if __name__ == "__main__": # x = Domine2023() - # x = x.replace(obs_history=[1, 2], num_hidden=2) - # x.num_hidden = 5 - # - # x.update() \ No newline at end of file +# x = x.replace(obs_history=[1, 2], num_hidden=2) +# x.num_hidden = 5 +# +# x.update() diff --git a/neuralplayground/agents/domine_2023_extras/class_Graph_generation.py b/neuralplayground/agents/domine_2023_extras/class_Graph_generation.py index 364fc70b..b4443c01 100644 --- a/neuralplayground/agents/domine_2023_extras/class_Graph_generation.py +++ b/neuralplayground/agents/domine_2023_extras/class_Graph_generation.py @@ -12,7 +12,14 @@ def get_grid_adjacency(n_x, n_y, atol=1e-1): def sample_padded_grid_batch_shortest_path( - rng, batch_size, feature_position, weighted, nx_min, nx_max, ny_min=None, ny_max=None + rng, + batch_size, + feature_position, + weighted, + nx_min, + nx_max, + ny_min=None, + ny_max=None, ): rng_seq = rng_sequence_from_rng(rng) """Sample a batch of grid graphs with variable sizes. @@ -43,37 +50,53 @@ def sample_padded_grid_batch_shortest_path( for n_x, n_y in zip(n_xs, n_ys): nx_graph = get_grid_adjacency(n_x, n_y) - senders, receivers, node_positions, edge_displacements, n_node, n_edge, global_context = grid_networkx_to_graphstuple( - nx_graph - ) + ( + senders, + receivers, + node_positions, + edge_displacements, + n_node, + n_edge, + global_context, + ) = grid_networkx_to_graphstuple(nx_graph) weights = add_weighted_edge(edge_displacements, n_edge, 1) for i, j in nx_graph.edges(): - nx_graph[i][j]['weight'] = weights[i] + nx_graph[i][j]["weight"] = weights[i] - i_start_1 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval= n_x) + i_start_1 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_x) i_start_2 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_y) - i_end_1 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval= n_x) + i_end_1 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_x) i_end_2 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_y) - start = tuple(np.concatenate( (i_start_1,i_start_2), axis=0 )) - end = tuple(np.concatenate( (i_end_1,i_end_2), axis=0 )) + start = tuple(np.concatenate((i_start_1, i_start_2), axis=0)) + end = tuple(np.concatenate((i_end_1, i_end_2), axis=0)) - nodes_on_shortest_path_indexes_not_weighted = nx.shortest_path(nx_graph, start, end) - nodes_on_shortest_path_indexes = nx.shortest_path(nx_graph, start, end, weight='weight') + nodes_on_shortest_path_indexes_not_weighted = nx.shortest_path( + nx_graph, start, end + ) + nodes_on_shortest_path_indexes = nx.shortest_path( + nx_graph, start, end, weight="weight" + ) # make it a node feature of the input graph if a node is a start/end node input_node_features = jnp.zeros((n_node, 1)) node_number_start = (i_start_1) * n_y + (i_start_2) node_number_end = (i_end_1) * n_y + (i_end_2) - input_node_features = input_node_features.at[node_number_start, 0].set(1) # set start node feature - input_node_features = input_node_features.at[node_number_end, 0].set(1) # set end node feature + input_node_features = input_node_features.at[node_number_start, 0].set( + 1 + ) # set start node feature + input_node_features = input_node_features.at[node_number_end, 0].set( + 1 + ) # set end node feature if feature_position: - input_node_features = jnp.concatenate((input_node_features, node_positions), axis=1) + input_node_features = jnp.concatenate( + (input_node_features, node_positions), axis=1 + ) if weighted: - edge_displacement = jnp.concatenate((edge_displacements,weights), axis=1) + edge_displacement = jnp.concatenate((edge_displacements, weights), axis=1) graph = jraph.GraphsTuple( nodes=input_node_features, senders=senders, @@ -98,16 +121,21 @@ def sample_padded_grid_batch_shortest_path( graphs.append(graph) nodes_on_shortest_labels = jnp.zeros((n_node, 1)) for i in nodes_on_shortest_path_indexes: - l=np.argwhere(np.all((node_positions - np.asarray(i)) == 0, axis=1)) - nodes_on_shortest_labels = nodes_on_shortest_labels.at[l[0,0]].set(1) + l = np.argwhere(np.all((node_positions - np.asarray(i)) == 0, axis=1)) + nodes_on_shortest_labels = nodes_on_shortest_labels.at[l[0, 0]].set(1) target.append(nodes_on_shortest_labels) # set start node feature targets = jnp.concatenate(target) target_pad = jnp.zeros(((max_n - len(targets)), 1)) padded_target = jnp.concatenate((targets, target_pad), axis=0) graph_batch = jraph.batch(graphs) - padded_graph_batch = jraph.pad_with_graphs(graph_batch, n_node=max_n, n_edge=max_e, n_graph=len(graphs) + 1) + padded_graph_batch = jraph.pad_with_graphs( + graph_batch, n_node=max_n, n_edge=max_e, n_graph=len(graphs) + 1 + ) - return padded_graph_batch, jnp.asarray(padded_target), + return ( + padded_graph_batch, + jnp.asarray(padded_target), + ) def grid_networkx_to_graphstuple(nx_graph): @@ -116,7 +144,9 @@ def grid_networkx_to_graphstuple(nx_graph): node_positions = jnp.array(nx_graph.nodes) node_to_inds = {n: i for i, n in enumerate(nx_graph.nodes)} senders_receivers = [(node_to_inds[s], node_to_inds[r]) for s, r in nx_graph.edges] - edge_displacements = jnp.array([np.array(r) - np.array(s) for s, r in nx_graph.edges]) + edge_displacements = jnp.array( + [np.array(r) - np.array(s) for s, r in nx_graph.edges] + ) senders, receivers = zip(*senders_receivers) n_node = node_positions.shape[0] n_edge = edge_displacements.shape[0] @@ -131,7 +161,9 @@ def grid_networkx_to_graphstuple(nx_graph): ) -def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple, number_graph_batch) -> nx.Graph: +def convert_jraph_to_networkx_graph( + jraph_graph: jraph.GraphsTuple, number_graph_batch +) -> nx.Graph: nodes, edges, receivers, senders, _, _, _ = jraph_graph node_padd = 0 edges_padd = 0 diff --git a/neuralplayground/agents/domine_2023_extras/class_config.yaml b/neuralplayground/agents/domine_2023_extras/class_config.yaml index 3b6f56f3..a880f88b 100644 --- a/neuralplayground/agents/domine_2023_extras/class_config.yaml +++ b/neuralplayground/agents/domine_2023_extras/class_config.yaml @@ -11,7 +11,7 @@ num_hidden: 500 # @param num_layers: 2 # @param num_message_passing_steps: 4 # @param learning_rate: 0.001 # @param -num_training_steps: 300 # @param +num_training_steps: 10 # @param # Env Stuff diff --git a/neuralplayground/agents/domine_2023_extras/class_grid_run_config.py b/neuralplayground/agents/domine_2023_extras/class_grid_run_config.py index 30cae69e..df66167a 100644 --- a/neuralplayground/agents/domine_2023_extras/class_grid_run_config.py +++ b/neuralplayground/agents/domine_2023_extras/class_grid_run_config.py @@ -1,6 +1,8 @@ from typing import Dict, Union -from neuralplayground.agents.domine_2023_extras.class_config_template import ConfigTemplate +from neuralplayground.agents.domine_2023_extras.class_config_template import ( + ConfigTemplate, +) from config_manager import base_configuration diff --git a/neuralplayground/agents/domine_2023_extras/class_models.py b/neuralplayground/agents/domine_2023_extras/class_models.py index c566c346..e3356e8f 100644 --- a/neuralplayground/agents/domine_2023_extras/class_models.py +++ b/neuralplayground/agents/domine_2023_extras/class_models.py @@ -18,7 +18,8 @@ def _forward(x): # Map features to desired feature size. x = jraph.GraphMapFeatures( - embed_edge_fn=hk.Linear(output_size=num_hidden), embed_node_fn=hk.Linear(output_size=num_hidden) + embed_edge_fn=hk.Linear(output_size=num_hidden), + embed_node_fn=hk.Linear(output_size=num_hidden), )(x) # Apply rounds of message passing. @@ -29,7 +30,8 @@ def _forward(x): # Map features to desired feature size. x = jraph.GraphMapFeatures( - embed_edge_fn=hk.Linear(output_size=edge_output_size), embed_node_fn=hk.Linear(output_size=node_output_size) + embed_edge_fn=hk.Linear(output_size=edge_output_size), + embed_node_fn=hk.Linear(output_size=node_output_size), )(x) return x, message_passing @@ -40,5 +42,7 @@ def _forward(x): def message_passing_layer(x, edge_mlp_sizes, node_mlp_sizes): update_edge_fn = jraph.concatenated_args(hk.nets.MLP(output_sizes=edge_mlp_sizes)) update_node_fn = jraph.concatenated_args(hk.nets.MLP(output_sizes=node_mlp_sizes)) - x = jraph.GraphNetwork(update_edge_fn=update_edge_fn, update_node_fn=update_node_fn)(x) + x = jraph.GraphNetwork( + update_edge_fn=update_edge_fn, update_node_fn=update_node_fn + )(x) return x diff --git a/neuralplayground/agents/domine_2023_extras/class_plotting_utils.py b/neuralplayground/agents/domine_2023_extras/class_plotting_utils.py index ec110116..158f64a7 100644 --- a/neuralplayground/agents/domine_2023_extras/class_plotting_utils.py +++ b/neuralplayground/agents/domine_2023_extras/class_plotting_utils.py @@ -1,14 +1,20 @@ # @title Make rng sequence generator import matplotlib.pyplot as plt import networkx as nx -from neuralplayground.agents.domine_2023_extras.class_utils import convert_jraph_to_networkx_graph, get_activations_graph_n, get_node_pad +from neuralplayground.agents.domine_2023_extras.class_utils import ( + convert_jraph_to_networkx_graph, + get_activations_graph_n, + get_node_pad, +) -def plot_input_target_output(inputs, targets, outputs, graph, n, edege_lables, save_path): +def plot_input_target_output( + inputs, targets, outputs, graph, n, edege_lables, save_path, title +): # minim 2 otherwise it breaks rows = ["{}".format(row) for row in ["Input", "Target", "Outputs"]] fig, axes = plt.subplots(3, n) - fig.set_size_inches(10, 10) + fig.set_size_inches(8, 8) for i in range(n): nx_graph = convert_jraph_to_networkx_graph(graph, i) pos = nx.spring_layout(nx_graph, iterations=100, seed=39775) @@ -27,23 +33,64 @@ def plot_input_target_output(inputs, targets, outputs, graph, n, edege_lables, s nx_graph[l][j]["weight"] = round(graph.edges[node_padd + u][2], 2) u = u + 1 labels = nx.get_edge_attributes(nx_graph, "weight") - nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=axes[0, i]) - nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=axes[1, i]) - nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=axes[2, i]) + nx.draw_networkx_edge_labels( + nx_graph, pos=pos, edge_labels=labels, ax=axes[0, i] + ) + nx.draw_networkx_edge_labels( + nx_graph, pos=pos, edge_labels=labels, ax=axes[1, i] + ) + nx.draw_networkx_edge_labels( + nx_graph, pos=pos, edge_labels=labels, ax=axes[2, i] + ) # labels = nx.get_edge_attributes(nx_graph, 'weight') - nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[0, i]) - nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=target, font_color="white", - ax=axes[1, i]) - nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=output, font_color="white", - ax=axes[2, i]) + nx.draw( + nx_graph, + pos=pos, + with_labels=True, + node_size=200, + node_color=input, + font_color="white", + ax=axes[0, i], + ) + nx.draw( + nx_graph, + pos=pos, + with_labels=True, + node_size=200, + node_color=target, + font_color="white", + ax=axes[1, i], + ) + nx.draw( + nx_graph, + pos=pos, + with_labels=True, + node_size=200, + node_color=output, + font_color="white", + ax=axes[2, i], + ) for axes, row in zip(axes[:, 0], rows): axes.set_ylabel(row, rotation=0, size="large") + plt.suptitle(title) plt.savefig(save_path) -def plot_message_passing_layers(inputs, activations, targets, outputs, graph, n, n_message_passing, edege_lables, save_path): + +def plot_message_passing_layers( + inputs, + activations, + targets, + outputs, + graph, + n, + n_message_passing, + edege_lables, + save_path, + title, +): # minim 2 otherwise it breaks fig, axes = plt.subplots(n_message_passing + 3, n) - fig.set_size_inches(10, 10) + fig.set_size_inches(8, 8) for j in range(n): nx_graph = convert_jraph_to_networkx_graph(graph, j) pos = nx.spring_layout(nx_graph, iterations=100, seed=39775) @@ -56,44 +103,73 @@ def plot_message_passing_layers(inputs, activations, targets, outputs, graph, n, labels = nx.get_edge_attributes(nx_graph, "weight") for i in range(n_message_passing + 3): if edege_lables: - nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=axes[i, j]) + nx.draw_networkx_edge_labels( + nx_graph, pos=pos, edge_labels=labels, ax=axes[i, j] + ) if i == (n_message_passing + 2): axes[i, j].title.set_text("input") input = get_activations_graph_n(inputs, graph, j) nx.draw( - nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[i, j] + nx_graph, + pos=pos, + with_labels=True, + node_size=200, + node_color=input, + font_color="white", + ax=axes[i, j], ) elif i == (n_message_passing + 1): axes[i, j].title.set_text("target") target = get_activations_graph_n(targets, graph, j) nx.draw( - nx_graph, pos=pos, with_labels=True, node_size=200, node_color=target, font_color="white", ax=axes[i, j] + nx_graph, + pos=pos, + with_labels=True, + node_size=200, + node_color=target, + font_color="white", + ax=axes[i, j], ) elif i == (n_message_passing): axes[i, j].title.set_text("output") output = get_activations_graph_n(outputs, graph, j) nx.draw( - nx_graph, pos=pos, with_labels=True, node_size=200, node_color=output, font_color="white", ax=axes[i, j] + nx_graph, + pos=pos, + with_labels=True, + node_size=200, + node_color=output, + font_color="white", + ax=axes[i, j], ) else: activation = activations[i] - axes[i, j].title.set_text("graph_" + str(j) + "message_passing_" + str(i)) - input = get_activations_graph_n(activation.nodes[:, j].tolist(), graph, j) + axes[i, j].title.set_text( + "graph_" + str(j) + "message_passing_" + str(i) + ) + input = get_activations_graph_n( + activation.nodes[:, j].tolist(), graph, j + ) nx.draw( - nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[i, j] + nx_graph, + pos=pos, + with_labels=True, + node_size=200, + node_color=input, + font_color="white", + ax=axes[i, j], ) + plt.suptitle(title) plt.savefig(save_path) - - def plot_graph_grid_activations( - node_colour, - graph, - save_path, - title, - edege_lables, - number_graph_batch=0, + node_colour, + graph, + save_path, + title, + edege_lables, + number_graph_batch=0, ): nx_graph = convert_jraph_to_networkx_graph(graph, number_graph_batch) output = get_activations_graph_n(node_colour, graph, number_graph_batch=0) @@ -109,12 +185,28 @@ def plot_graph_grid_activations( u = u + 1 labels = nx.get_edge_attributes(nx_graph, "weight") nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=ax) - nx.draw(nx_graph, pos=pos, with_labels=True, node_size=500, node_color=output, font_color="white", ax=ax) + nx.draw( + nx_graph, + pos=pos, + with_labels=True, + node_size=500, + node_color=output, + font_color="white", + ax=ax, + ) plt.savefig(save_path) def plot_message_passing_layers_units( - activations, targets, outputs, graph, number_hidden, n_message_passing, edege_lables, save_path + activations, + targets, + outputs, + graph, + number_hidden, + n_message_passing, + edege_lables, + save_path, + title, ): # minim 2 otherwise it breaks fig, axes = plt.subplots(n_message_passing, number_hidden) @@ -130,13 +222,25 @@ def plot_message_passing_layers_units( for i in range(n_message_passing): for j in range(number_hidden): activation = activations[i] - axes[i, j].title.set_text("first_graph_unit_" + str(j) + "message_passing_" + str(i)) + axes[i, j].title.set_text( + "first_graph_unit_" + str(j) + "message_passing_" + str(i) + ) # We select the first graph only input = get_activations_graph_n(activation.nodes[:, j].tolist(), graph, 0) - nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", - ax=axes[i, j]) + nx.draw( + nx_graph, + pos=pos, + with_labels=True, + node_size=200, + node_color=input, + font_color="white", + ax=axes[i, j], + ) if edege_lables: - nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=axes[i, j]) + nx.draw_networkx_edge_labels( + nx_graph, pos=pos, edge_labels=labels, ax=axes[i, j] + ) + plt.suptitle(title) plt.savefig(save_path) @@ -145,4 +249,30 @@ def plot_xy(auc_roc, path, title): ax = fig.add_subplot(111) ax.title.set_text(title) ax.plot(auc_roc) - plt.savefig(path) \ No newline at end of file + plt.savefig(path) + + +import matplotlib.pyplot as plt + + +def plot_curves(curves, path, title, legend_labels=None, x_label=None, y_label=None): + fig, ax = plt.subplots(figsize=(8, 6)) + ax.set_title(title) + + if x_label: + ax.set_xlabel(x_label) + if y_label: + ax.set_ylabel(y_label) + + colors = ["b", "g", "r", "c", "m", "y", "k"] + + for i, curve in enumerate(curves): + label = legend_labels[i] if legend_labels else None + color = colors[i % len(colors)] + ax.plot(curve, label=label, color=color) + + if legend_labels: + ax.legend() + + plt.savefig(path) + plt.show() diff --git a/neuralplayground/agents/domine_2023_extras/class_run.py b/neuralplayground/agents/domine_2023_extras/class_run.py index 27596007..c2b4c358 100644 --- a/neuralplayground/agents/domine_2023_extras/class_run.py +++ b/neuralplayground/agents/domine_2023_extras/class_run.py @@ -44,7 +44,11 @@ def run(config_path, config): edege_lables = False if config.wandb_on: dateTimeObj = datetime.now() - wandb.init(project="graph-brain", entity="graph-brain", name="Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M")) + wandb.init( + project="graph-brain", + entity="graph-brain", + name="Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M"), + ) wandb_logs = {} save_path = wandb.run.dir os.mkdir(os.path.join(save_path, "results")) @@ -52,8 +56,16 @@ def run(config_path, config): else: dateTimeObj = datetime.now() save_path = os.path.join(Path(os.getcwd()).resolve(), "results") - os.mkdir(os.path.join(save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M"))) - save_path = os.path.join(os.path.join(save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M"))) + os.mkdir( + os.path.join( + save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M") + ) + ) + save_path = os.path.join( + os.path.join( + save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M") + ) + ) # SAVING Trainning Files path = os.path.join(save_path, "run.py") HERE = os.path.join(Path(os.getcwd()).resolve(), "run.py") @@ -76,17 +88,29 @@ def run(config_path, config): shutil.copyfile(HERE, path) # This is the function that does the forward pass of the model - forward = get_forward_function(config.num_hidden, config.num_layers, config.num_message_passing_steps) + forward = get_forward_function( + config.num_hidden, config.num_layers, config.num_message_passing_steps + ) net_hk = hk.without_apply_rng(hk.transform(forward)) rng = jax.random.PRNGKey(config.seed) rng_seq = rng_sequence_from_rng(rng) if config.train_on_shortest_path: graph, targets = sample_padded_grid_batch_shortest_path( - rng, config.batch_size, config.feature_position, config.weighted, config.nx_min, config.nx_max + rng, + config.batch_size, + config.feature_position, + config.weighted, + config.nx_min, + config.nx_max, ) else: graph, targets = sample_padded_grid_batch_shortest_path( - rng, config.batch_size, config.feature_position, config.weighted, config.nx_min, config.nx_max + rng, + config.batch_size, + config.feature_position, + config.weighted, + config.nx_min, + config.nx_max, ) targets = graph.nodes params = net_hk.init(rng, graph) @@ -121,7 +145,12 @@ def evaluate(model, params, inputs, target): roc_aucs_test = [] rng = next(rng_seq) graph_test, target_test = sample_padded_grid_batch_shortest_path( - rng, config.batch_size_test, config.feature_position, config.weighted, config.nx_min_test, config.nx_max_test + rng, + config.batch_size_test, + config.feature_position, + config.weighted, + config.nx_min_test, + config.nx_max_test, ) for n in range(config.num_training_steps): rng = next(rng_seq) @@ -129,29 +158,50 @@ def evaluate(model, params, inputs, target): if config.resample: if train_on_shortest_path: graph, targets = sample_padded_grid_batch_shortest_path( - rng, config.batch_size, config.feature_position, config.weighted, config.nx_min, config.nx_max + rng, + config.batch_size, + config.feature_position, + config.weighted, + config.nx_min, + config.nx_max, ) else: graph, targets = sample_padded_grid_batch_shortest_path( - rng, config.batch_size, config.feature_position, config.weighted, config.nx_min, config.nx_max + rng, + config.batch_size, + config.feature_position, + config.weighted, + config.nx_min, + config.nx_max, ) targets = graph.nodes # Train - loss, grads = jax.value_and_grad(compute_loss)(params, net_hk, graph, targets) # jits inside of value_and_grad + loss, grads = jax.value_and_grad(compute_loss)( + params, net_hk, graph, targets + ) # jits inside of value_and_grad params = update_step(grads, opt_state, params) losses.append(loss) - outputs_train, roc_auc_train, MCC_train = evaluate(net_hk, params, graph, targets) + outputs_train, roc_auc_train, MCC_train = evaluate( + net_hk, params, graph, targets + ) roc_aucs_train.append(roc_auc_train) MCCs_train.append(MCC_train) # Matthews correlation coefficient # Test # model should basically learn to do nothing from this loss_test = compute_loss(params, net_hk, graph_test, target_test) losses_test.append(loss_test) - outputs_test, roc_auc_test, MCC_test = evaluate(net_hk, params, graph_test, target_test) + outputs_test, roc_auc_test, MCC_test = evaluate( + net_hk, params, graph_test, target_test + ) roc_aucs_test.append(roc_auc_test) MCCs_test.append(MCC_test) # Log - wandb_logs = {"loss": loss, "losses_test": loss_test, "roc_auc_test": roc_auc_test, "roc_auc": roc_auc_train} + wandb_logs = { + "loss": loss, + "losses_test": loss_test, + "roc_auc_test": roc_auc_test, + "roc_auc": roc_auc_train, + } if config.wandb_on: wandb.log(wandb_logs) if n % log_every == 0: @@ -169,7 +219,11 @@ def evaluate(model, params, inputs, target): with open("readme.txt", "w") as f: f.write("readme") with open(os.path.join(save_path, "Constant.txt"), "w") as outfile: - outfile.write("num_message_passing_steps" + str(config.num_message_passing_steps) + "\n") + outfile.write( + "num_message_passing_steps" + + str(config.num_message_passing_steps) + + "\n" + ) outfile.write("Learning_rate:" + str(config.learning_rate) + "\n") outfile.write("num_training_steps:" + str(config.num_training_steps)) outfile.write("roc_auc" + str(roc_auc)) @@ -179,7 +233,9 @@ def evaluate(model, params, inputs, target): plot_xy(losses, os.path.join(save_path, "Losses.pdf"), "Losses") plot_xy(losses_test, os.path.join(save_path, "Losses_test.pdf"), "Losses_test") plot_xy(roc_aucs_test, os.path.join(save_path, "auc_roc_test.pdf"), "auc_roc_test") - plot_xy(roc_aucs_train, os.path.join(save_path, "auc_roc_train.pdf"), "auc_roc_train") + plot_xy( + roc_aucs_train, os.path.join(save_path, "auc_roc_train.pdf"), "auc_roc_train" + ) plot_xy(MCCs_train, os.path.join(save_path, "MCC_train.pdf"), "MCC_train") plot_xy(MCCs_test, os.path.join(save_path, "MCC_test.pdf"), "MCC_test") @@ -221,7 +277,13 @@ def evaluate(model, params, inputs, target): "Inputs node assigments", edege_lables, ) - plot_graph_grid_activations(target_test.sum(-1), graph_test, os.path.join(save_path, "Target.pdf"), "Target", edege_lables) + plot_graph_grid_activations( + target_test.sum(-1), + graph_test, + os.path.join(save_path, "Target.pdf"), + "Target", + edege_lables, + ) plot_graph_grid_activations( outputs[0].nodes.tolist(), @@ -240,7 +302,12 @@ def evaluate(model, params, inputs, target): 2, ) plot_graph_grid_activations( - target_test.sum(-1), graph_test, os.path.join(save_path, "Target_2.pdf"), "Target", edege_lables, 2 + target_test.sum(-1), + graph_test, + os.path.join(save_path, "Target_2.pdf"), + "Target", + edege_lables, + 2, ) return losses, roc_auc diff --git a/neuralplayground/agents/domine_2023_extras/class_test.py b/neuralplayground/agents/domine_2023_extras/class_test.py index d734d709..2aaa2fc0 100644 --- a/neuralplayground/agents/domine_2023_extras/class_test.py +++ b/neuralplayground/agents/domine_2023_extras/class_test.py @@ -44,18 +44,30 @@ def shortest_path(): for n_x, n_y in zip(n_xs, n_ys): nx_graph = get_grid_adjacency(n_x, n_y) - senders, receivers, node_positions, edge_displacements, n_node, n_edge, global_context = grid_networkx_to_graphstuple( - nx_graph - ) + ( + senders, + receivers, + node_positions, + edge_displacements, + n_node, + n_edge, + global_context, + ) = grid_networkx_to_graphstuple(nx_graph) i_end = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_node) i_start = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_node) # make it a node feature of the input graph if a node is a start/end node input_node_features = jnp.zeros((n_node, 1)) - input_node_features = input_node_features.at[i_start, 0].set(1) # set start node feature - input_node_features = input_node_features.at[i_end, 0].set(1) # set end node feature + input_node_features = input_node_features.at[i_start, 0].set( + 1 + ) # set start node feature + input_node_features = input_node_features.at[i_end, 0].set( + 1 + ) # set end node feature if feature_position: - input_node_features = jnp.concatenate((input_node_features, node_positions), axis=1) + input_node_features = jnp.concatenate( + (input_node_features, node_positions), axis=1 + ) # edge_displacement= add_weighted_edge(edge_displacements,n_edge,10) edge_displacement = edge_displacements graph_weighted = jraph.GraphsTuple( @@ -79,7 +91,9 @@ def shortest_path(): nx_graph_weighted = convert_jraph_to_networkx_graph(graph_weighted, 0) min_edge_weight = 0.5 for i, j in nx_graph_weighted.edges(): - nx_graph_weighted[i][j]["weight"] = np.max([0.5 * np.random.rand() + 1.0, min_edge_weight]) + nx_graph_weighted[i][j]["weight"] = np.max( + [0.5 * np.random.rand() + 1.0, min_edge_weight] + ) nodes_on_shortest_path_indexes_weighted = nx.shortest_path( nx_graph_weighted, int(i_start[0]), int(i_end[0]), weight="weight" ) @@ -97,7 +111,9 @@ def grid_networkx_to_graphstuple(nx_graph): node_positions = jnp.array(nx_graph.nodes) node_to_inds = {n: i for i, n in enumerate(nx_graph.nodes)} senders_receivers = [(node_to_inds[s], node_to_inds[r]) for s, r in nx_graph.edges] - edge_displacements = jnp.array([np.array(r) - np.array(s) for s, r in nx_graph.edges]) + edge_displacements = jnp.array( + [np.array(r) - np.array(s) for s, r in nx_graph.edges] + ) senders, receivers = zip(*senders_receivers) n_node = node_positions.shape[0] @@ -118,7 +134,8 @@ def add_weighted_edge(edge_displacement, n_edge, sigma_on_edge_weight_noise): for l in range(2): if not edge_displacement[k][l] == 0: edge_displacement = edge_displacement.at[k, l].set( - edge_displacement[k][l] + sigma_on_edge_weight_noise * np.random.rand() + edge_displacement[k][l] + + sigma_on_edge_weight_noise * np.random.rand() ) return edge_displacement @@ -127,7 +144,9 @@ def get_grid_adjacency(n_x, n_y, atol=1e-1): return nx.grid_2d_graph(n_x, n_y) # Get directed grid graph -def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple, number_graph_batch) -> nx.Graph: +def convert_jraph_to_networkx_graph( + jraph_graph: jraph.GraphsTuple, number_graph_batch +) -> nx.Graph: nodes, edges, receivers, senders, _, _, _ = jraph_graph node_padd = 0 edges_padd = 0 diff --git a/neuralplayground/agents/domine_2023_extras/class_utils.py b/neuralplayground/agents/domine_2023_extras/class_utils.py index b29b0b2f..9ab4678b 100644 --- a/neuralplayground/agents/domine_2023_extras/class_utils.py +++ b/neuralplayground/agents/domine_2023_extras/class_utils.py @@ -21,7 +21,10 @@ def rng_sequence_from_rng(rng): rng, _ = jax.random.split(rng) yield rng -def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple, number_graph_batch) -> nx.Graph: + +def convert_jraph_to_networkx_graph( + jraph_graph: jraph.GraphsTuple, number_graph_batch +) -> nx.Graph: nodes, edges, receivers, senders, _, _, _ = jraph_graph node_padd = 0 edges_padd = 0 diff --git a/neuralplayground/agents/jax_optimised_version.py b/neuralplayground/agents/jax_optimised_version.py index bb6718dd..5ae904e9 100644 --- a/neuralplayground/agents/jax_optimised_version.py +++ b/neuralplayground/agents/jax_optimised_version.py @@ -16,7 +16,9 @@ from neuralplayground.agents.agent_core import AgentCore os.environ["KMP_DUPLICATE_LIB_OK"] = "True" -from neuralplayground.agents.domine_2023_extras.class_Graph_generation import sample_padded_grid_batch_shortest_path +from neuralplayground.agents.domine_2023_extras.class_Graph_generation import ( + sample_padded_grid_batch_shortest_path, +) from neuralplayground.agents.domine_2023_extras.class_grid_run_config import GridConfig from neuralplayground.agents.domine_2023_extras.class_models import get_forward_function from neuralplayground.agents.domine_2023_extras.class_plotting_utils import ( @@ -25,13 +27,16 @@ plot_message_passing_layers, plot_xy, ) -from neuralplayground.agents.domine_2023_extras.class_utils import rng_sequence_from_rng, set_device +from neuralplayground.agents.domine_2023_extras.class_utils import ( + rng_sequence_from_rng, + set_device, +) from sklearn.metrics import matthews_corrcoef, roc_auc_score from flax import struct class Domine2023(AgentCore, struct.PytreeNode): - experiment_name: str = 'smaller size generalisation graph with no position feature' + experiment_name: str = "smaller size generalisation graph with no position feature" train_on_shortest_path: bool = True resample: bool = True wandb_on: bool = False @@ -51,10 +56,10 @@ class Domine2023(AgentCore, struct.PytreeNode): nx_max_test: int = 7 obs_history: list = [] - def __init__ ( # autogenerated + def __init__( # autogenerated self, - #agent_name: str = "SR", - experiment_name = 'smaller size generalisation graph with no position feature', + # agent_name: str = "SR", + experiment_name="smaller size generalisation graph with no position feature", train_on_shortest_path: bool = True, resample: bool = True, wandb_on: bool = False, @@ -62,15 +67,15 @@ def __init__ ( # autogenerated feature_position: bool = False, weighted: bool = True, num_hidden: int = 100, - num_layers : int = 2, + num_layers: int = 2, num_message_passing_steps: int = 3, learning_rate: float = 0.001, num_training_steps: int = 10, batch_size: int = 4, - nx_min: int = 4, + nx_min: int = 4, nx_max: int = 7, - batch_size_test: int= 4, - nx_min_test: int = 4 , + batch_size_test: int = 4, + nx_min_test: int = 4, nx_max_test: int = 7, **mod_kwargs, ): @@ -93,21 +98,26 @@ def __init__ ( # autogenerated self.num_training_steps = num_training_steps # cconfig.num_training_steps # @param - self.batch_size =batch_size + self.batch_size = batch_size self.nx_min = nx_min self.nx_max = nx_max - # This can be tought of the brain making different rep of different granularity # Could be explained during sleep self.batch_size_test = batch_size_test # cconfig.batch_size_test - self.nx_min_test = nx_min_test # cconfig.nx_min_test # This is thought of the state density - self.nx_max_test = nx_max_test # config.nx_max_test # This is thought of the state density + self.nx_min_test = ( + nx_min_test # cconfig.nx_min_test # This is thought of the state density + ) + self.nx_max_test = ( + nx_max_test # config.nx_max_test # This is thought of the state density + ) self.batch_size = batch_size # c config.batch_size self.nx_min = nx_min # c config.nx_min # This is thought of the state density self.nx_max = nx_max # c config.nx_max # This is thought of the state density - self.arena_x_limits = mod_kwargs["arena_y_limits"] # cmod_kwargs["arena_x_limits"] + self.arena_x_limits = mod_kwargs[ + "arena_y_limits" + ] # cmod_kwargs["arena_x_limits"] self.arena_y_limits = mod_kwargs["arena_y_limits"] self.room_width = np.diff(self.arena_x_limits)[0] self.room_depth = np.diff(self.arena_y_limits)[0] @@ -122,8 +132,9 @@ def __init__ ( # autogenerated if self.wandb_on: dateTimeObj = datetime.now() wandb.init( - project="graph-brain", entity="graph-brain", - name="Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M") + project="graph-brain", + entity="graph-brain", + name="Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M"), ) self.wandb_logs = {} save_path = wandb.run.dir @@ -133,9 +144,16 @@ def __init__ ( # autogenerated else: dateTimeObj = datetime.now() save_path = os.path.join(Path(os.getcwd()).resolve(), "results") - os.mkdir(os.path.join(save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M"))) + os.mkdir( + os.path.join( + save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M") + ) + ) self.save_path = os.path.join( - os.path.join(save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M"))) + os.path.join( + save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M") + ) + ) self.reset() self.saving_run_parameters() @@ -143,17 +161,28 @@ def __init__ ( # autogenerated rng = jax.random.PRNGKey(self.seed) self.rng_seq = rng_sequence_from_rng(rng) - if self.train_on_shortest_path: self.graph, self.targets = sample_padded_grid_batch_shortest_path( - rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max + rng, + self.batch_size, + self.feature_position, + self.weighted, + self.nx_min, + self.nx_max, ) else: self.graph, self.targets = sample_padded_grid_batch_shortest_path( - rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max + rng, + self.batch_size, + self.feature_position, + self.weighted, + self.nx_min, + self.nx_max, ) - forward = get_forward_function(self.num_hidden, self.num_layers, self.num_message_passing_steps) + forward = get_forward_function( + self.num_hidden, self.num_layers, self.num_message_passing_steps + ) net_hk = hk.without_apply_rng(hk.transform(forward)) params = net_hk.init(rng, self.graph) self.params = params @@ -161,88 +190,84 @@ def __init__ ( # autogenerated opt_state = optimizer.init(self.params) self.opt_state = opt_state - - def compute_loss(params, inputs, targets): # not jitted because it will get jitted in jax.value_and_grad outputs = net_hk.apply(params, inputs) return jnp.mean((outputs[0].nodes - targets) ** 2) # using MSE - - self._compute_loss = jax.jit(compute_loss) - def update_step(params,opt_state): + def update_step(params, opt_state): loss, grads = jax.value_and_grad(compute_loss)( params, self.graph, self.targets ) # jits inside of value_and_grad - updates, opt_state = optimizer.update(grads, opt_state,params) + updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) - return params, opt_state ,loss + return params, opt_state, loss self._update_step = jax.jit(update_step) def evaluate(params, inputs, target): outputs = net_hk.apply(params, inputs) roc_auc = roc_auc_score(np.squeeze(target), np.squeeze(outputs[0].nodes)) - MCC = matthews_corrcoef(np.squeeze(target), round(np.squeeze(outputs[0].nodes))) + MCC = matthews_corrcoef( + np.squeeze(target), round(np.squeeze(outputs[0].nodes)) + ) return outputs, roc_auc, MCC self._evaluate = evaluate @jit def compute_loss(self, inputs, targets): - forward = get_forward_function(self.num_hidden, self.num_layers, self.num_message_passing_steps) + forward = get_forward_function( + self.num_hidden, self.num_layers, self.num_message_passing_steps + ) net_hk = hk.without_apply_rng(hk.transform(forward)) outputs = net_hk.apply(self.params, inputs) return jnp.mean((outputs[0].nodes - targets) ** 2) # using MSE def saving_run_parameters(self): - path = os.path.join(self.save_path, "run.py") HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023.py") shutil.copyfile(HERE, path) path = os.path.join(self.save_path, "class_Graph_generation.py") - HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023_extras/class_Graph_generation.py") + HERE = os.path.join( + Path(os.getcwd()).resolve(), "domine_2023_extras/class_Graph_generation.py" + ) shutil.copyfile(HERE, path) path = os.path.join(self.save_path, "class_utils.py") - HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023_extras/class_utils.py") + HERE = os.path.join( + Path(os.getcwd()).resolve(), "domine_2023_extras/class_utils.py" + ) shutil.copyfile(HERE, path) path = os.path.join(self.save_path, "class_plotting_utils.py") - HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023_extras/class_plotting_utils.py") + HERE = os.path.join( + Path(os.getcwd()).resolve(), "domine_2023_extras/class_plotting_utils.py" + ) shutil.copyfile(HERE, path) path = os.path.join(self.save_path, "class_config_run.yaml") - HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023_extras/class_config.yaml") + HERE = os.path.join( + Path(os.getcwd()).resolve(), "domine_2023_extras/class_config.yaml" + ) shutil.copyfile(HERE, path) - - def set_obs_history(self,obs_history): + def set_obs_history(self, obs_history): new_self = Domine2023( - self.seed, - self.num_hidden, - self.num_layers, - ..., - obs_history = obs_history + self.seed, self.num_hidden, self.num_layers, ..., obs_history=obs_history ) return new_self - def reset_obs_history(self, a=1): new_self = Domine2023( - self.seed, - self.num_hidden, - self.num_layers, - ..., - obs_history = [] + self.seed, self.num_hidden, self.num_layers, ..., obs_history=[] ) return new_self - - def reset(self,a=1): + def reset(self, a=1): self.obs_history = [] # Initialize observation history to update weights later self.grad_history = [] self.global_steps = 0 @@ -255,45 +280,72 @@ def reset(self,a=1): return def update(self): - rng = next(self.rng_seq) graph_test, target_test = sample_padded_grid_batch_shortest_path( - rng, self.batch_size_test, self.feature_position, self.weighted, self.nx_min_test, self.nx_max_test + rng, + self.batch_size_test, + self.feature_position, + self.weighted, + self.nx_min_test, + self.nx_max_test, ) rng = next(self.rng_seq) # Sample a new batch of graph every itterations if self.resample: if self.train_on_shortest_path: self.graph, self.targets = sample_padded_grid_batch_shortest_path( - rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max + rng, + self.batch_size, + self.feature_position, + self.weighted, + self.nx_min, + self.nx_max, ) else: self.graph, self.targets = sample_padded_grid_batch_shortest_path( - rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max + rng, + self.batch_size, + self.feature_position, + self.weighted, + self.nx_min, + self.nx_max, ) self.targets = self.graph.nodes # Train - self.params,self.opt_state, loss = self._update_step(self.params,self.opt_state ) + self.params, self.opt_state, loss = self._update_step( + self.params, self.opt_state + ) self.losses.append(loss) - outputs_train, roc_auc_train, MCC_train = self._evaluate(self.params, self.graph, self.targets) + outputs_train, roc_auc_train, MCC_train = self._evaluate( + self.params, self.graph, self.targets + ) self.roc_aucs_train.append(roc_auc_train) - self.MCCs_train.append(MCC_train) #Matthews correlation coefficient + self.MCCs_train.append(MCC_train) # Matthews correlation coefficient # Test # model should basically learn to do nothing from this - loss_test = self._compute_loss(self.params,graph_test, target_test) + loss_test = self._compute_loss(self.params, graph_test, target_test) self.losses_test.append(loss_test) - outputs_test, roc_auc_test, MCC_test = self._evaluate(self.params, graph_test, target_test) + outputs_test, roc_auc_test, MCC_test = self._evaluate( + self.params, graph_test, target_test + ) self.roc_aucs_test.append(roc_auc_test) self.MCCs_test.append(MCC_test) # Log - wandb_logs = {"loss": loss, "losses_test": loss_test, "roc_auc_test": roc_auc_test, "roc_auc": roc_auc_train} + wandb_logs = { + "loss": loss, + "losses_test": loss_test, + "roc_auc_test": roc_auc_test, + "roc_auc": roc_auc_train, + } if self.wandb_on: wandb.log(wandb_logs) self.global_steps = self.global_steps + 1 if self.global_steps % self.log_every == 0: - print(f"Training step {self.global_steps}: loss = {loss} , loss_test = {loss_test}, roc_auc_test = {roc_auc_test}, roc_auc_train = {roc_auc_train}") + print( + f"Training step {self.global_steps}: loss = {loss} , loss_test = {loss_test}, roc_auc_test = {roc_auc_test}, roc_auc_train = {roc_auc_train}" + ) return def print_and_plot(self): @@ -301,7 +353,12 @@ def print_and_plot(self): # EVALUATE rng = next(self.rng_seq) graph_test, target_test = sample_padded_grid_batch_shortest_path( - rng, self.batch_size_test, self.feature_position, self.weighted, self.nx_min_test, self.nx_max_test + rng, + self.batch_size_test, + self.feature_position, + self.weighted, + self.nx_min_test, + self.nx_max_test, ) # graph_test= self.graph # target_test = self.targets @@ -318,7 +375,11 @@ def print_and_plot(self): with open("readme.txt", "w") as f: f.write("readme") with open(os.path.join(self.save_path, "Constant.txt"), "w") as outfile: - outfile.write("num_message_passing_steps" + str(self.num_message_passing_steps) + "\n") + outfile.write( + "num_message_passing_steps" + + str(self.num_message_passing_steps) + + "\n" + ) outfile.write("Learning_rate:" + str(self.learning_rate) + "\n") outfile.write("num_training_steps:" + str(self.num_training_steps)) outfile.write("roc_auc" + str(roc_auc)) @@ -326,11 +387,27 @@ def print_and_plot(self): # PLOTTING THE LOSS and AUC ROC plot_xy(self.losses, os.path.join(self.save_path, "Losses.pdf"), "Losses") - plot_xy(self.losses_test, os.path.join(self.save_path, "Losses_test.pdf"), "Losses_test") - plot_xy(self.roc_aucs_test, os.path.join(self.save_path, "auc_roc_test.pdf"), "auc_roc_test") - plot_xy(self.roc_aucs_train, os.path.join(self.save_path, "auc_roc_train.pdf"), "auc_roc_train") - plot_xy(self.MCCs_train, os.path.join(self.save_path, "MCC_train.pdf"), "MCC_train") - plot_xy(self.MCCs_test, os.path.join(self.save_path, "MCC_test.pdf"), "MCC_test") + plot_xy( + self.losses_test, + os.path.join(self.save_path, "Losses_test.pdf"), + "Losses_test", + ) + plot_xy( + self.roc_aucs_test, + os.path.join(self.save_path, "auc_roc_test.pdf"), + "auc_roc_test", + ) + plot_xy( + self.roc_aucs_train, + os.path.join(self.save_path, "auc_roc_train.pdf"), + "auc_roc_train", + ) + plot_xy( + self.MCCs_train, os.path.join(self.save_path, "MCC_train.pdf"), "MCC_train" + ) + plot_xy( + self.MCCs_test, os.path.join(self.save_path, "MCC_test.pdf"), "MCC_test" + ) # PLOTTING ACTIVATION OF THE FIRST 2 GRAPH OF THE BATCH plot_input_target_output( @@ -371,10 +448,14 @@ def print_and_plot(self): self.edege_lables, ) plot_graph_grid_activations( - target_test.sum(-1), graph_test, os.path.join(self.save_path, "Target.pdf"), "Target", self.edege_lables + target_test.sum(-1), + graph_test, + os.path.join(self.save_path, "Target.pdf"), + "Target", + self.edege_lables, ) - graph_test= self.graph + graph_test = self.graph target_test = self.targets outputs, roc_auc, MCC = self._evaluate(self.params, graph_test, target_test) @@ -417,34 +498,37 @@ def print_and_plot(self): self.edege_lables, ) plot_graph_grid_activations( - target_test.sum(-1), graph_test, os.path.join(self.save_path, "Target.pdf"), "Target", self.edege_lables + target_test.sum(-1), + graph_test, + os.path.join(self.save_path, "Target.pdf"), + "Target", + self.edege_lables, ) - - # plot_graph_grid_activations( - # outputs[0].nodes.tolist(), + # plot_graph_grid_activations( + # outputs[0].nodes.tolist(), # graph_test, - # os.path.join(self.save_path, "outputs_2.pdf"), - # "Predicted Node Assignments with GCN", - # self.edege_lables, - #2, - #) - #plot_graph_grid_activations( - # list(graph_test.nodes.sum(-1)), - # graph_test, - # os.path.join(self.save_path, "Inputs_2.pdf"), - #"Inputs node assigments", - #self.edege_lables, - #2, - #) + # os.path.join(self.save_path, "outputs_2.pdf"), + # "Predicted Node Assignments with GCN", + # self.edege_lables, + # 2, + # ) + # plot_graph_grid_activations( + # list(graph_test.nodes.sum(-1)), + # graph_test, + # os.path.join(self.save_path, "Inputs_2.pdf"), + # "Inputs node assigments", + # self.edege_lables, + # 2, + # ) # # - print('End') + print("End") # plot_graph_grid_activations( - # - # target_test.sum(-1), graph_test, os.path.join(self.save_path, "Target_2.pdf"), "Target", self.edege_lables, 2 - # ) + # + # target_test.sum(-1), graph_test, os.path.join(self.save_path, "Target_2.pdf"), "Target", self.edege_lables, 2 + # ) # return @@ -454,11 +538,11 @@ def print_and_plot(self): # @title Graph net functions parser = argparse.ArgumentParser() parser.add_argument( - "--config_path", - metavar="-C", - default="domine_2023_extras/class_config.yaml", - help="path to base configuration file.", -) + "--config_path", + metavar="-C", + default="domine_2023_extras/class_config.yaml", + help="path to base configuration file.", + ) args = parser.parse_args() set_device() @@ -477,24 +561,27 @@ def print_and_plot(self): arena_y_limits=arena_y_limits, ) - agent = Domine2023( experiment_name=config.experiment_name, - train_on_shortest_path= config.train_on_shortest_path, - resample= config.resample, # @param - wandb_on= config.wandb_on, - seed= config.seed, - feature_position= config.feature_position, - weighted= config.weighted, - num_hidden= config.num_hidden, # @param - num_layers= config.num_layers, # @param - num_message_passing_steps= config.num_message_passing_steps, # @param - learning_rate= config.learning_rate , # @param - num_training_steps= config.num_training_steps, # @param - batch_size= config.batch_size, - nx_min= config.nx_min, - nx_max= config. nx_max, - batch_size_test= config.batch_size_test, - nx_min_test= config.nx_min_test, - nx_max_test= 7,arena_y_limits=arena_y_limits, arena_x_limits=arena_x_limits + agent = Domine2023( + experiment_name=config.experiment_name, + train_on_shortest_path=config.train_on_shortest_path, + resample=config.resample, # @param + wandb_on=config.wandb_on, + seed=config.seed, + feature_position=config.feature_position, + weighted=config.weighted, + num_hidden=config.num_hidden, # @param + num_layers=config.num_layers, # @param + num_message_passing_steps=config.num_message_passing_steps, # @param + learning_rate=config.learning_rate, # @param + num_training_steps=config.num_training_steps, # @param + batch_size=config.batch_size, + nx_min=config.nx_min, + nx_max=config.nx_max, + batch_size_test=config.batch_size_test, + nx_min_test=config.nx_min_test, + nx_max_test=7, + arena_y_limits=arena_y_limits, + arena_x_limits=arena_x_limits, ) for n in range(config.num_training_steps): @@ -510,4 +597,4 @@ def print_and_plot(self): x = x.replace(obs_history=[1, 2], num_hidden=2) x.num_hidden = 5 - x.update() \ No newline at end of file + x.update()