2013-07-17

My KDTree Implementation - it's super effective!

New page

<code><syntaxhighlight>

/*

** KDTree.java by Julian Kent

** Licenced under the Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License

** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/

** For additional usage rights please contact jkflying@gmail.com

**

** Example usage is given in the main method, as well as benchmarking code against Rednaxela's Gen2 Tree

*/

package jk.mega;

import java.util.ArrayDeque;

import java.util.ArrayList;

import java.util.Arrays;

//import ags.utils.*;

public class KDTree<T>{

//use a big bucketSize so that we have less node bounds (for more cache hits) and better splits

private static final int _bucketSize = 64;

private final int _dimensions;

private int _nodes;

private Node root;

//prevent GC from having to collect _bucketSize*dimensions*8 bytes each time a leaf splits

private double[] mem_recycle;

//the starting values for bounding boxes, for easy access

private final double[] bounds_template;

//one big self-expanding array to keep all the node bounding boxes so that they stay in cache

// node bounds available at:

//low: 2 * _dimensions * node.index

//high: 2 * _dimensions * node.index + _dimensions

private ContiguousDoubleArrayList nodeMinMaxBounds;

/*

public static void main(String[] args){

int dims = 12;

int size = 20000;

int testsize = 200;

int k = 40;

int iterations = 3;

System.out.println(

"Config:\n"

+ "No JIT Warmup\n"

+ "Tested on random data.\n"

+ "Training and testing points shared across iterations.\n"

+ "Searches interleaved.");

System.out.println("Num points: " + size);

System.out.println("Num searches: " + testsize);

System.out.println("Dimensions: " + dims);

System.out.println("Num Neighbours: " + k);

System.out.println();

ArrayList<double[]> locs = new ArrayList<double[]>(size);

for(int i = 0; i < size; i++){

double[] loc = new double[dims];

for(int j = 0; j < dims; j++)

loc[j] = Math.random();

locs.add(loc);

}

ArrayList<double[]> testlocs = new ArrayList<double[]>(testsize);

for(int i = 0; i < testsize; i++){

double[] loc = new double[dims];

for(int j = 0; j < dims; j++)

loc[j] = Math.random();

testlocs.add(loc);

}

for(int r = 0; r < iterations; r++){

long t1 = System.nanoTime();

KDTree<double[]> t = new KDTree<double[]>(dims);// This tree

for(int i = 0; i < size; i++){

t.addPoint(locs.get(i),locs.get(i));

}

long t2 = System.nanoTime();

KdTree<double[]> rt = new KdTree.Euclidean<double[]>(dims,null); //Rednaxela Gen2

for(int i = 0; i < size; i++){

rt.addPoint(locs.get(i),locs.get(i));

}

long t3 = System.nanoTime();

long jtn = 0;

long rtn = 0;

long mjtn = 0;

long mrtn = 0;

double dist1 = 0, dist2 = 0;

for(int i = 0; i < testsize; i++){

long t4 = System.nanoTime();

dist1 += t.nearestNeighbours(testlocs.get(i),k).iterator().next().distance;

long t5 = System.nanoTime();

dist2 += rt.nearestNeighbor(testlocs.get(i),k,true).iterator().next().distance;

long t6 = System.nanoTime();

long t7 = System.nanoTime();

jtn += t5 - t4 - (t7 - t6);

rtn += t6 - t5 - (t7 - t6);

mjtn = Math.max(mjtn,t5 - t4 - (t7 - t6));

mrtn = Math.max(mrtn,t6 - t5 - (t7 - t6));

}

System.out.println("Accuracy: " + (Math.abs(dist1-dist2) < 1e-10?"100%":"BROKEN!!!"));

if(Math.abs(dist1-dist2) > 1e-10){

System.out.println("dist1: " + dist1 + " dist2: " + dist2);

}

long jts = t2 - t1;

long rts = t3 - t2;

System.out.println("Iteration: " + (r+1) + "/" + iterations);

System.out.println("This tree add avg: " + jts/size + " ns");

System.out.println("Reds tree add avg: " + rts/size + " ns");

System.out.println("This tree knn avg: " + jtn/testsize + " ns");

System.out.println("Reds tree knn avg: " + rtn/testsize + " ns");

System.out.println("This tree knn max: " + mjtn + " ns");

System.out.println("Reds tree knn max: " + mrtn + " ns");

System.out.println();

}

}

*/

public KDTree(int dimensions){

_dimensions = dimensions;

nodeMinMaxBounds = new ContiguousDoubleArrayList(2 * dimensions);

mem_recycle = new double[_bucketSize*dimensions];

bounds_template = new double[2*_dimensions];

Arrays.fill(bounds_template,0,_dimensions,Double.POSITIVE_INFINITY);

Arrays.fill(bounds_template,_dimensions,2*_dimensions,Double.NEGATIVE_INFINITY);

//and.... start!

root = new Node();

}

public int nodes(){

return _nodes;

}

public int addPoint(double[] location, T payload){

Node addNode = root;

//Do a Depth First Search to find the node where it should be stored

while(addNode.pointLocations == null){

addNode.expandBounds(location);

if(location[addNode.splitDim] < addNode.splitVal)

addNode = addNode.less;

else

addNode = addNode.more;

}

addNode.expandBounds(location);

int nodeSize = addNode.add(location,payload);

if(nodeSize % _bucketSize == 0)

//try splitting again once every time the node passes a _bucketSize multiple

addNode.split();

return root.entries;

}

public SearchResult<T> nearestNeighbour(double[] searchLocation){

Node searchNode = root;

ArrayDeque<Node> stack = new ArrayDeque<Node>(50);

//Do a Depth First Search to find the Node this location would be stored

while(searchNode.pointLocations == null){

if(searchLocation[searchNode.splitDim] < searchNode.splitVal){

stack.push(searchNode.more);

searchNode = searchNode.less;

}

else{

stack.push(searchNode.less);

searchNode = searchNode.more;

}

}

double minDist = Double.POSITIVE_INFINITY;

T minValue = null;

double[] array = searchNode.pointLocations.array;

//Find the closest point in this Node and use as a solution

for(int j = searchNode.entries; j-- > 0;){

double distance = searchNode.pointDist(searchLocation,j);

if(distance < minDist){

minDist = distance;

minValue = searchNode.pointPayloads.get(j);

}

}

//backtrace stack

while(stack.size() > 0){

searchNode = stack.pop();

if(searchNode.pointRectDist(searchLocation) < minDist){

if(searchNode.pointLocations == null){

if(searchLocation[searchNode.splitDim] < searchNode.splitVal){

stack.push(searchNode.more);

stack.push(searchNode.less);

}

else{

stack.push(searchNode.less);

stack.push(searchNode.more);

}

}

else{

array = searchNode.pointLocations.array;

for(int j = searchNode.entries; j-- > 0;){

double distance = searchNode.pointDist(searchLocation,j);

if(distance < minDist){

minDist = distance;

minValue = searchNode.pointPayloads.get(j);

}

}

}

}

}

return new SearchResult(minDist,minValue);

}

public ArrayList<SearchResult<double[]>> nearestNeighbours(double[] searchLocation, int K){

Node searchNode = root;

ArrayDeque<Node> stack = new ArrayDeque<Node>(50);

//Do a Depth First Search to find the Node this location would be stored

while(searchNode.pointLocations == null){

if(searchLocation[searchNode.splitDim] < searchNode.splitVal){

stack.push(searchNode.more);

searchNode = searchNode.less;

}

else{

stack.push(searchNode.less);

searchNode = searchNode.more;

}

}

PrioQueue<T> results = new PrioQueue<T>(K);

ArrayList<T> payloads = searchNode.pointPayloads;

double[] temp = new double[_dimensions];

//Find the closest point in this Node and use as a solution

for(int j = searchNode.entries; j-- > 0;){

double distance = searchNode.pointDist(searchLocation,j);

results.offer(payloads.get(j),-distance);

}

//backtrace stack

while(stack.size() > 0){

searchNode = stack.pop();

if( searchNode.pointRectDist(searchLocation) < -results.peekPrio()){

if(searchNode.pointLocations == null){

if(searchLocation[searchNode.splitDim] < searchNode.splitVal){

stack.push(searchNode.more);

stack.push(searchNode.less);//less will be popped first

}

else{

stack.push(searchNode.less);

stack.push(searchNode.more);

}

}

else{

payloads = searchNode.pointPayloads;

for(int j = searchNode.entries; j-- > 0;){

double distance = searchNode.pointDist(searchLocation,j);

results.offer(payloads.get(j),-distance);

}

}

}

}

ArrayList<SearchResult<double[]>> returnResults = new ArrayList<SearchResult<double[]>>(results.elements.size());

//for(int i =0, j = results.elements.size(); i<j;i++){//Forward (closest first)

for(int i = results.elements.size(); i-- > 0;){//Reverse (Like Rednaxela Gen2)

PrioQueue<T>.Element e = results.elements.get(i);

SearchResult s = new SearchResult(-e.priority,e.contents);

returnResults.add(s);

}

return returnResults;

}

//NB! This Priority Queue keeps things with the HIGHEST priority.

//If you want lowest priority items kept, negate your values

private static class PrioQueue<S>{

ArrayList<Element> elements;

private double minPrio;

PrioQueue(int size){

elements = new ArrayList<Element>(size);

while(size-->0){

elements.add(new Element(null,Double.NEGATIVE_INFINITY));

}

minPrio = Double.NEGATIVE_INFINITY;

}

//uses O(log(n)) comparisons and one big shift of size O(N)

//and is MUCH simpler than a heap --> faster JIT

boolean offer(S value,double priority){

//is this point worthy of joining the exulted ranks?

if(priority > minPrio){

//recycle object to avoid garbage collector stalls

Element replace = elements.remove(elements.size() - 1);

replace.update(value,priority);

add(replace);

return true;

}

return false;

}

void add(Element e){

//find the right place with a binary search

int index = searchFor(e.priority);

//and re-insert updated value (ArrayList automatically shifts other elements up)

elements.add(index,e);

minPrio = elements.get(elements.size() - 1).priority;

}

int searchFor(double priority){

int i = elements.size()-1;

int j = 0;

while(i>=j){

int index = (i+j)>>1;

if(elements.get(index).priority < priority)

i = index-1;

else

j = index+1;

}

return j;

}

double peekPrio(){

return minPrio;

}

/* //Methods for using it as a priority stack - leave them out for now

void push(S value, double priority){

Element insert = new Element(value,priority);

add(insert);

}

S pop(){

Element remove = elements.remove(elements.size() - 1);

if(elements.size() == 0)

minPrio = Double.NEGATIVE_INFINITY;

else

minPrio = elements.get(elements.size() - 1).priority;

return remove.contents;

}

int size(){

return elements.size();

}

void trim(double newMinPrio){

if(newMinPrio > minPrio){

int index = searchFor(newMinPrio);

int size = elements.size();

elements.subList(index,elements.size()).clear();

if(elements.size() == 0)

minPrio = Double.NEGATIVE_INFINITY;

else

minPrio = elements.get(elements.size() - 1).priority;

}

}

// */

class Element{

S contents;

double priority;

Element(S con, double prio){

contents = con;

priority = prio;

}

void update(S con, double prio){

contents = con;

priority = prio;

}

}

}

public static class SearchResult<S>{

double distance;

S payload;

SearchResult(double dist, S load){

distance = dist;

payload = load;

}

}

private class Node {

//for accessing bounding box data

// - if trees weren't so unbalanced might be better to use an implicit heap?

int index;

//keep track of size of subtree

int entries;

//leaf

ContiguousDoubleArrayList pointLocations ;

ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize);

//stem

Node less, more;

int splitDim;

double splitVal;

private Node(){

this(new double[_bucketSize*_dimensions]);

}

private Node(double[] pointMemory){

pointLocations = new ContiguousDoubleArrayList(pointMemory);

index = _nodes++;

nodeMinMaxBounds.add(bounds_template);

}

private final double pointRectDist(double[] location){

int minOffset = 2*index*_dimensions;

int maxOffset = minOffset+_dimensions;

double distance=0;

double[] array = nodeMinMaxBounds.array;

for(int i = _dimensions; i-- > 0; ){

double lowDist = array[i+minOffset] - location[i];

if(lowDist > 0)

distance += sqr(lowDist);

else{

double highDist = location[i] - array[i+maxOffset];

if(highDist > 0)

distance += sqr(highDist);

}

}

return distance;

}

private final double pointDist(double[] location, int index){

double distance = 0;

int offset = index*_dimensions;

for(int i = _dimensions; i-- > 0;)

distance += sqr(pointLocations.array[offset+i] - location[i]);

return distance;

}

private void expandBounds(double[] location){

entries++;

int offset = index*2*_dimensions;

for(int i = 0; i < _dimensions;i++){

nodeMinMaxBounds.array[offset+i] = Math.min(nodeMinMaxBounds.array[offset+i],location[i]);

nodeMinMaxBounds.array[offset+_dimensions+i] = Math.max(nodeMinMaxBounds.array[offset+_dimensions+i],location[i]);

}

}

private int add(double[] location, T load){

pointLocations.add(location);

pointPayloads.add(load);

return entries;

}

private void split(){

double diff = 0;

int offset = index*2*_dimensions;

for(int i = 0; i < _dimensions; i++){

double min = nodeMinMaxBounds.array[offset+i];

double max = nodeMinMaxBounds.array[offset+_dimensions+i];

if(max - min > diff){

diff = max - min;

splitVal = 0.5*(max + min);

splitDim = i;

}

}

less = new Node(mem_recycle);//recycle that memory!

more = new Node();

//reduce garbage by factor of _bucketSize by recycling this array

double[] pointLocation = new double[_dimensions];

for(int i = 0; i < entries; i++){

System.arraycopy(pointLocations.array,i*_dimensions,pointLocation,0,_dimensions);

T load = pointPayloads.get(i);

if(pointLocation[splitDim] < splitVal){

less.expandBounds(pointLocation);

less.add(pointLocation,load);

}

else{

more.expandBounds(pointLocation);

more.add(pointLocation,load);

}

}

if(less.entries*more.entries == 0){

//one of them was 0, so the split was worthless. throw it away.

less = null;

more = null;

}

else{

//we won't be needing that now, so keep it for the next split to reduce garbage

mem_recycle = pointLocations.array;

pointLocations = null;

pointPayloads.clear();

pointPayloads = null;

}

}

}

private static class ContiguousDoubleArrayList{

double[] array;

int size;

ContiguousDoubleArrayList(){

this(300);

}

ContiguousDoubleArrayList(int size){

this(new double[size]);

}

ContiguousDoubleArrayList(double[] data){

array = data;

}

ContiguousDoubleArrayList add(double[] da){

if(size + da.length >= array.length){

array = Arrays.copyOf(array,(array.length+da.length)*2);

}

System.arraycopy(da,0,array,size,da.length);

size += da.length;

return this;

}

}

private static final double sqr(double d){

return d*d;}

}

</syntaxhighlight></code>

Show more