最近在看机器学习的相关入门材料,看到目录里面有k-d树,想到以前做竞赛的时候一直留的坑,这里算是补上了。
K-d树的理解
k-d树的主要作用是能够在k维的样本空间中比较快速的寻找到离目标点 距离最近的k个点,这种问题在数据量小的时候我们可以通过直接暴力或者暴力加二分的简单方法去做,不过这种做法很有局限性而且没有完全利用样本空间的信息。与这种线性的方法相比,我们的k-d树要巧妙很多。
k-d树是每个节点都为k维点的二叉树。 —-维基百科
维基百科上的这种说法固然没错,不过给一开始初学的我造成了很大的麻烦。我个人感觉从代码上入手k-d树会快很多。从建树的代码上看,k-d树是一棵滚动划分各个维度的二叉树,这样每经过k轮划分后,我们的样本空间就会被划分成很多很多小的样本空间。这样一来,我们的目标点就会在其中某一个小的空间中,由于之前每次划分都会把样本空间一分为二,我们只需要沿着树找到目标点所在的那个”小空间”,在回溯上来,在这个回溯的过程中,我们就能逐渐找到”k个邻近的点”。
上面是对k-d树的建树和查找一个比较简约和形象的描述,具体的细节还要看代码,一般实现的k-d树的整个代码并不算多,不过在编写上很是巧妙。拿HDU 4347来说,我看过不少人写的k-d树,有洋洋洒洒200行的,也有压缩到近30行的(可见discuss),各有各的好处,我最后综合了一下在50行左右,可以看下面两道题的实现。
我知道竞赛题和实际应用还是很有一段距离的,这里作为一个入门,先写在这里,以后如果真正碰到了再总结吧~
HDU 4347
这题是我一开始拿来学k-d树的题,看过k-d树的概念以后,discuss里面的那份代码帮了我很多,推荐有兴趣的去看一下。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cmath>
#define LL long long
#define maxn 500005
#define lson_n n<<1
#define rson_n n<<1|1
using namespace std;
int n,k,t,idx,m,d[maxn<<2];
struct node
{
LL dis;
int p;
node(){dis=0; p=0;}
bool operator <(const node &a)const
{
return dis<a.dis;
}
}ans[15];
priority_queue<node> q;
struct point
{
int v[10];
point(){}
void setup(int k)
{
for (int i=0;i<k;i++) scanf("%d",&v[i]);
}
bool operator <(const point &a)const
{
return v[idx]<a.v[idx];
}
void PRI(int k)
{
for (int i=0;i<k;i++)
{
if (i) printf(" ");
printf("%d",v[i]);
}
}
}a[maxn];
void buildtree(int n,int l,int r,int dep)
{
if (l>r) return;
idx=dep%k;
int mid=(l+r)>>1;
nth_element(a+l,a+mid,a+r+1);
d[n]=mid;
buildtree(lson_n,l,mid-1,dep+1);
buildtree(rson_n,mid+1,r,dep+1);
}
void SWAP(int &x,int &y) { x^=y^=x^=y; }
LL sqr(LL a) { return a*a; }
void query(point A,int n,int dep)
{
if (d[n]==-1) return ;
bool f=0;
int dim=dep%k,ls=lson_n,rs=rson_n;
node cur;
cur.p=d[n];
for (int i=0;i<k;i++) cur.dis+=sqr(a[d[n]].v[i]-A.v[i]);
if (A.v[dim]> a[d[n]].v[dim] ) SWAP(ls,rs);
if (~d[ls]) query(A,ls,dep+1);
if (q.size()<m) q.push(cur),f=1;
else
{
if (cur.dis<q.top().dis)
{
q.pop(); q.push(cur);
}
if ( sqr(A.v[dim]-a[d[n]].v[dim])<q.top().dis ) f=1;
}
if (~d[rs] && f) query(A,rs,dep+1);
}
int main()
{
while (~scanf("%d%d",&n,&k))
{
memset(d,-1,sizeof(d));
for (int i=0;i<n;i++) a[i].setup(k);
scanf("%d",&t);
buildtree(1,0,n-1,0);
while (t--)
{
point A;
A.setup(k);
scanf("%d",&m);
query(A,1,0);
int h=0;
while (!q.empty())
{
ans[h++]=q.top();
q.pop();
}
printf("the closest %d points are:\n",m);
for (int i=h-1;i>=0;i--)
{
a[ans[i].p].PRI(k);
printf("\n");
}
}
}
return 0;
}
HDU 5992 (2016青岛ICPC现场K题)
这一题对我来说,是很有意义的一题,它帮助我取得了我ACM第一块银牌,也是第一块奖牌,我至今还记得和老师在青岛吃的那顿海鲜——那味道我一辈子忘不了。
这一题我估计本意是金银分界题,整个金银铜按照 5 、4 、3 分区。结果G题的log网络流不少人都没看出来,有些看出来了的TLE在了double上。当时有不少队伍就在K题上尝试了暴力加二分的写法——我们队当时依靠着这一题的优势,最终打进银牌区。
现在来看,这一题也算是比较裸的k-d树,对于整个点空间建树,用代价C在查询上做约束,一点点微小的不同。由于这一题还只要算最邻近,省去了优先队列,最终代码量很少。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#define maxn 200005
#define maxm 20005
#define LL long long
#define lson_n n<<1
#define rson_n n<<1|1
using namespace std;
int T,n,m,idx,d[maxn<<2],C;
bool flag;
struct point
{
int v[2],c,id;
point(){v[0]=v[1]=c=id=0;}
bool operator <(const point &a)const
{
return v[idx]<a.v[idx];
}
}a[maxn];
struct node
{
LL dis;
int p,id;
node(){ dis=p=id=0; }
void reset(){ dis=p=id=0; }
bool operator <(const node &a)const
{
if (dis!=a.dis) return dis<a.dis;
return id<a.id;
}
}ans;
void buildtree(int n,int l,int r,int dep)
{
if (l>r) return;
idx=dep%2; int mid=(l+r)>>1;
nth_element(a+l,a+mid,a+r+1);
d[n]=mid;
buildtree(lson_n,l,mid-1,dep+1);
buildtree(rson_n,mid+1,r,dep+1);
}
LL sqr(LL a){ return a*a; }
void SWAP(int &x,int &y){ x^=y^=x^=y; }
void query(point A,int n,int dep)
{
bool f=0;
int dim=dep%2,ls=lson_n,rs=rson_n;
node cur;
cur.p=d[n]; cur.id=a[d[n]].id;
for (int i=0;i<2;i++) cur.dis+=sqr(A.v[i]-a[d[n]].v[i]);
if (A.v[dim]>a[d[n]].v[dim]) SWAP(ls,rs);
if (~d[ls]) query(A,ls,dep+1);
if (!flag && a[cur.p].c<=C) ans=cur,f=1,flag=1;
else if (cur<ans && a[cur.p].c<=C) ans=cur;
if (sqr(A.v[dim]-a[d[n]].v[dim])<ans.dis) f=1;
if (~d[rs] && (f || !flag) ) query(A,rs,dep+1);
}
int main()
{
scanf("%d",&T);
while (T--)
{
memset(d,-1,sizeof(d));
scanf("%d%d",&n,&m);
for (int i=0;i<n;i++) scanf("%d%d%d",&a[i].v[0],&a[i].v[1],&a[i].c),a[i].id=i;
buildtree(1,0,n-1,0);
while (m--)
{
point A;
scanf("%d%d%d",&A.v[0],&A.v[1],&C);
ans.reset(); flag=0;
query(A,1,0);
printf("%d %d %d\n", a[ans.p].v[0],a[ans.p].v[1],a[ans.p].c);
}
}
return 0;
}