From 5f7495a4c0a8472543206ae0cb39aad4134e2443 Mon Sep 17 00:00:00 2001 From: Bobby Jackson Date: Wed, 22 Feb 2023 10:52:46 -0600 Subject: [PATCH] FIX: Groupby plots will now group timeseries subplots by (#624) * ADD: Groupby plotting capability. * ENH: Changing author name * FIX: Plots in examples * FIX: GroupBy will now override TimeSeriesDisplay default time axes * ADD: None as default value for group_by. * FIX: Groupby now groups by year in subplots --------- Co-authored-by: Robert Jackson Co-authored-by: AdamTheisen Co-authored-by: Robert Jackson Co-authored-by: Robert Jackson --- act/plotting/plot.py | 44 +++++++++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/act/plotting/plot.py b/act/plotting/plot.py index 153b47e939..b9201aa6f1 100644 --- a/act/plotting/plot.py +++ b/act/plotting/plot.py @@ -311,6 +311,10 @@ def __init__(self, display, units): """ self.display = display self._groupby = {} + self.mapping = {} + self.xlims = {} + self.units = units + self.isTimeSeriesDisplay = hasattr(self.display, 'time_height_scatter') num_groups = 0 datastreams = list(display._obj.keys()) for key in datastreams: @@ -350,6 +354,7 @@ def plot_group(self, func_name, dsname=None, **kwargs): if dsname == key: self.display._obj = {} for k, ds in self._groupby[key]: + num_years = len(np.unique(ds.time.dt.year)) self.display._obj[key + '_%d' % k] = ds if i >= np.prod(subplot_shape): i = 0 @@ -363,9 +368,27 @@ def plot_group(self, func_name, dsname=None, **kwargs): kwargs["subplot_index"] = subplot_index if "time_rng" in args: kwargs["time_rng"] = (ds.time.values.min(), ds.time.values.max()) - func(dsname=key + '_%d' % k, - **kwargs) - + if num_years > 1 and self.isTimeSeriesDisplay: + first_year = ds.time.dt.year[0] + for yr, ds1 in ds.groupby('time.year'): + if ds1.time.dt.year[0] % 4 == 0: + days_in_year = 366 + else: + days_in_year = 365 + year_diff = ds1.time.dt.year - first_year + time_diff = np.array( + [np.timedelta64(x * days_in_year, 'D') for x in year_diff.values]) + ds1['time'] = ds1.time - time_diff + self.display._obj[key + '%d_%d' % (k, yr)] = ds1 + func(dsname=key + '%d_%d' % (k, yr), label=str(yr), **kwargs) + self.mapping[key + '%d_%d' % (k, yr)] = subplot_index + self.xlims[key + '%d_%d' % (k, yr)] = (ds1.time.values.min(), ds1.time.values.max()) + del self.display._obj[key + '_%d' % k] + else: + func(dsname=key + '_%d' % k, **kwargs) + self.mapping[key + '_%d' % k] = subplot_index + if self.isTimeSeriesDisplay: + self.xlims[key + '_%d' % k] = (ds.time.values.min(), ds.time.values.max()) i = i + 1 if wrap_around is False and i < np.prod(subplot_shape): @@ -387,16 +410,11 @@ def plot_group(self, func_name, dsname=None, **kwargs): except AttributeError: pass - # Set to min and max for each time period if time series display - # Only the TimeSeriesDisplay has the time_height_scatter function - # So, check for that - if hasattr(self.display, 'time_height_scatter'): - key_list = list(self.display._obj.keys()) - if i >= len(key_list): - continue - ds = self.display._obj[key_list[i]] - time_min = ds.time.values.min() - time_max = ds.time.values.max() + if self.isTimeSeriesDisplay: + key_list = list(self.display._obj.keys()) + for k in key_list: + time_min, time_max = self.xlims[k] + subplot_index = self.mapping[k] self.display.set_xrng([time_min, time_max], subplot_index) self.display._obj = old_obj