Skip to content

Commit

Permalink
Fix plotting bug related to single deviation plot with deviations wit…
Browse files Browse the repository at this point in the history
…hout rotations
  • Loading branch information
gereon-t committed Jan 6, 2025
1 parent c817a9a commit 4bc41b1
Showing 1 changed file with 17 additions and 62 deletions.
79 changes: 17 additions & 62 deletions trajectopy/core/plotting/mpl/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def _stair_hist(*, l, mm: bool = False, linewidth: float = 1.5) -> None:
return max(n_hist)


def plot_compact_ate_hist(
ate_result: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings()
) -> Figure:
def plot_compact_ate_hist(ate_result: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings()) -> Figure:
"""
Plots compact ATE histograms for the given ATEResult.
The plot contains histograms for the position deviations and, if available, the rotation deviations.
Expand All @@ -79,9 +77,7 @@ def plot_compact_ate_hist(
return fig


def plot_rotation_ate_hist(
devs: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings()
) -> None:
def plot_rotation_ate_hist(devs: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings()) -> None:
roll = np.rad2deg(devs.rot_dev_x)
pitch = np.rad2deg(devs.rot_dev_y)
yaw = np.rad2deg(devs.rot_dev_z)
Expand All @@ -99,29 +95,11 @@ def plot_rotation_ate_hist(
plt.legend(["yaw", "pitch", "roll"])


def plot_position_ate_hist(
devs: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings()
):
deviations_xa = (
devs.abs_dev.directed_pos_dev[:, 0]
if plot_settings.directed_ate
else devs.abs_dev.pos_dev[:, 0]
)
deviations_yh = (
devs.abs_dev.directed_pos_dev[:, 1]
if plot_settings.directed_ate
else devs.abs_dev.pos_dev[:, 1]
)
deviations_zv = (
devs.abs_dev.directed_pos_dev[:, 2]
if plot_settings.directed_ate
else devs.abs_dev.pos_dev[:, 2]
)
labels = (
["vertical", "horizontal", "along"]
if plot_settings.directed_ate
else ["x", "y", "z"]
)
def plot_position_ate_hist(devs: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings()):
deviations_xa = devs.abs_dev.directed_pos_dev[:, 0] if plot_settings.directed_ate else devs.abs_dev.pos_dev[:, 0]
deviations_yh = devs.abs_dev.directed_pos_dev[:, 1] if plot_settings.directed_ate else devs.abs_dev.pos_dev[:, 1]
deviations_zv = devs.abs_dev.directed_pos_dev[:, 2] if plot_settings.directed_ate else devs.abs_dev.pos_dev[:, 2]
labels = ["vertical", "horizontal", "along"] if plot_settings.directed_ate else ["x", "y", "z"]

plt.xlabel(plot_settings.unit_str)
plt.ylabel("counts")
Expand Down Expand Up @@ -175,9 +153,7 @@ def plot_position_ate_edf(

for dev in deviation_list:
sorted_comb_pos_dev = np.sort(dev.pos_dev_comb)
pos_norm_cdf = np.arange(len(sorted_comb_pos_dev)) / float(
len(sorted_comb_pos_dev)
)
pos_norm_cdf = np.arange(len(sorted_comb_pos_dev)) / float(len(sorted_comb_pos_dev))
ax_pos.plot(sorted_comb_pos_dev * plot_settings.unit_multiplier, pos_norm_cdf)


Expand All @@ -194,9 +170,7 @@ def plot_rotation_ate_edf(deviation_list: List[ATEResult]) -> None:
if dev.abs_dev.rot_dev is None:
continue
sorted_comb_rot_dev = np.sort(np.rad2deg(dev.rot_dev_comb))
rot_norm_cdf = np.arange(len(sorted_comb_rot_dev)) / float(
len(sorted_comb_rot_dev)
)
rot_norm_cdf = np.arange(len(sorted_comb_rot_dev)) / float(len(sorted_comb_rot_dev))
ax_rot.plot(sorted_comb_rot_dev, rot_norm_cdf)


Expand Down Expand Up @@ -281,9 +255,7 @@ def plot_ate(
Figure: Figure containing the plot.
"""
deviation_list = ate_results if isinstance(ate_results, list) else [ate_results]
x_label = derive_xlabel_from_sortings(
[dev.trajectory.sorting.value for dev in deviation_list]
)
x_label = derive_xlabel_from_sortings([dev.trajectory.sorting.value for dev in deviation_list])

