Skip to content

Commit

Permalink
Merge pull request #355 from European-XFEL/thumbnails
Browse files Browse the repository at this point in the history
Use xarray's plotting methods to generate image thumbnails
  • Loading branch information
JamesWrigley authored Nov 27, 2024
2 parents 6566fa2 + d6dace5 commit 8f1decb
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
13 changes: 11 additions & 2 deletions damnit/ctxsupport/ctxrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,11 @@ def generate_thumbnail(image):
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
vmin = np.nanquantile(image, 0.01)
vmax = np.nanquantile(image, 0.99)
ax.imshow(image, vmin=vmin, vmax=vmax, extent=(0, 1, 1, 0))
if isinstance(image, np.ndarray):
ax.imshow(image, vmin=vmin, vmax=vmax)
else:
# Use DataArray's own plotting method
image.plot(ax=ax, vmin=vmin, vmax=vmax, add_colorbar=False)
ax.axis('tight')
ax.axis('off')
ax.margins(0, 0)
Expand Down Expand Up @@ -518,7 +522,12 @@ def summarise(self, name):
if data.ndim == 0:
return data
elif data.ndim == 2:
return generate_thumbnail(np.nan_to_num(data))
if isinstance(data, np.ndarray):
data = np.nan_to_num(data)
else:
data = data.fillna(0)

return generate_thumbnail(data)
else:
return f"{data.dtype}: {data.shape}"

Expand Down
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Fixed:

- Added back grid lines for plots of `DataArray`'s (!334).
- Fixed thumbnails of 2D `DataArray`'s to match what is displayed when the
variable is plotted (!355).

## [0.1.4]

Expand Down
12 changes: 9 additions & 3 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,18 @@ def foo(run): return xr.DataArray([1, 2, 3])

figure_code = """
import numpy as np
import xarray as xr
from damnit_ctx import Variable
from matplotlib import pyplot as plt
@Variable("2D array")
@Variable("2D ndarray")
def twodarray(run):
return np.random.rand(1000, 1000)
@Variable("2D xarray")
def twodxarray(run):
return xr.DataArray(np.random.rand(100, 100))
@Variable(title="Axes")
def axes(run):
_, ax = plt.subplots()
Expand Down Expand Up @@ -359,8 +364,9 @@ def figure(run):
assert f["axes/data"].ndim == 3

# Test that the summaries are the right size
twodarray_png = Image.open(io.BytesIO(f[".reduced/twodarray"][()]))
assert np.asarray(twodarray_png).shape == (THUMBNAIL_SIZE, THUMBNAIL_SIZE, 4)
for var in ["twodarray", "twodxarray"]:
png = Image.open(io.BytesIO(f[f".reduced/{var}"][()]))
assert np.asarray(png).shape == (THUMBNAIL_SIZE, THUMBNAIL_SIZE, 4)

figure_png = Image.open(io.BytesIO(f[".reduced/figure"][()]))
assert max(np.asarray(figure_png).shape) == THUMBNAIL_SIZE
Expand Down

0 comments on commit 8f1decb

Please sign in to comment.