diff --git a/neuralplayground/arenas/discritized_objects.py b/neuralplayground/arenas/discritized_objects.py index 0720778..ca8707b 100644 --- a/neuralplayground/arenas/discritized_objects.py +++ b/neuralplayground/arenas/discritized_objects.py @@ -413,8 +413,10 @@ def render(self, history_length=30, display=True): history = self.history[-history_length:] ax = self.plot_trajectory(history_data=history, ax=ax) canvas.draw() - image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8") - image = image.reshape(f.canvas.get_width_height()[::-1] + (3,)) + image = np.frombuffer(canvas.buffer_rgba(), dtype="uint8") + image = image.reshape(f.canvas.get_width_height()[::-1] + (4,)) + image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) + print(image.shape) if display: cv2.imshow("2D_env", image) diff --git a/neuralplayground/arenas/simple2d.py b/neuralplayground/arenas/simple2d.py index d98e421..9529e74 100644 --- a/neuralplayground/arenas/simple2d.py +++ b/neuralplayground/arenas/simple2d.py @@ -354,8 +354,10 @@ def render(self, history_length=30, display=True): history = self.history[-history_length:] ax = self.plot_trajectory(history_data=history, ax=ax) canvas.draw() - image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8") - image = image.reshape(f.canvas.get_width_height()[::-1] + (3,)) + image = np.frombuffer(canvas.buffer_rgba(), dtype="uint8") + image = image.reshape(f.canvas.get_width_height()[::-1] + (4,)) + image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) + print(image.shape) if display: cv2.imshow("2D_env", image)