-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest_prediction.cc
123 lines (100 loc) · 3.97 KB
/
test_prediction.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/*
* Copyright (C) 2019 Swift Navigation Inc.
* Contact: Swift Navigation <[email protected]>
*
* This source is subject to the license found in the file 'LICENSE' which must
* be distributed together with this source. All other rights reserved.
*
* THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND,
* EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE.
*/
#include <albatross/Core>
#include <gtest/gtest.h>
namespace albatross {
struct X {};
class MeanOnlyModel : public ModelBase<MeanOnlyModel> {
public:
Fit<MeanOnlyModel> _fit_impl(const std::vector<X> &,
const MarginalDistribution &) const {
return {};
}
Eigen::VectorXd _predict_impl(const std::vector<X> &features,
const Fit<MeanOnlyModel> &,
PredictTypeIdentity<Eigen::VectorXd>) const {
return Eigen::VectorXd::Zero(cast::to_index(features.size()));
}
};
TEST(test_prediction, test_mean_only) {
MeanOnlyModel m;
std::vector<X> xs = {{}, {}};
const auto zeros = Eigen::VectorXd::Zero(cast::to_index(xs.size()));
MarginalDistribution targets(zeros);
const auto fit_model = m.fit(xs, targets);
const auto prediction = fit_model.predict(xs);
auto mean = prediction.mean();
EXPECT_TRUE(bool(std::is_same<Eigen::VectorXd, decltype(mean)>::value));
std::vector<X> empty = {};
EXPECT_EQ(fit_model.predict(empty).mean().size(), 0);
}
class MarginalOnlyModel : public ModelBase<MarginalOnlyModel> {
public:
Fit<MarginalOnlyModel> _fit_impl(const std::vector<X> &,
const MarginalDistribution &) const {
return {};
}
MarginalDistribution
_predict_impl(const std::vector<X> &features, const Fit<MarginalOnlyModel> &,
PredictTypeIdentity<MarginalDistribution>) const {
auto mean = Eigen::VectorXd::Zero(cast::to_index(features.size()));
return MarginalDistribution(mean);
}
};
TEST(test_prediction, test_marginal_only) {
MarginalOnlyModel m;
std::vector<X> xs = {{}, {}};
const auto zeros = Eigen::VectorXd::Zero(cast::to_index(xs.size()));
MarginalDistribution targets(zeros);
const auto fit_model = m.fit(xs, targets);
const auto prediction = fit_model.predict(xs);
auto mean = prediction.mean();
EXPECT_TRUE(bool(std::is_same<Eigen::VectorXd, decltype(mean)>::value));
auto marginal = prediction.marginal();
EXPECT_TRUE(
bool(std::is_same<MarginalDistribution, decltype(marginal)>::value));
std::vector<X> empty = {};
EXPECT_EQ(fit_model.predict(empty).marginal().size(), 0);
}
class JointOnlyModel : public ModelBase<JointOnlyModel> {
public:
Fit<JointOnlyModel> _fit_impl(const std::vector<X> &,
const MarginalDistribution &) const {
return {};
}
JointDistribution
_predict_impl(const std::vector<X> &features, const Fit<JointOnlyModel> &,
PredictTypeIdentity<JointDistribution>) const {
const Eigen::Index n = cast::to_index(features.size());
const auto mean = Eigen::VectorXd::Zero(n);
const auto covariance = Eigen::MatrixXd::Zero(n, n);
return JointDistribution(mean, covariance);
}
};
TEST(test_prediction, test_joint_only) {
JointOnlyModel m;
std::vector<X> xs = {{}, {}};
const auto zeros = Eigen::VectorXd::Zero(cast::to_index(xs.size()));
MarginalDistribution targets(zeros);
const auto fit_model = m.fit(xs, targets);
const auto prediction = fit_model.predict(xs);
auto mean = prediction.mean();
EXPECT_TRUE(bool(std::is_same<Eigen::VectorXd, decltype(mean)>::value));
auto marginal = prediction.marginal();
EXPECT_TRUE(
bool(std::is_same<MarginalDistribution, decltype(marginal)>::value));
auto joint = prediction.joint();
EXPECT_TRUE(bool(std::is_same<JointDistribution, decltype(joint)>::value));
std::vector<X> empty = {};
EXPECT_EQ(fit_model.predict(empty).joint().size(), 0);
}
} // namespace albatross