kd樹的簡單實現

LZEP2015發表於2014-10-23

 kd樹在特徵匹配中有著重要的作用,關於kd樹的詳解這裡就不再介紹了,下面給出了kd樹的生成過程:

       

#include<iostream>
#include<cstdlib>
#include<string>
#include<vector>
#include <algorithm>
using namespace std;
#define K 2
typedef float ElemType;
typedef int RangeInt;
typedef int SplitInt;
typedef int SampleInt;
typedef struct Kd_Tree
{
  ElemType Node_Data[K];
  ElemType Range;
  SplitInt Split;
  struct Kd_Tree* Left;
  struct Kd_Tree* Right;
  struct Kd_Tree* Parent;	
}Kd_Tree,*_Kd_Tree;


int Find_Max_Range_Item(vector< vector<ElemType> > &Data_Set,SplitInt *Space_Range,SampleInt& nSample)
{
	ElemType* avg=(ElemType*)malloc(nSample*sizeof(ElemType));
	memset((void*)avg,0.0,nSample*sizeof(ElemType));
	ElemType* var=(ElemType*)malloc(nSample*sizeof(ElemType));
	memset(var,0.0,nSample*sizeof(ElemType));
	ElemType max=0.0;
	int index=0;
	for(int j=0;j<=K;j++)
	{
	   for(int i=0;i<nSample;i++)
	   {
	 	  if(j<K) avg[j]+=Data_Set[i][j];
	 	  if(j>0) var[j-1]+=(Data_Set[i][j-1]-avg[j-1])*(Data_Set[i][j-1]-avg[j-1]);
	   }
	   if(j<K) avg[j]/=K; 
	   if(j>0) { var[j-1]/=K; if(var[j-1]>max) { max=var[j-1]; index=j-1; } }
	}
	if(avg) free(avg);
	if(var) free(var); 
	return index;
} 
void swap(vector< vector<ElemType> > &Data_Set,int i,int j)  
{
    vector<ElemType> tmp(K);
	for(int t=0;t<K;t++) 
	{
		tmp[t]=Data_Set[i][t];
		Data_Set[i][t]=Data_Set[j][t];
		Data_Set[j][t]=tmp[t];
	}
  
}  
int partition(vector< vector<ElemType> > &Data_Set,int Range_Item_value,int left,int right,int nSample)  
{  
    ElemType base_element=Data_Set[left][Range_Item_value];  
    int i=left;  
    int j=right+1;  
    for(;;)  
    {  
        while(Data_Set[++i][Range_Item_value]<base_element&&(i+1 <nSample));  
        while(Data_Set[--j][Range_Item_value]>base_element&&(j-1>-1));  
        if(i<j) swap(Data_Set,i,j);  
        else break;  
    }  
    swap(Data_Set,left,j);  
    return  j;  
}  