fig = plt.figure()

Expand Down Expand Up @@ -323,9 +295,9 @@ def plot_ate(
dev.trajectory.function_of[arc_length_sorting],
np.rad2deg(dev.rot_dev_comb[arc_length_sorting]),
)
ax_rot.set_xlim(min_x, max_x)

ax_pos.set_xlim(min_x, max_x)
ax_rot.set_xlim(min_x, max_x)

fig.legend([dev.name for dev in deviation_list], ncol=3, loc="upper center")
plt.tight_layout()
Expand Down Expand Up @@ -379,17 +351,8 @@ def plot_rpe(rpe_results: List[RPEResult]) -> Tuple[Figure, Figure]:

_rpy_legend(figure_dict)

ret_sum = (
1
if any(
dev.rpe_dev.pair_distance_unit == PairDistanceUnit.METER
for dev in rpe_results
)
else 0
)
if any(
dev.rpe_dev.pair_distance_unit == PairDistanceUnit.SECOND for dev in rpe_results
):
ret_sum = 1 if any(dev.rpe_dev.pair_distance_unit == PairDistanceUnit.METER for dev in rpe_results) else 0
if any(dev.rpe_dev.pair_distance_unit == PairDistanceUnit.SECOND for dev in rpe_results):
ret_sum += 2

plt.close({1: fig_time, 2: fig_metric}.get(ret_sum))
Expand All @@ -402,9 +365,7 @@ def plot_rpe(rpe_results: List[RPEResult]) -> Tuple[Figure, Figure]:
}[ret_sum]


def _plot_rpe_pos(
figure_dict: Dict[PairDistanceUnit, Axes], devs: List[RPEResult]
) -> None:
def _plot_rpe_pos(figure_dict: Dict[PairDistanceUnit, Axes], devs: List[RPEResult]) -> None:
for dev in devs:
line_plot = figure_dict[dev.rpe_dev.pair_distance_unit].plot(
dev.mean_pair_distances, dev.pos_dev_mean, label=dev.name
Expand All @@ -426,9 +387,7 @@ def _plot_rpe_pos(
_set_violin_color(violin_plot, line_plot[0].get_color())


def _plot_rpe_rot(
figure_dict: Dict[PairDistanceUnit, Axes], devs: List[RPEResult]
) -> None:
def _plot_rpe_rot(figure_dict: Dict[PairDistanceUnit, Axes], devs: List[RPEResult]) -> None:
plot_sum = 0
for dev in devs:
if not dev.has_rot_dev:
Expand Down Expand Up @@ -481,9 +440,7 @@ def _rpy_legend(figure_dict: Dict[str, Dict[PairDistanceUnit, Axes]]):
ax.legend()


def scatter_ate(
ate_result: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings()
) -> Tuple[Figure, Figure]:
def scatter_ate(ate_result: ATEResult, plot_settings: MPLPlotSettings = MPLPlotSettings()) -> Tuple[Figure, Figure]:
"""
Plots the ATE results as a scatter plot with color-coded deviations.
Expand Down Expand Up @@ -524,9 +481,7 @@ def _colored_scatter_plot(
plt.xlabel("x [m]")
plt.ylabel("y [m]")

c_list, lower_bound, upper_bound, c_bar_ticks, c_bar_ticklabels = (
_setup_cbar_params(c_list, plot_settings)
)
c_list, lower_bound, upper_bound, c_bar_ticks, c_bar_ticklabels = _setup_cbar_params(c_list, plot_settings)

sc = plt.scatter(
xyz[:, 0],
Expand Down

0 comments on commit 4bc41b1

Please sign in to comment.