Matlab - k-d tree nearest neighbour search

From XennisWiki
Jump to: navigation, search

Matlab functions to create a k-d tree for a given point cloud and compute the nearest neighbours according to this tree.

Animation of NN searching with a k-d tree in two dimensions

kdtreeClass.m

classdef kdtreeClass
% Class defines a k-d tree
%    
    properties
        threshold;
        direction;
        yes;
        no;
    end
    
    methods
    end
    
end

kdtreeCreate.m

function [ tree ] = kdtreeCreate( P, direction )
% Creates and returns a k-d tree
%
% Input
%	P: point cloud
%	direction: direction of the partitioning, 1 for x- and 2 for y-direction
%
% Output
%	tree: kd-tree with properties
%

% we always start with the x-direction when partitioning the intervals
if nargin < 2
    direction = 1;
end

tree = kdtreeClass;
tree.direction = direction;

P_direction = P(:, tree.direction);
if tree.direction == 1
    nextDirection = 2;
else
    nextDirection = 1;
end


tree.threshold = mean(P_direction);

% case: yes
P_yes = P(P_direction<tree.threshold, :);
[N, ~] = size(P_yes);
if N >= 2
    tree.yes = kdtreeCreate(P_yes, nextDirection);
else
    tree.yes = P_yes;
end

% case: no
P_no = P(P_direction>=tree.threshold, :);
[N, ~] = size(P_no);
if N >= 2
    tree.no = kdtreeCreate(P_no, nextDirection);
else
    tree.no = P_no;
end

end

kdtreeClassify.m

function [ p_nn ] = kdtreeClassify( tree, p )
% Use the given k-d tree to execute a nearest neighbour search, i.e. find
% the nearest neighbour of the given point p
%
% Input
%    tree: k-d tree
%    p: 2D point
%
% Output
%    p_nn: nearest neighbour of input point p
%
if (p(tree.direction) < tree.threshold)
    [~, N] = size(tree.yes);
    if N == 1
        p_nn = kdtreeClassify(tree.yes, p);        
    else
        p_nn = tree.yes;
    end
else
    [~, N] = size(tree.no);
    if N == 1
        p_nn = kdtreeClassify(tree.no, p);
    else
        p_nn = tree.no;
    end
end

main.m

% Sample point cloud
Q = [2,   3;
     3,   0.5;
     4,   3;
     3.5, 2;
     1,   1.2;
     2,   1];

%figure()
%	plot(Q(:,1), Q(:,2), 'rx', 'MarkerSize', 15)
%	grid on
%	axis([0 5 0 4])

% Create the k-d tree for the point cloud
tree_Q = kdtreeCreate(Q)

% Sample points
P = [4,   1;
     1.5, 2.5;
     3,   4];

% Compute the nearest neighbours
for i = 1:3
    fprintf('p=(%g,%g)\n', P(i,1), P(i,2))
    p_nn = kdtreeClassify(tree_Q, P(i,:));
    fprintf('\tnearest neighbour is  (%g,%g)\n', p_nn(1), p_nn(2))    
end

See also