void _quick_sort(vector< vector<ElemType> > &Data_Set,int Range_Item_Value,int left,int right,int nSample)
{
	if(left<right)
	{
		int pivot=partition(Data_Set,Range_Item_Value,left,right,nSample);
		_quick_sort(Data_Set,Range_Item_Value,left,pivot-1,nSample);
		_quick_sort(Data_Set,Range_Item_Value,pivot+1,right,nSample);
	}
}
void Sort_Data_Set(vector< vector<ElemType> > &Data_Set,SampleInt& nSample,int Range_Item_Value)
{
	_quick_sort(Data_Set,Range_Item_Value,0,nSample-1,nSample);
}
/*確定split域:對於所有描述子資料(特徵向量),統計它們在每個維上的資料I方差。
以SURF特徵為例,描述子為64維,可計算64個方差。挑選出最大值,對應的域就是split
域的值。資料方差大表明沿該座標軸方向上的資料分散的比較開,在這個方向上的資料分散
得比較開,在這個方向上進行資料分割有著較好的解析度)*/ 
bool Create_KD_Tree(_Kd_Tree* pTree,_Kd_Tree *tree,vector< vector<ElemType> > &Data_Set,SplitInt *Space_Range,SampleInt& nSample)
{
	if(Data_Set.size()==0)
	{
	   printf("finish\n");
	   return false;
	}
	
	int Range_Item_Value=Find_Max_Range_Item(Data_Set,Space_Range,nSample);
	printf("Item=%d\n",Range_Item_Value);
    Sort_Data_Set(Data_Set,nSample,Range_Item_Value);
    
    printf("\n");
    for(int i=0;i<nSample;i++)
    {
    	for(int j=0;j<K;j++)
        {
     	printf("%f ",Data_Set[i][j]);
        }
        printf("\n");
    }
    printf("\n");
    int mid_nSample=nSample/2;
    printf("mid=%d\n",mid_nSample);
    int Left_nSample=mid_nSample;
    int Right_nSample=nSample-Left_nSample-1;
    //int* left_Space_Range=(int*)malloc(Left_nSample*sizeof(int));
    //int* right_Space_Range=(int*)malloc(Right_nSample*sizeof(int));
    //memcpy(left_Space_Range,Space_Range,Left_nSample*sizeof(int));
    //memcpy(right_Space_Range,Space_Range+Left_nSample,Right_nSample*sizeof(int));
    vector< vector<ElemType> > Left_Data_Set(Left_nSample,vector<ElemType>(K));
    vector< vector<ElemType> > Right_Data_Set(Right_nSample,vector<ElemType>(K));
    
    for(int i=0;i<Left_nSample;i++)
    {
   	  for(int j=0;j<K;j++)
    	Left_Data_Set[i][j]=Data_Set[i][j];
    }
    
    for(int i=0;i<Right_nSample;i++)
    {
   	  for(int j=0;j<K;j++)
    	Right_Data_Set[i][j]=Data_Set[i+Left_nSample+1][j];
    }
    
    (*tree)->Split=Range_Item_Value;
	(*tree)->Range=Data_Set[mid_nSample][Range_Item_Value];
	for(int i=0;i<K;i++)
	  (*tree)->Node_Data[i]=Data_Set[mid_nSample][i];
	printf("data=%f\n",Data_Set[mid_nSample][0]);
	if(Left_nSample>1)
	{
	  (*tree)->Left=(_Kd_Tree)malloc(sizeof(Kd_Tree));
	  Create_KD_Tree(tree,&((*tree)->Left),Left_Data_Set,Space_Range,Left_nSample);	
	}
	else if(Left_nSample==1)
    {
      (*tree)->Left=(_Kd_Tree)malloc(sizeof(Kd_Tree));
      for(int i=0;i<K;i++)
	  (*tree)->Left->Node_Data[i]=Data_Set[Left_nSample-1][i];
	  (*tree)->Left->Range=-1.0;
	  (*tree)->Left->Split=-1.0;
	  (*tree)->Left->Left=NULL;
	  (*tree)->Left->Right=NULL;
	  (*tree)->Left->Parent=(*tree);
    }
    else
	{
	  (*tree)->Left=NULL;
    }
    
	if(Right_nSample>1)
	{
	  (*tree)->Right=(_Kd_Tree)malloc(sizeof(Kd_Tree));
      Create_KD_Tree(tree,&((*tree)->Right),Right_Data_Set,Space_Range,Right_nSample);
    } 
    else if(Right_nSample==1)
    {
      (*tree)->Right=(_Kd_Tree)malloc(sizeof(Kd_Tree));
      for(int i=0;i<K;i++)
      (*tree)->Right->Node_Data[i]=Data_Set[nSample-Right_nSample][i];
	  (*tree)->Right->Node_Data[K];
	  (*tree)->Right->Range=-1.0;
	  (*tree)->Right->Split=-1.0;
	  (*tree)->Right->Left=NULL;
	  (*tree)->Right->Right=NULL;
	  (*tree)->Right->Parent=(*tree);	
    }
    else
	{
      (*tree)->Right=NULL;	
    }
    (*tree)->Parent=(*pTree);
	 
     //if(left_Space_Range)   free(left_Space_Range);
     //if(right_Space_Range)  free(right_Space_Range);
     return true;
	
}
void Print_KD_Tree(_Kd_Tree tree)
{
	if(!tree) return ;
	printf("(");
	for(int i=0;i<K;i++)
	  printf("%f ",tree->Node_Data[i]);
    printf(")\n");
    Print_KD_Tree(tree->Left);
    Print_KD_Tree(tree->Right);
}
int main()
{
   //create kd Tree
   _Kd_Tree pTree=NULL;
   _Kd_Tree Tree=(_Kd_Tree)malloc(sizeof(Kd_Tree));
   int nSample=6;
   SplitInt* Space_Range=(SplitInt*)malloc(K*sizeof(SplitInt));
   for(int i=0;i<K;i++)
     Space_Range[i]=i; 
   printf("Start Input data set!\n");
   vector< vector<ElemType> > Data_Set(nSample,vector<ElemType>(K));
   for(int i=0;i<nSample;i++)
     for(int j=0;j<K;j++)
     {
     	scanf("%f",&Data_Set[i][j]);
     }
   printf("Finish Input data set!\n");
   printf("Start create kd_tree!\n");
   Create_KD_Tree(&pTree,&Tree,Data_Set,Space_Range,nSample);
   printf("Finish Create kd_tree!\n");
   printf("Start Print kd_tree!\n");
   Print_KD_Tree(Tree);
   printf("Finish Print kd_tree!\n");
   if(Tree)  printf("success\n");
}




相關文章