diff --git a/structure/ball_tree.cpp b/structure/ball_tree.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7cf6ead8278a5cc89e3a2c0f6c94aadda7c855d8 --- /dev/null +++ b/structure/ball_tree.cpp @@ -0,0 +1,167 @@ + /** + * Balltree (k-Nearest Neighbors) + * + * Complexity (Time): O(n log n) + * Complexity (Space): O(n) + */ + +#define x first +#define y second + +typedef pair<double, double> point; +typedef vector<point> pset; + +typedef struct node { + double radius; + point center; + + node *left, *right; +} node; + + +double distance(point &a, point &b) { + return sqrt((a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y)); +} + + +// Find furthest point from center and returns <distance,index> of that point +pair<double, int> get_radius(point ¢er, pset &ps) { + int ind = 0; + double dist, radius = -1.0; + + for (int i = 0; i < ps.size(); ++i) { + dist = distance(center, ps[i]); + + if (radius < dist) { + radius = dist; + ind = i; + } + } + + return pair<double, int>(radius, ind); +} + + +// Find average point and pretends it's the center of the given set of points +void get_center(pset &ps, point ¢er) { + center.x = center.y = 0; + + for (auto p : ps) { + center.x += p.x; + center.y += p.y; + } + + center.x /= (double)ps.size(); + center.y /= (double)ps.size(); +} + + +// Splits the set of points in closer to ps[lind] and closer to ps[rind], +// where lind is returned by get_radius and rind is the furthest points +// from ps[lind] +void partition(pset &ps, pset &left, pset &right, int lind) { + int rind = 0; + double dist, grt = -1.0; + double ldist, rdist; + + point rmpoint; + point lmpoint = ps[lind]; + + for (int i = 0; i < ps.size(); ++i) + if (i != lind) { + dist = distance(lmpoint, ps[i]); + + if (dist > grt) { + grt = dist; + rind = i; + } + } + + rmpoint = ps[rind]; + + left.push_back(ps[lind]); + right.push_back(ps[rind]); + + for (int i = 0; i < ps.size(); ++i) + if (i != lind && i != rind) { + ldist = distance(ps[i], lmpoint); + rdist = distance(ps[i], rmpoint); + + if (ldist <= rdist) + left.push_back(ps[i]); + else + right.push_back(ps[i]); + } +} + + +// Build ball-tree recursively +// ps: vector of points +node *build(pset &ps) { + if (ps.size() == 0) + return nullptr; + + node *n = new node; + + // When there's only one point in ps, a leaf node is created storing that + // point + if (ps.size() == 1) { + n->center = ps[0]; + + n->radius = 0.0; + n->right = n->left = nullptr; + + // Otherwise, ps gets split into two partitions, one for each child + } else { + get_center(ps, n->center); + auto rad = get_radius(n->center, ps); + + pset lpart, rpart; + partition(ps, lpart, rpart, rad.second); + + n->radius = rad.first; + n->left = build(lpart); + n->right = build(rpart); + } + + return n; +} + + +// Search the ball-tree recursively +// n: root +// t: query point +// pq: initially empty multiset (will contain the answer after execution) +// k: number of nearest neighbors +void search(node *n, point t, multiset<double> &pq, int &k) { + if (n->left == nullptr && n->right == nullptr) { + double dist = distance(t, n->center); + + // (!) Only necessary when the same point needs to be ignored + if (dist < EPS) + return; + + else if (pq.size() < k || dist < *pq.rbegin()) { + pq.insert(dist); + + if (pq.size() > k) + pq.erase(prev(pq.end())); + } + } else { + double distl = distance(t, n->left->center); + double distr = distance(t, n->right->center); + + if (distl <= distr) { + if (pq.size() < k || (distl <= *pq.rbegin() + n->left->radius)) + search(n->left, t, pq, k); + if (pq.size() < k || (distr <= *pq.rbegin() + n->right->radius)) + search(n->right, t, pq, k); + + } else { + if (pq.size() < k || (distr <= *pq.rbegin() + n->right->radius)) + search(n->right, t, pq, k); + if (pq.size() < k || (distl <= *pq.rbegin() + n->left->radius)) + search(n->left, t, pq, k); + } + } +}