Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TEM branch merger #120

Open
wants to merge 80 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
004bff2
Add options for Windows and Unix OS in README
LukeHollingsworth Jul 31, 2023
86d403b
Merge branch 'main' of https://github.com/SainsburyWellcomeCentre/Neu…
LukeHollingsworth Aug 9, 2023
f84c0b3
Merge branch 'main' of https://github.com/SainsburyWellcomeCentre/Neu…
LukeHollingsworth Aug 15, 2023
29b678c
adding experimental runs to TEM
LukeHollingsworth Aug 18, 2023
e36c848
batch environment example working with Simple2D
Aug 18, 2023
29f1dfb
default argument of BatchEnvironment() set to DiscreteObjectEnvironme…
Aug 18, 2023
f634212
default argument of BatchEnvironment() set to DiscreteObjectEnvironme…
Aug 18, 2023
ad5cab6
merge main
ClementineDomine Aug 19, 2023
2b0f9d8
Update README.md - Centered logo
JarvisDevon Aug 19, 2023
034e685
debugging state density plot
Aug 20, 2023
60abefa
Merge branch 'whittington_2020' of https://github.com/SainsburyWellco…
Aug 20, 2023
68422a1
pre-commit changes
Aug 20, 2023
55d740e
change TEM imports to not require torch install
LukeHollingsworth Aug 20, 2023
2f78dee
note on installing dependencies on zsh shell
LukeHollingsworth Aug 20, 2023
3fae384
merged main into whittington_2020
LukeHollingsworth Aug 20, 2023
a60d64b
introduce logging of training accuracies
LukeHollingsworth Aug 20, 2023
ff12f40
pre-commit changes
LukeHollingsworth Aug 20, 2023
5c4fd53
added comments to TEM run file
LukeHollingsworth Aug 20, 2023
82a34d7
merge from main
Aug 21, 2023
48058c3
batch trajectories and grids plotted
Aug 21, 2023
4cd1f7a
Simple2D & DiscreteObject examples added for BatchEnvironment
Aug 21, 2023
a8b07cf
attempting to fix large file problem
Jun 27, 2024
ccc584a
running TEM tests
Jun 28, 2024
f67b4c2
slurm updated
Jun 28, 2024
978f001
slurm updated
Jun 28, 2024
348a161
slurm change
Jun 28, 2024
aebb6fb
huge 50K run added
Jul 4, 2024
c5762df
huge 50K run added
Jul 4, 2024
723c16d
state density and history bugs sorted
Jul 5, 2024
0aab239
TEM state density bugs fixed
Jul 9, 2024
1672921
big high density run added
Jul 9, 2024
7128313
small TEM run
Jul 11, 2024
eab0cdf
state density mismatch fixed
Jul 11, 2024
36f5da1
small training run (without width 2) added
Jul 12, 2024
18c4abb
medium size run added
Jul 15, 2024
5bc718b
problem with state assignment fixed
Jul 15, 2024
ccb394e
reduced slurm memory pool
Jul 16, 2024
94b8ac8
reduced slurm memory pool
Jul 16, 2024
7de0832
updated test
Jul 16, 2024
5d33231
pre-commit run on all files
Jul 16, 2024
0d277b1
is the cluster broken or is it just me?
Jul 16, 2024
74990a6
trying cpu slurm
Jul 16, 2024
ca1c310
trying cpu slurm
Jul 16, 2024
a03ada0
trying cpu slurm
Jul 16, 2024
92de616
looped walks added
Jul 19, 2024
0e93183
looping walk
Jul 19, 2024
a681c6b
cpu slurm added
Jul 19, 2024
5b22e32
cpu slurm added
Jul 19, 2024
78cb5bb
trying to fix slurm bug
Jul 25, 2024
2c01eac
big memory run with longer walks
Jul 29, 2024
078d41f
new training config
Jul 30, 2024
ab297dc
formatted
Aug 5, 2024
cc0ad77
full var walks added
Aug 5, 2024
709fc82
trailing whitespace
Sep 3, 2024
0ead088
full length training
Sep 3, 2024
1fe31ab
recent TEM updates
Oct 7, 2024
2a95433
minor update
Oct 15, 2024
d2bdd14
test push
Oct 16, 2024
853d185
black precommit changes
Oct 29, 2024
4afe24c
precommit black
Nov 6, 2024
d423f89
pre-merge
Nov 20, 2024
3351dd8
premerge to main
Nov 22, 2024
d2cb6c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2024
3921cf6
Merge remote-tracking branch 'origin/main' into whittington_2020
Nov 27, 2024
a6911d9
Merge branch 'whittington_2020' of https://github.com/SainsburyWellco…
Nov 27, 2024
2993e06
starting the cleaning process
Dec 2, 2024
62df4d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 2, 2024
bba6902
fixed arena examples bug
Dec 23, 2024
ab59665
Merge branch 'whittington_2020' of https://github.com/SainsburyWellco…
Dec 23, 2024
8b062e7
retrigger checks
rodrigcd Dec 23, 2024
774a67f
SimpleDiscreteAgent added
Dec 31, 2024
3b5c1ac
Merge branch 'whittington_2020' of https://github.com/SainsburyWellco…
Dec 31, 2024
67b9c53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 31, 2024
37cdc5a
fixing obs hist length
rodrigcd Jan 8, 2025
4d15210
disabling default plt show and fixing bugs on tests
rodrigcd Jan 8, 2025
6bd3651
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2025
c00cc09
removing tem logs from test
rodrigcd Jan 8, 2025
9a4fe58
fixing live render code for matplotlib 3.10
rodrigcd Jan 8, 2025
d872e91
limiting to python>=3.10
rodrigcd Jan 8, 2025
fad1ccd
limiting to python>=3.10
rodrigcd Jan 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pip install NeuralPlayground==0.0.5

