From 4bc41b12e19621f55fde6271730e8bd8014f7ff9 Mon Sep 17 00:00:00 2001 From: Gereon Tombrink Date: Mon, 6 Jan 2025 13:48:03 +0100 Subject: [PATCH] Fix plotting bug related to single deviation plot with deviations without rotations --- trajectopy/core/plotting/mpl/results.py | 79 ++++++------------------- 1 file changed, 17 insertions(+), 62 deletions(-) diff --git a/trajectopy/core/plotting/mpl/results.py b/trajectopy/core/plotting/mpl/results.py index 558a400..49da417 100644 --- a/trajectopy/core/plotting/mpl/results.py +++ b/trajectopy/core/plotting/mpl/results.py @@ -51,9 +51,7 @@ def _stair_hist(*, l, mm: bool = False, linewidth: float = 1.5) -> None: return max(n_hist) -def plot_compact_ate_hist( - ate_result: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings() -) -> Figure: +def plot_compact_ate_hist(ate_result: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings()) -> Figure: """ Plots compact ATE histograms for the given ATEResult. The plot contains histograms for the position deviations and, if available, the rotation deviations. @@ -79,9 +77,7 @@ def plot_compact_ate_hist( return fig -def plot_rotation_ate_hist( - devs: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings() -) -> None: +def plot_rotation_ate_hist(devs: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings()) -> None: roll = np.rad2deg(devs.rot_dev_x) pitch = np.rad2deg(devs.rot_dev_y) yaw = np.rad2deg(devs.rot_dev_z) @@ -99,29 +95,11 @@ def plot_rotation_ate_hist( plt.legend(["yaw", "pitch", "roll"]) -def plot_position_ate_hist( - devs: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings() -): - deviations_xa = ( - devs.abs_dev.directed_pos_dev[:, 0] - if plot_settings.directed_ate - else devs.abs_dev.pos_dev[:, 0] - ) - deviations_yh = ( - devs.abs_dev.directed_pos_dev[:, 1] - if plot_settings.directed_ate - else devs.abs_dev.pos_dev[:, 1] - ) - deviations_zv = ( - devs.abs_dev.directed_pos_dev[:, 2] - if plot_settings.directed_ate - else devs.abs_dev.pos_dev[:, 2] - ) - labels = ( - ["vertical", "horizontal", "along"] - if plot_settings.directed_ate - else ["x", "y", "z"] - ) +def plot_position_ate_hist(devs: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings()): + deviations_xa = devs.abs_dev.directed_pos_dev[:, 0] if plot_settings.directed_ate else devs.abs_dev.pos_dev[:, 0] + deviations_yh = devs.abs_dev.directed_pos_dev[:, 1] if plot_settings.directed_ate else devs.abs_dev.pos_dev[:, 1] + deviations_zv = devs.abs_dev.directed_pos_dev[:, 2] if plot_settings.directed_ate else devs.abs_dev.pos_dev[:, 2] + labels = ["vertical", "horizontal", "along"] if plot_settings.directed_ate else ["x", "y", "z"] plt.xlabel(plot_settings.unit_str) plt.ylabel("counts") @@ -175,9 +153,7 @@ def plot_position_ate_edf( for dev in deviation_list: sorted_comb_pos_dev = np.sort(dev.pos_dev_comb) - pos_norm_cdf = np.arange(len(sorted_comb_pos_dev)) / float( - len(sorted_comb_pos_dev) - ) + pos_norm_cdf = np.arange(len(sorted_comb_pos_dev)) / float(len(sorted_comb_pos_dev)) ax_pos.plot(sorted_comb_pos_dev * plot_settings.unit_multiplier, pos_norm_cdf) @@ -194,9 +170,7 @@ def plot_rotation_ate_edf(deviation_list: List[ATEResult]) -> None: if dev.abs_dev.rot_dev is None: continue sorted_comb_rot_dev = np.sort(np.rad2deg(dev.rot_dev_comb)) - rot_norm_cdf = np.arange(len(sorted_comb_rot_dev)) / float( - len(sorted_comb_rot_dev) - ) + rot_norm_cdf = np.arange(len(sorted_comb_rot_dev)) / float(len(sorted_comb_rot_dev)) ax_rot.plot(sorted_comb_rot_dev, rot_norm_cdf) @@ -281,9 +255,7 @@ def plot_ate( Figure: Figure containing the plot. """ deviation_list = ate_results if isinstance(ate_results, list) else [ate_results] - x_label = derive_xlabel_from_sortings( - [dev.trajectory.sorting.value for dev in deviation_list] - ) + x_label = derive_xlabel_from_sortings([dev.trajectory.sorting.value for dev in deviation_list]) fig = plt.figure() @@ -323,9 +295,9 @@ def plot_ate( dev.trajectory.function_of[arc_length_sorting], np.rad2deg(dev.rot_dev_comb[arc_length_sorting]), ) + ax_rot.set_xlim(min_x, max_x) ax_pos.set_xlim(min_x, max_x) - ax_rot.set_xlim(min_x, max_x) fig.legend([dev.name for dev in deviation_list], ncol=3, loc="upper center") plt.tight_layout() @@ -379,17 +351,8 @@ def plot_rpe(rpe_results: List[RPEResult]) -> Tuple[Figure, Figure]: _rpy_legend(figure_dict) - ret_sum = ( - 1 - if any( - dev.rpe_dev.pair_distance_unit == PairDistanceUnit.METER - for dev in rpe_results - ) - else 0 - ) - if any( - dev.rpe_dev.pair_distance_unit == PairDistanceUnit.SECOND for dev in rpe_results - ): + ret_sum = 1 if any(dev.rpe_dev.pair_distance_unit == PairDistanceUnit.METER for dev in rpe_results) else 0 + if any(dev.rpe_dev.pair_distance_unit == PairDistanceUnit.SECOND for dev in rpe_results): ret_sum += 2 plt.close({1: fig_time, 2: fig_metric}.get(ret_sum)) @@ -402,9 +365,7 @@ def plot_rpe(rpe_results: List[RPEResult]) -> Tuple[Figure, Figure]: }[ret_sum] -def _plot_rpe_pos( - figure_dict: Dict[PairDistanceUnit, Axes], devs: List[RPEResult] -) -> None: +def _plot_rpe_pos(figure_dict: Dict[PairDistanceUnit, Axes], devs: List[RPEResult]) -> None: for dev in devs: line_plot = figure_dict[dev.rpe_dev.pair_distance_unit].plot( dev.mean_pair_distances, dev.pos_dev_mean, label=dev.name @@ -426,9 +387,7 @@ def _plot_rpe_pos( _set_violin_color(violin_plot, line_plot[0].get_color()) -def _plot_rpe_rot( - figure_dict: Dict[PairDistanceUnit, Axes], devs: List[RPEResult] -) -> None: +def _plot_rpe_rot(figure_dict: Dict[PairDistanceUnit, Axes], devs: List[RPEResult]) -> None: plot_sum = 0 for dev in devs: if not dev.has_rot_dev: @@ -481,9 +440,7 @@ def _rpy_legend(figure_dict: Dict[str, Dict[PairDistanceUnit, Axes]]): ax.legend() -def scatter_ate( - ate_result: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings() -) -> Tuple[Figure, Figure]: +def scatter_ate(ate_result: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings()) -> Tuple[Figure, Figure]: """ Plots the ATE results as a scatter plot with color-coded deviations. @@ -524,9 +481,7 @@ def _colored_scatter_plot( plt.xlabel("x [m]") plt.ylabel("y [m]") - c_list, lower_bound, upper_bound, c_bar_ticks, c_bar_ticklabels = ( - _setup_cbar_params(c_list, plot_settings) - ) + c_list, lower_bound, upper_bound, c_bar_ticks, c_bar_ticklabels = _setup_cbar_params(c_list, plot_settings) sc = plt.scatter( xyz[:, 0],