Skip to content

Commit

Permalink
Update use of matplotlib API
Browse files Browse the repository at this point in the history
  • Loading branch information
adamltyson committed Jan 9, 2025
1 parent 2dc1047 commit 8451321
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions neuralplayground/arenas/discritized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,10 @@ def render(self, history_length=30, display=True):
history = self.history[-history_length:]
ax = self.plot_trajectory(history_data=history, ax=ax)
canvas.draw()
image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
image = image.reshape(f.canvas.get_width_height()[::-1] + (3,))
image = np.frombuffer(canvas.buffer_rgba(), dtype="uint8")
image = image.reshape(f.canvas.get_width_height()[::-1] + (4,))
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)

print(image.shape)
if display:
cv2.imshow("2D_env", image)
Expand Down
6 changes: 4 additions & 2 deletions neuralplayground/arenas/simple2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,10 @@ def render(self, history_length=30, display=True):
history = self.history[-history_length:]
ax = self.plot_trajectory(history_data=history, ax=ax)
canvas.draw()
image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
image = image.reshape(f.canvas.get_width_height()[::-1] + (3,))
image = np.frombuffer(canvas.buffer_rgba(), dtype="uint8")
image = image.reshape(f.canvas.get_width_height()[::-1] + (4,))
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)

print(image.shape)
if display:
cv2.imshow("2D_env", image)
Expand Down

0 comments on commit 8451321

Please sign in to comment.