From 772116dfed74d934de8dad9ac0754e500a989cfa Mon Sep 17 00:00:00 2001 From: kleeman Date: Wed, 4 Sep 2019 01:36:29 -0700 Subject: [PATCH] Add a nearest neighbor model. --- include/albatross/NearestNeighbor | 20 +++ include/albatross/serialize/NearestNeighbor | 20 +++ .../albatross/src/cereal/nearest_neighbor.hpp | 49 ++++++++ .../src/evaluation/cross_validation_utils.hpp | 16 ++- .../albatross/src/models/nearest_neighbor.hpp | 115 ++++++++++++++++++ include/albatross/src/models/null_model.hpp | 9 +- tests/test_cross_validation.cc | 30 +++-- tests/test_models.h | 15 ++- tests/test_serialize.cc | 1 + 9 files changed, 252 insertions(+), 23 deletions(-) create mode 100644 include/albatross/NearestNeighbor create mode 100644 include/albatross/serialize/NearestNeighbor create mode 100644 include/albatross/src/cereal/nearest_neighbor.hpp create mode 100644 include/albatross/src/models/nearest_neighbor.hpp diff --git a/include/albatross/NearestNeighbor b/include/albatross/NearestNeighbor new file mode 100644 index 00000000..ad6ed361 --- /dev/null +++ b/include/albatross/NearestNeighbor @@ -0,0 +1,20 @@ +/* + * Copyright (C) 2019 Swift Navigation Inc. + * Contact: Swift Navigation + * + * This source is subject to the license found in the file 'LICENSE' which must + * be distributed together with this source. All other rights reserved. + * + * THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, + * EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE. + */ + +#ifndef ALBATROSS_NEAREST_NEIGHBOR_MODEL_H +#define ALBATROSS_NEAREST_NEIGHBOR_MODEL_H + +#include "Core" + +#include + +#endif diff --git a/include/albatross/serialize/NearestNeighbor b/include/albatross/serialize/NearestNeighbor new file mode 100644 index 00000000..268126ef --- /dev/null +++ b/include/albatross/serialize/NearestNeighbor @@ -0,0 +1,20 @@ +/* + * Copyright (C) 2019 Swift Navigation Inc. + * Contact: Swift Navigation + * + * This source is subject to the license found in the file 'LICENSE' which must + * be distributed together with this source. All other rights reserved. + * + * THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, + * EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE. + */ + +#ifndef ALBATROSS_SERIALIZE_NEAREST_NEIGHBOR_H +#define ALBATROSS_SERIALIZE_NEAREST_NEIGHBOR_H + +#include "Core" + +#include "../src/cereal/nearest_neighbor.hpp" + +#endif diff --git a/include/albatross/src/cereal/nearest_neighbor.hpp b/include/albatross/src/cereal/nearest_neighbor.hpp new file mode 100644 index 00000000..55eb403c --- /dev/null +++ b/include/albatross/src/cereal/nearest_neighbor.hpp @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2019 Swift Navigation Inc. + * Contact: Swift Navigation + * + * This source is subject to the license found in the file 'LICENSE' which must + * be distributed together with this source. All other rights reserved. + * + * THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, + * EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE. + */ + +#ifndef ALBATROSS_CEREAL_NEAREST_NEIGHBOR_HPP_ +#define ALBATROSS_CEREAL_NEAREST_NEIGHBOR_HPP_ + +namespace albatross { + +template class NearestNeighborModel; + +template struct NearestNeighborFit; + +} // namespace albatross + +namespace cereal { + +template +inline void +save(Archive &archive, + const albatross::Fit> &fit, + const std::uint32_t) { + archive(cereal::make_nvp("training_features", fit.training_data.features)); + archive(cereal::make_nvp("training_targets", fit.training_data.targets)); +} + +template +inline void +load(Archive &archive, + albatross::Fit> &fit, + const std::uint32_t) { + std::vector features; + archive(cereal::make_nvp("training_features", features)); + albatross::MarginalDistribution targets; + archive(cereal::make_nvp("training_targets", targets)); + fit.training_data = RegressionDataset(features, targets); +} + +} // namespace cereal + +#endif /* ALBATROSS_CEREAL_NEAREST_NEIGHBOR_HPP_ */ diff --git a/include/albatross/src/evaluation/cross_validation_utils.hpp b/include/albatross/src/evaluation/cross_validation_utils.hpp index 4030a9fe..c9a82029 100644 --- a/include/albatross/src/evaluation/cross_validation_utils.hpp +++ b/include/albatross/src/evaluation/cross_validation_utils.hpp @@ -93,15 +93,25 @@ inline MarginalDistribution concatenate_marginal_predictions( Eigen::VectorXd variance(n); Eigen::Index number_filled = 0; // Put all the predicted means back in order. + bool has_covariance = false; for (const auto &pair : indexer) { assert(preds.at(pair.first).size() == pair.second.size()); set_subset(preds.at(pair.first).mean, pair.second, &mean); - set_subset(preds.at(pair.first).covariance.diagonal(), pair.second, - &variance); + if (preds.at(pair.first).has_covariance()) { + has_covariance = true; + set_subset(preds.at(pair.first).covariance.diagonal(), pair.second, + &variance); + } else { + assert(!has_covariance); + } number_filled += static_cast(pair.second.size()); } assert(number_filled == n); - return MarginalDistribution(mean, variance.asDiagonal()); + if (has_covariance) { + return MarginalDistribution(mean, variance.asDiagonal()); + } else { + return MarginalDistribution(mean); + } } template + * + * This source is subject to the license found in the file 'LICENSE' which must + * be distributed together with this source. All other rights reserved. + * + * THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, + * EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE. + */ + +#ifndef ALBATROSS_SRC_MODELS_NEAREST_NEIGHBOR_MODEL_HPP_ +#define ALBATROSS_SRC_MODELS_NEAREST_NEIGHBOR_MODEL_HPP_ + +namespace albatross { + +template class NearestNeighborModel; + +template struct NearestNeighborFit; + +template struct Fit> { + + Fit() : training_data(){}; + + Fit(const RegressionDataset &dataset) : training_data(dataset){}; + + bool operator==(const Fit> &other) const { + return training_data == other.training_data; + } + + RegressionDataset training_data; +}; + +template +class NearestNeighborModel + : public ModelBase> { + +public: + NearestNeighborModel() : distance_metric(){}; + + std::string get_name() const { return "nearest_neighbor_model"; }; + + template + Fit> + _fit_impl(const std::vector &features, + const MarginalDistribution &targets) const { + return Fit>( + RegressionDataset(features, targets)); + } + + template + auto fit_from_prediction(const std::vector &features, + const JointDistribution &prediction) const { + const NearestNeighborModel m(*this); + MarginalDistribution marginal_pred( + prediction.mean, prediction.covariance.diagonal().asDiagonal()); + Fit> fit = { + RegressionDataset(features, marginal_pred)}; + FitModel>> + fit_model(m, fit); + return fit_model; + } + + template + MarginalDistribution + _predict_impl(const std::vector &features, + const Fit> &fit, + PredictTypeIdentity &&) const { + const Eigen::Index n = static_cast(features.size()); + Eigen::VectorXd mean = Eigen::VectorXd::Zero(n); + mean.fill(NAN); + Eigen::VectorXd variance = Eigen::VectorXd::Zero(n); + variance.fill(NAN); + + for (std::size_t i = 0; i < features.size(); ++i) { + const auto min_index = + index_with_min_distance(features[i], fit.training_data.features); + mean[i] = fit.training_data.targets.mean[min_index]; + variance[i] = fit.training_data.targets.get_diagonal(min_index); + } + + if (fit.training_data.targets.has_covariance()) { + return MarginalDistribution(mean, variance.asDiagonal()); + } else { + return MarginalDistribution(mean); + } + } + +private: + template + std::size_t + index_with_min_distance(const FeatureType &ref, + const std::vector &features) const { + assert(features.size() > 0); + + std::size_t min_index = 0; + double min_distance = distance_metric(ref, features[0]); + + for (std::size_t i = 1; i < features.size(); ++i) { + const double dist = distance_metric(ref, features[i]); + if (dist < min_distance) { + min_index = i; + min_distance = dist; + } + } + return min_index; + } + + DistanceMetric distance_metric; +}; + +} // namespace albatross + +#endif // ALBATROSS_SRC_MODELS_NEAREST_NEIGHBOR_MODEL_HPP_ diff --git a/include/albatross/src/models/null_model.hpp b/include/albatross/src/models/null_model.hpp index 803afb5b..42203d8b 100644 --- a/include/albatross/src/models/null_model.hpp +++ b/include/albatross/src/models/null_model.hpp @@ -32,10 +32,6 @@ class NullModel : public ModelBase { std::string get_name() const { return "null_model"; }; - /* - * The Gaussian Process Regression model derives its parameters from - * the covariance functions. - */ ParameterStore get_params() const override { return params_; } void unchecked_set_param(const std::string &name, @@ -43,8 +39,6 @@ class NullModel : public ModelBase { params_[name] = param; } - // If the implementing class doesn't have a fit method for this - // FeatureType but the CovarianceFunction does. template Fit _fit_impl(const std::vector &features, const MarginalDistribution &targets) const { @@ -87,5 +81,4 @@ class NullModel : public ModelBase { } // namespace albatross -#endif /* THIRD_PARTY_ALBATROSS_INCLUDE_ALBATROSS_SRC_MODELS_NULL_MODEL_HPP_ \ - */ +#endif // ALBATROSS_SRC_MODELS_NULL_MODEL_HPP_ diff --git a/tests/test_cross_validation.cc b/tests/test_cross_validation.cc index 007f2683..af202f80 100644 --- a/tests/test_cross_validation.cc +++ b/tests/test_cross_validation.cc @@ -53,16 +53,22 @@ TYPED_TEST_P(RegressionModelTester, test_logo_predict_variants) { auto dataset = this->test_case.get_dataset(); auto model = this->test_case.get_model(); - // Here we assume that the test case is linear, then split - // it using a group function which will not preserve order - // and make sure that cross validation properly reassembles - // the predictions - LeaveOneGroupOut logo(group_by_interval); - const auto prediction = model.cross_validate().predict(dataset, logo); - - EXPECT_TRUE(is_monotonic_increasing(prediction.mean())); - - expect_predict_variants_consistent(prediction); + // The nearest neighbor approach is not capable of modelling linear + // trends and in turn fails this test. + if (!std::is_same>::value) { + // Here we assume that the test case is linear, then split + // it using a group function which will not preserve order + // and make sure that cross validation properly reassembles + // the predictions + LeaveOneGroupOut logo( + group_by_interval); + const auto prediction = model.cross_validate().predict(dataset, logo); + + EXPECT_TRUE(is_monotonic_increasing(prediction.mean())); + + expect_predict_variants_consistent(prediction); + } } TYPED_TEST_P(RegressionModelTester, test_loo_predict_variants) { @@ -110,7 +116,9 @@ TYPED_TEST_P(RegressionModelTester, test_score_variants) { // Here we make sure the cross validated mean absolute error is reasonable. // Note that because we are running leave one out cross validation, the // RMSE for each fold is just the absolute value of the error. - if (!std::is_same::value) { + if (!std::is_same::value && + !std::is_same>::value) { EXPECT_LE(cv_scores.mean(), 0.1); } } diff --git a/tests/test_models.h b/tests/test_models.h index ff9832df..fe50a02d 100644 --- a/tests/test_models.h +++ b/tests/test_models.h @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -181,6 +182,17 @@ class MakeNullModel { } }; +class MakeNearestNeighborModel { +public: + NearestNeighborModel get_model() const { + return NearestNeighborModel(); + } + + RegressionDataset get_dataset() const { + return make_toy_linear_data(); + } +}; + template class RegressionModelTester : public ::testing::Test { public: @@ -189,7 +201,8 @@ class RegressionModelTester : public ::testing::Test { typedef ::testing::Types + MakeRansacAdaptedGaussianProcess, MakeNullModel, + MakeNearestNeighborModel> ExampleModels; TYPED_TEST_CASE_P(RegressionModelTester); diff --git a/tests/test_serialize.cc b/tests/test_serialize.cc index a3ca7f33..eaf23156 100644 --- a/tests/test_serialize.cc +++ b/tests/test_serialize.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include