If you want to contribute to the project, get the latest development version
from GitHub, and install it in editable mode, including the "dev" dependencies:

#### Unix (Linux and macOS)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really necessary

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. I had some problems early on, cloning the dev version of NeuralPlayground on both Windows and Mac OS. If this has been fixed, then this is redundant and I'll change it.

```bash
git clone https://github.com/SainsburyWellcomeCentre/NeuralPlayground/ --single-branch
cd NeuralPlayground
Expand Down
4,680 changes: 4,680 additions & 0 deletions examples/agent_examples/custom_sim/run.log

Large diffs are not rendered by default.

75 changes: 42 additions & 33 deletions examples/agent_examples/whittington_2020_example.ipynb

Large diffs are not rendered by default.

96 changes: 55 additions & 41 deletions examples/agent_examples/whittington_2020_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from neuralplayground.experiments import Sargolini2006Data

simulation_id = "TEM_custom_sim"
save_path = os.path.join(os.getcwd(), "results_sim")
save_path = os.path.join(os.getcwd(), "examples", "agent_examples", "results_sim")
# save_path = os.path.join(os.getcwd(), "examples", "agent_examples", "trained_results")
agent_class = Whittington2020
env_class = BatchEnvironment
Expand All @@ -23,56 +23,56 @@
params = parameters.parameters()
full_agent_params = params.copy()

# Set the x and y limits for the arena
arena_x_limits = [
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
]
arena_y_limits = [
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
[-2.5, 2.5],
]

room_widths = [int(np.diff(arena_x_limits)[i]) for i in range(len(arena_x_limits))]
room_depths = [int(np.diff(arena_y_limits)[i]) for i in range(len(arena_y_limits))]

# Set parameters for the environment that generates observations
discrete_env_params = {
"environment_name": "DiscreteObject",
"state_density": 1,
"n_objects": params["n_x"],
"agent_step_size": 1,
"agent_step_size": 1, # Note: this must be 1 / state density
"use_behavioural_data": False,
"data_path": None,
"experiment_class": Sargolini2006Data,
}

# Set parameters for the batch environment
env_params = {
"environment_name": "BatchEnvironment",
"batch_size": 16,
Expand All @@ -81,19 +81,33 @@
"env_class": DiscreteObjectEnvironment,
"arg_env_params": discrete_env_params,
}

# If behavioural data are used, set arena limits to those from Sargolini et al. 2006, reduce state density to 1/4
state_densities = [discrete_env_params["state_density"] for _ in range(env_params["batch_size"])]
if discrete_env_params["use_behavioural_data"]:
arena_x_limits = [[-50, 50] for _ in range(env_params["batch_size"])]
arena_y_limits = [[-50, 50] for _ in range(env_params["batch_size"])]
state_densities = [0.25] * env_params["batch_size"]

room_widths = [int(np.diff(arena_x_limits)[i]) for i in range(env_params["batch_size"])]
room_depths = [int(np.diff(arena_y_limits)[i]) for i in range(env_params["batch_size"])]

# Set parameters for the agent
agent_params = {
"model_name": "Whittington2020",
"save_name": str(simulation_id)[4:],
"params": full_agent_params,
"batch_size": env_params["batch_size"],
"room_widths": room_widths,
"room_depths": room_depths,
"state_densities": [discrete_env_params["state_density"]] * env_params["batch_size"],
"use_behavioural_data": False,
"state_densities": state_densities,
"use_behavioural_data": discrete_env_params["use_behavioural_data"],
}

# Full model training consists of 20000 episodes
training_loop_params = {"n_episode": 10, "params": full_agent_params}
training_loop_params = {"n_episode": 5000, "params": full_agent_params, "random_state": False, "custom_state": [0.0, 0.0]}

# Create the training simulation object
sim = SingleSim(
simulation_id=simulation_id,
agent_class=agent_class,
Expand All @@ -104,7 +118,7 @@
training_loop_params=training_loop_params,
)

# print(sim)
# Run the simulation
print("Running sim...")
sim.run_sim(save_path)
print("Sim finished.")
301 changes: 230 additions & 71 deletions examples/arena_examples/arena_examples.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/comparisons_examples/comparison_board.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.10.12 ('NPG-env')",
"language": "python",
"name": "python3"
},
Expand Down
1 change: 1 addition & 0 deletions neuralplayground/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .agent_core import AgentCore, RandomAgent, LevyFlightAgent, RatMovementAgent
from .stachenfeld_2018 import Stachenfeld2018
from .weber_2018 import Weber2018
from .discrete_agent import SimpleDiscreteAgent

# from .whittington_2020 import Whittington2020
2 changes: 1 addition & 1 deletion neuralplayground/agents/agent_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def act(self, obs, policy_func=None):
action = np.random.normal(scale=self.agent_step_size, size=(2,))

self.obs_history.append(obs)
if len(self.obs_history) >= 1000: # reset every 1000
if len(self.obs_history) >= 1000: # max length 1000
self.obs_history.pop(0)
if policy_func is not None:
return policy_func(obs)
Expand Down
113 changes: 113 additions & 0 deletions neuralplayground/agents/discrete_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import numpy as np

from .agent_core import AgentCore


class SimpleDiscreteAgent(AgentCore):
"""
A simplified single-environment discrete agent, loosely mirroring TEM’s
approach to picking actions and checking whether the environment
actually moved.
"""

def __init__(
self,
agent_name: str = "SimpleDiscreteAgent",
**model_kwargs,
):
"""
Parameters
----------
room_width : int
Width (in discrete states) of the environment
room_depth : int
Depth (in discrete states) of the environment
state_density : float
Number of discrete states per unit distance (usually 1 / step_size)
agent_name : str
Agent's name
"""
super().__init__(agent_name=agent_name)
self.room_width = model_kwargs["room_width"]
self.room_depth = model_kwargs["room_depth"]
self.state_density = model_kwargs["state_density"]
# Discrete actions: stay, up, down, right, left
self.poss_actions = [[0, 0], [0, 1], [0, -1], [1, 0], [-1, 0]]

# For storing trajectory
self.walk_actions = []
self.obs_history = []

# Keep track of previous observation/action so we know if the environment actually moved.
self.prev_observation = None
self.prev_action = [0, 0]
self.n_walk = 0

def reset(self):
"""
Reset the agent’s history and counters.
"""
super().reset()
self.walk_actions = []
self.obs_history = []
self.prev_observation = None
self.prev_action = [0, 0]
self.n_walk = 0

def act(self, observation, policy_func=None):
"""
Decide on the next action. If the environment did not change state
(i.e. we got the same position as before, and we tried to move),
then pick a new random action. Otherwise, record the old observation and action.

Parameters
----------
observation : list or np.ndarray
Typically [state_index, object_info, (x,y)] for a discrete environment.
The first element (observation[0]) is the discrete state index.

policy_func : callable, optional
Unused here. Included only for compatibility.

Returns
-------
action : list
Chosen discrete action [dx, dy]
"""
# If this is our first time calling act, initialise
if self.prev_observation is None:
self.prev_observation = observation
self.prev_action = self.action_policy()
return self.prev_action

# Check if environment actually moved to a new state
curr_state_idx = observation[0]
prev_state_idx = self.prev_observation[0]

if curr_state_idx == prev_state_idx and self.prev_action != [0, 0]:
# The environment didn't move from last action, so pick a new random action
new_action = self.action_policy()
else:
# The environment did move, so record old obs/action before picking the next action
self.walk_actions.append(self.prev_action)
self.obs_history.append(self.prev_observation)
self.n_walk += 1
new_action = self.action_policy()

self.prev_observation = observation
self.prev_action = new_action
return new_action

def action_policy(self):
"""
Random action policy that selects an action from [stay, up, down, right, left].
"""
idx = np.random.choice(len(self.poss_actions))
return self.poss_actions[idx]

def update(self):
"""
Update the agent's internal state after a walk is completed.
"""
self.n_walk = 0
return None
Loading
Loading