15double squared(
const double x) {
return x *
x; }
17double dis_from_bnd(
const double x,
const double amin,
const double amax) {
38 const size_t n = data_in.size();
39 if (!data_in.empty()) {
40 m_dim = data_in[0].size();
44 for (
size_t i = 0; i < n; i++) m_ind[i] = i;
46 m_root = build_tree_for_range(0, n - 1, 0);
56 if (u < l)
return nullptr;
58 if ((u - l) <= bucketsize) {
60 for (
size_t i = 0; i <
m_dim; i++) {
61 node->box[i] = spread_in_coordinate(i, l, u);
67 node->left = node->right =
nullptr;
76 double maxspread = 0.0;
77 for (
size_t i = 0; i <
m_dim; i++) {
78 if (!parent || (parent->cut_dim == i)) {
79 node->box[i] = spread_in_coordinate(i, l, u);
81 node->box[i] = parent->box[i];
83 const double spread = node->box[i][1] - node->box[i][0];
84 if (spread > maxspread) {
92 for (
int k = l; k <= u; k++) {
93 sum +=
m_data[m_ind[k]][c];
95 const double average = sum /
static_cast<double>(u - l + 1);
96 int m = select_on_coordinate_value(c, average, l, u);
102 node->left = build_tree_for_range(l, m, node);
103 node->right = build_tree_for_range(m + 1, u, node);
106 node->box = node->left->box;
107 node->cut_val = node->left->box[c][1];
108 node->cut_val_left = node->cut_val_right = node->cut_val;
109 }
else if (!node->left) {
110 node->box = node->right->box;
111 node->cut_val = node->right->box[c][1];
112 node->cut_val_left = node->cut_val_right = node->cut_val;
114 node->cut_val_right = node->right->box[c][0];
115 node->cut_val_left = node->left->box[c][1];
116 node->cut_val = 0.5 * (node->cut_val_left + node->cut_val_right);
121 for (
size_t i = 0; i <
m_dim; i++) {
122 node->box[i][1] = std::max(node->left->box[i][1],
123 node->right->box[i][1]);
124 node->box[i][0] = std::min(node->left->box[i][0],
125 node->right->box[i][0]);
132std::array<double, 2> KDTree::spread_in_coordinate(
const int c,
const int l,
136 double smin =
m_data[m_ind[l]][c];
141 for (i = l + 2; i <= u; i += 2) {
142 double lmin =
m_data[m_ind[i - 1]][c];
143 double lmax =
m_data[m_ind[i]][c];
144 if (lmin > lmax) std::swap(lmin, lmax);
145 if (smin > lmin) smin = lmin;
146 if (smax < lmax) smax = lmax;
151 if (smin > last) smin =
last;
152 if (smax < last) smax =
last;
157int KDTree::select_on_coordinate_value(
int c,
double alpha,
int l,
int u) {
163 if (
m_data[m_ind[lb]][c] <= alpha) {
166 std::swap(m_ind[lb], m_ind[ub]);
171 return m_data[m_ind[lb]][c] <=
alpha ? lb : lb - 1;
175 const unsigned int nn,
176 std::vector<KDTreeResult>& result)
const {
178 std::priority_queue<KDTreeResult> res;
179 double r2 = std::numeric_limits<double>::max();
180 m_root->search_n(-1, 0, nn, r2, qv, *
this, res);
182 while (!res.empty()) {
183 result.push_back(res.top());
190 const unsigned int ndecorrel,
191 const unsigned int nn,
192 std::vector<KDTreeResult>& result)
const {
194 std::priority_queue<KDTreeResult> res;
195 double r2 = std::numeric_limits<double>::max();
196 m_root->search_n(idx, ndecorrel, nn, r2,
m_data[idx], *
this, res);
198 while (!res.empty()) {
199 result.push_back(res.top());
206 std::vector<KDTreeResult>& result)
const {
209 m_root->search_r(-1, 0, r2, qv, *
this, result);
214 const unsigned int ndecorrel,
216 std::vector<KDTreeResult>& result)
const {
219 m_root->search_r(idx, ndecorrel, r2,
m_data[idx], *
this, result);
228 if (left)
delete left;
229 if (right)
delete right;
232void KDTreeNode::search_n(
const int idx0,
const int nd,
233 const unsigned int nn,
double& r2,
234 const std::vector<double>& qv,
const KDTree& tree,
235 std::priority_queue<KDTreeResult>& res)
const {
237 if (!left && !right) {
239 process_terminal_node_n(idx0, nd, nn, r2, qv, tree, res);
242 KDTreeNode *ncloser =
nullptr;
243 KDTreeNode *nfarther =
nullptr;
246 double qval = qv[cut_dim];
248 if (qval < cut_val) {
251 extra = cut_val_right - qval;
255 extra = qval - cut_val_left;
258 if (ncloser) ncloser->search_n(idx0, nd, nn, r2, qv, tree, res);
259 if ((nfarther) && (squared(extra) < r2)) {
261 if (nfarther->box_in_search_range(r2, qv)) {
262 nfarther->search_n(idx0, nd, nn, r2, qv, tree, res);
267void KDTreeNode::search_r(
const int idx0,
const int nd,
const double r2,
268 const std::vector<double>& qv,
const KDTree& tree,
269 std::vector<KDTreeResult>& res)
const {
271 if (!left && !right) {
273 process_terminal_node_r(idx0, nd, r2, qv, tree, res);
280 double qval = qv[cut_dim];
282 if (qval < cut_val) {
285 extra = cut_val_right - qval;
289 extra = qval - cut_val_left;
292 if (ncloser) ncloser->search_r(idx0, nd, r2, qv, tree, res);
293 if ((nfarther) && (squared(extra) < r2)) {
295 if (nfarther->box_in_search_range(r2, qv)) {
296 nfarther->search_r(idx0, nd, r2, qv, tree, res);
301inline bool KDTreeNode::box_in_search_range(
const double r2,
302 const std::vector<double>& qv)
const {
306 const size_t dim = qv.size();
308 for (
size_t i = 0; i < dim; i++) {
309 dis2 += squared(dis_from_bnd(qv[i], box[i][0], box[i][1]));
310 if (dis2 > r2)
return false;
315void KDTreeNode::process_terminal_node_n(
const int idx0,
const int nd,
316 const unsigned int nn,
double& r2,
const std::vector<double>& qv,
317 const KDTree& tree, std::priority_queue<KDTreeResult>& res)
const {
319 const size_t dim = tree.m_dim;
320 const auto& data = tree.m_data;
322 for (
int i = m_l; i <= m_u; i++) {
323 const int idx = tree.m_ind[i];
324 bool early_exit =
false;
326 for (
size_t k = 0; k < dim; k++) {
327 dis += squared(data[idx][k] - qv[k]);
333 if (early_exit)
continue;
336 if (idx0 >= 0 && (abs(idx - idx0) < nd))
continue;
339 if (res.size() < nn) {
346 if (res.size() == nn) r2 = res.top().dis;
361void KDTreeNode::process_terminal_node_r(
const int idx0,
const int nd,
362 const double r2,
const std::vector<double>& qv,
const KDTree& tree,
363 std::vector<KDTreeResult>& res)
const {
365 const size_t dim = tree.m_dim;
366 const auto& data = tree.m_data;
368 for (
int i = m_l; i <= m_u; i++) {
369 const int idx = tree.m_ind[i];
370 bool early_exit =
false;
372 for (
size_t k = 0; k < dim; k++) {
373 dis += squared(data[idx][k] - qv[k]);
379 if (early_exit)
continue;
382 if (idx0 >= 0 && (abs(idx - idx0) < nd))
continue;
387 res.push_back(std::move(e));
KDTreeNode(int dim)
Constructor.
const KDTreeArray & m_data
void n_nearest_around_point(const unsigned int idx, const unsigned int ndecorrel, const unsigned int nn, std::vector< KDTreeResult > &result) const
void r_nearest(const std::vector< double > &qv, const double r2, std::vector< KDTreeResult > &result) const
void n_nearest(const std::vector< double > &qv, const unsigned int nn, std::vector< KDTreeResult > &result) const
void r_nearest_around_point(const unsigned int idx, const unsigned int ndecorrel, const double r2, std::vector< KDTreeResult > &result) const
bool operator<(const KDTreeResult &e1, const KDTreeResult &e2)
std::vector< std::vector< double > > KDTreeArray