Skip to content

Commit

Permalink
Backport PR matplotlib#29167: BUGFIX: use axes unit information in Co…
Browse files Browse the repository at this point in the history
…nnectionPatch
  • Loading branch information
QuLogic authored and meeseeksmachine committed Dec 5, 2024
1 parent 4a18447 commit b0ae5d7
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 7 deletions.
5 changes: 3 additions & 2 deletions lib/matplotlib/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np

from matplotlib import _api, ticker, units
from matplotlib import _api, cbook, ticker, units


_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,7 +55,8 @@ def convert(value, unit, axis):
values = np.atleast_1d(np.array(value, dtype=object))
# force an update so it also does type checking
unit.update(values)
return np.vectorize(unit._mapping.__getitem__, otypes=[float])(values)
s = np.vectorize(unit._mapping.__getitem__, otypes=[float])(values)
return s if not cbook.is_scalar_or_string(value) else s[0]

@staticmethod
def axisinfo(unit, axis):
Expand Down
14 changes: 9 additions & 5 deletions lib/matplotlib/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -4589,23 +4589,27 @@ def _get_xy(self, xy, s, axes=None):
s0 = s # For the error message, if needed.
if axes is None:
axes = self.axes
xy = np.array(xy)

# preserve mixed type input (such as str, int)
x = np.array(xy[0])
y = np.array(xy[1])

fig = self.get_figure(root=False)
if s in ["figure points", "axes points"]:
xy *= fig.dpi / 72
x = x * fig.dpi / 72
y = y * fig.dpi / 72
s = s.replace("points", "pixels")
elif s == "figure fraction":
s = fig.transFigure
elif s == "subfigure fraction":
s = fig.transSubfigure
elif s == "axes fraction":
s = axes.transAxes
x, y = xy

if s == 'data':
trans = axes.transData
x = float(self.convert_xunits(x))
y = float(self.convert_yunits(y))
x = cbook._to_unmasked_float_array(axes.xaxis.convert_units(x))
y = cbook._to_unmasked_float_array(axes.yaxis.convert_units(y))
return trans.transform((x, y))
elif s == 'offset points':
if self.xycoords == 'offset points': # prevent recursion
Expand Down
22 changes: 22 additions & 0 deletions lib/matplotlib/tests/test_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,28 @@ def test_connection_patch_fig(fig_test, fig_ref):
fig_ref.add_artist(con)


@check_figures_equal(extensions=["png"])
def test_connection_patch_pixel_points(fig_test, fig_ref):
xyA_pts = (.3, .2)
xyB_pts = (-30, -20)

ax1, ax2 = fig_test.subplots(1, 2)
con = mpatches.ConnectionPatch(xyA=xyA_pts, coordsA="axes points", axesA=ax1,
xyB=xyB_pts, coordsB="figure points",
arrowstyle="->", shrinkB=5)
fig_test.add_artist(con)

plt.rcParams["savefig.dpi"] = plt.rcParams["figure.dpi"]

ax1, ax2 = fig_ref.subplots(1, 2)
xyA_pix = (xyA_pts[0]*(fig_ref.dpi/72), xyA_pts[1]*(fig_ref.dpi/72))
xyB_pix = (xyB_pts[0]*(fig_ref.dpi/72), xyB_pts[1]*(fig_ref.dpi/72))
con = mpatches.ConnectionPatch(xyA=xyA_pix, coordsA="axes pixels", axesA=ax1,
xyB=xyB_pix, coordsB="figure pixels",
arrowstyle="->", shrinkB=5)
fig_ref.add_artist(con)


def test_datetime_rectangle():
# Check that creating a rectangle with timedeltas doesn't fail
from datetime import datetime, timedelta
Expand Down
15 changes: 15 additions & 0 deletions lib/matplotlib/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import matplotlib.pyplot as plt
from matplotlib.testing.decorators import check_figures_equal, image_comparison
import matplotlib.patches as mpatches
import matplotlib.units as munits
from matplotlib.category import StrCategoryConverter, UnitData
from matplotlib.dates import DateConverter
Expand Down Expand Up @@ -336,3 +337,17 @@ def test_plot_kernel():
# just a smoketest that fail
kernel = Kernel([1, 2, 3, 4, 5])
plt.plot(kernel)


def test_connection_patch_units(pd):
# tests that this doesn't raise an error
fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(10, 5))
x = pd.Timestamp('2017-01-01T12')
ax1.axvline(x)
y = "test test"
ax2.axhline(y)
arr = mpatches.ConnectionPatch((x, 0), (0, y),
coordsA='data', coordsB='data',
axesA=ax1, axesB=ax2)
fig.add_artist(arr)
fig.draw_without_rendering()

0 comments on commit b0ae5d7

Please sign in to comment.