K-D Tree 学习笔记
最近看了一下k-NN然后它说如果特征空间维数比较低的时候用K-D Tree来求k近邻比较快所以就来补一下学OI时没学的K-D Tree假装写一个学习笔记吧。
是什么?
是一个平衡二叉树
k=1的时候就是一只BST
k>1的话,每一层换一维来分割
就是用许多垂直坐标轴的超平面将一个k维空间分割
每个节点保存了一个点,它所代表的超平面就是经过这个点垂直于某个坐标轴一个超平面
每个子树代表了一个区域(代码实现中是包含子树中所有点的最小超矩形,实际上应该是划分后的那个超矩形
怎么做?
建树
我没有任何建树,下一个
复杂度\(O(kn\log n)\),一个分治...
插入
直接插入就行了,注意一路update
挺不科学的插入的话会破坏建树时的平衡性
所以要加入重构好麻烦不想写
查询
有点诡异的启发式搜索
有一个估算一个点到一个超矩形的最大/最小距离的操作
对于最近邻来说,先搜左右儿子中距离近的,并且只搜估算最近距离小于当前ans的
k近邻的话,用个大根堆,一直保持堆中有k个的元素
远的话换成远就行了QwQ
听说随机数据复杂度\(O(\log n)\)到\(O(n\sqrt{n})\) ,不会证不会证
代码实现
因为早就退役了所以我也没有做很多题来练习各种鬼畜用法的必要了扔模板就跑
带插入最近邻
#include#include #include #include #include using namespace std;const int N = 1e6+5, inf = 1e9;#define lc t[x].ch[0]#define rc t[x].ch[1]#define max0(x) max(x, 0)int n, m;int curD = 0;struct meow { int d[2]; meow() {} meow(int x, int y){d[0]=x; d[1]=y;} bool operator < (const meow &r) const { return d[curD] < r.d[curD]; } int calDist(meow &a) { return abs(d[0] - a.d[0]) + abs(d[1] - a.d[1]); }};meow a[N];struct node { int ch[2], x[2], y[2]; meow p; void update(node &a) { x[0] = min(x[0], a.x[0]); x[1] = max(x[1], a.x[1]); y[0] = min(y[0], a.y[0]); y[1] = max(y[1], a.y[1]); } void set(meow &a) { p = a; x[0] = x[1] = a.d[0]; y[0] = y[1] = a.d[1]; } int evaDist(meow &a) { int xx = a.d[0], yy = a.d[1]; return max0(x[0] - xx) + max0(xx - x[1]) + max0(y[0] - yy) + max0(yy - y[1]); }} t[N];int root;int build(int l, int r, int d) { curD = d; int x = (l+r)>>1; nth_element(a+l, a+x, a+r+1); t[x].set(a[x]); if(l < x) lc = build(l, x-1, d^1), t[x].update(t[lc]); if(x < r) rc = build(x+1, r, d^1), t[x].update(t[rc]); return x;}void insert(meow q) { t[++n].set(q); for(int x=root, D=0; x; D^=1) { t[x].update(t[n]); int &nxt = t[x].ch[q.d[D] >= t[x].p.d[D]]; if(nxt == 0) { nxt = n; break; } else x = nxt; }}int ans;void query(int x, meow q) { int nowDist = t[x].p.calDist(q), d[2]; d[0] = lc ? t[lc].evaDist(q) : inf; d[1] = rc ? t[rc].evaDist(q) : inf; int wh = d[1] <= d[0]; ans = min(ans, nowDist); if(d[wh] < ans) query(t[x].ch[wh], q); wh ^= 1; if(d[wh] < ans) query(t[x].ch[wh], q);}int main() { cin >> n >> m; int c, x, y; for(int i=1; i<=n; i++) { scanf("%d %d", &x, &y); a[i] = meow(x, y); } root = build(1, n, 0); for(int i=1; i<=m; i++) { scanf("%d %d %d", &c, &x, &y); if(c == 1) insert(meow(x, y)); else { ans = inf; query(root, meow(x, y)); printf("%d\n", ans); } }}
k远点对
每个点求一次k远点
值得注意的是会TLE所以整体用一个大根堆才行
#include#include #include #include #include #include #include using namespace std;typedef long long ll;const int N = 1e5+5;const ll inf = 1e18;#define lc t[x].ch[0]#define rc t[x].ch[1]#define max0(x) max(x, 0)inline ll sqr(ll x) {return x*x;}int n, K;int curD = 0;struct meow { ll d[2]; meow() {} meow(ll x, ll y){d[0]=x; d[1]=y;} bool operator < (const meow &r) const { //if(d[curD] == r.d[curD]) return d[curD^1] < r.d[curD^1]; return d[curD] < r.d[curD]; } ll calDist(meow &a) { //return abs(d[0] - a.d[0]) + abs(d[1] - a.d[1]); return sqr(d[0] - a.d[0]) + sqr(d[1] - a.d[1]); }};meow a[N];struct node { int ch[2], x[2], y[2]; meow p; void update(node &a) { x[0] = min(x[0], a.x[0]); x[1] = max(x[1], a.x[1]); y[0] = min(y[0], a.y[0]); y[1] = max(y[1], a.y[1]); } void set(meow &a) { p = a; x[0] = x[1] = a.d[0]; y[0] = y[1] = a.d[1]; } ll evaMaxDist(meow &a) { ll xx = a.d[0], yy = a.d[1]; return max(sqr(x[0]-xx), sqr(x[1]-xx)) + max(sqr(y[0]-yy), sqr(y[1]-yy)); }} t[N];int root;int build(int l, int r, int d) { curD = d; int x = (l+r)>>1; nth_element(a+l, a+x, a+r+1); t[x].set(a[x]); if(l < x) { lc = build(l, x-1, d^1); t[x].update(t[lc]); } if(x < r) { rc = build(x+1, r, d^1); t[x].update(t[rc]); } return x;}priority_queue , greater > ans;void query(int x, meow q) { ll nowDist = t[x].p.calDist(q), d[2]; d[0] = lc ? t[lc].evaMaxDist(q) : -inf; d[1] = rc ? t[rc].evaMaxDist(q) : -inf; int wh = d[1] >= d[0]; if(nowDist > ans.top()) ans.pop(), ans.push(nowDist); if(d[wh] > ans.top()) query(t[x].ch[wh], q); wh ^= 1; if(d[wh] > ans.top()) query(t[x].ch[wh], q);}int main() { cin >> n >> K; K <<= 1; int x, y; for(int i=1; i<=n; i++) { scanf("%d %d", &x, &y); a[i] = meow(x, y); } root = build(1, n, 0); for(int j=1; j<=K; j++) ans.push(-inf); for(int i=1; i<=n; i++) { query(root, a[i]); } cout << ans.top() << endl;}