My KDTree Implementation - it's super effective!
New page
<code><syntaxhighlight>
/*
** KDTree.java by Julian Kent
** Licenced under the Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License
** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/
** For additional usage rights please contact jkflying@gmail.com
**
** Example usage is given in the main method, as well as benchmarking code against Rednaxela's Gen2 Tree
*/
package jk.mega;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
//import ags.utils.*;
public class KDTree<T>{
//use a big bucketSize so that we have less node bounds (for more cache hits) and better splits
private static final int _bucketSize = 64;
private final int _dimensions;
private int _nodes;
private Node root;
//prevent GC from having to collect _bucketSize*dimensions*8 bytes each time a leaf splits
private double[] mem_recycle;
//the starting values for bounding boxes, for easy access
private final double[] bounds_template;
//one big self-expanding array to keep all the node bounding boxes so that they stay in cache
// node bounds available at:
//low: 2 * _dimensions * node.index
//high: 2 * _dimensions * node.index + _dimensions
private ContiguousDoubleArrayList nodeMinMaxBounds;
/*
public static void main(String[] args){
int dims = 12;
int size = 20000;
int testsize = 200;
int k = 40;
int iterations = 3;
System.out.println(
"Config:\n"
+ "No JIT Warmup\n"
+ "Tested on random data.\n"
+ "Training and testing points shared across iterations.\n"
+ "Searches interleaved.");
System.out.println("Num points: " + size);
System.out.println("Num searches: " + testsize);
System.out.println("Dimensions: " + dims);
System.out.println("Num Neighbours: " + k);
System.out.println();
ArrayList<double[]> locs = new ArrayList<double[]>(size);
for(int i = 0; i < size; i++){
double[] loc = new double[dims];
for(int j = 0; j < dims; j++)
loc[j] = Math.random();
locs.add(loc);
}
ArrayList<double[]> testlocs = new ArrayList<double[]>(testsize);
for(int i = 0; i < testsize; i++){
double[] loc = new double[dims];
for(int j = 0; j < dims; j++)
loc[j] = Math.random();
testlocs.add(loc);
}
for(int r = 0; r < iterations; r++){
long t1 = System.nanoTime();
KDTree<double[]> t = new KDTree<double[]>(dims);// This tree
for(int i = 0; i < size; i++){
t.addPoint(locs.get(i),locs.get(i));
}
long t2 = System.nanoTime();
KdTree<double[]> rt = new KdTree.Euclidean<double[]>(dims,null); //Rednaxela Gen2
for(int i = 0; i < size; i++){
rt.addPoint(locs.get(i),locs.get(i));
}
long t3 = System.nanoTime();
long jtn = 0;
long rtn = 0;
long mjtn = 0;
long mrtn = 0;
double dist1 = 0, dist2 = 0;
for(int i = 0; i < testsize; i++){
long t4 = System.nanoTime();
dist1 += t.nearestNeighbours(testlocs.get(i),k).iterator().next().distance;
long t5 = System.nanoTime();
dist2 += rt.nearestNeighbor(testlocs.get(i),k,true).iterator().next().distance;
long t6 = System.nanoTime();
long t7 = System.nanoTime();
jtn += t5 - t4 - (t7 - t6);
rtn += t6 - t5 - (t7 - t6);
mjtn = Math.max(mjtn,t5 - t4 - (t7 - t6));
mrtn = Math.max(mrtn,t6 - t5 - (t7 - t6));
}
System.out.println("Accuracy: " + (Math.abs(dist1-dist2) < 1e-10?"100%":"BROKEN!!!"));
if(Math.abs(dist1-dist2) > 1e-10){
System.out.println("dist1: " + dist1 + " dist2: " + dist2);
}
long jts = t2 - t1;
long rts = t3 - t2;
System.out.println("Iteration: " + (r+1) + "/" + iterations);
System.out.println("This tree add avg: " + jts/size + " ns");
System.out.println("Reds tree add avg: " + rts/size + " ns");
System.out.println("This tree knn avg: " + jtn/testsize + " ns");
System.out.println("Reds tree knn avg: " + rtn/testsize + " ns");
System.out.println("This tree knn max: " + mjtn + " ns");
System.out.println("Reds tree knn max: " + mrtn + " ns");
System.out.println();
}
}
*/
public KDTree(int dimensions){
_dimensions = dimensions;
nodeMinMaxBounds = new ContiguousDoubleArrayList(2 * dimensions);
mem_recycle = new double[_bucketSize*dimensions];
bounds_template = new double[2*_dimensions];
Arrays.fill(bounds_template,0,_dimensions,Double.POSITIVE_INFINITY);
Arrays.fill(bounds_template,_dimensions,2*_dimensions,Double.NEGATIVE_INFINITY);
//and.... start!
root = new Node();
}
public int nodes(){
return _nodes;
}
public int addPoint(double[] location, T payload){
Node addNode = root;
//Do a Depth First Search to find the node where it should be stored
while(addNode.pointLocations == null){
addNode.expandBounds(location);
if(location[addNode.splitDim] < addNode.splitVal)
addNode = addNode.less;
else
addNode = addNode.more;
}
addNode.expandBounds(location);
int nodeSize = addNode.add(location,payload);
if(nodeSize % _bucketSize == 0)
//try splitting again once every time the node passes a _bucketSize multiple
addNode.split();
return root.entries;
}
public SearchResult<T> nearestNeighbour(double[] searchLocation){
Node searchNode = root;
ArrayDeque<Node> stack = new ArrayDeque<Node>(50);
//Do a Depth First Search to find the Node this location would be stored
while(searchNode.pointLocations == null){
if(searchLocation[searchNode.splitDim] < searchNode.splitVal){
stack.push(searchNode.more);
searchNode = searchNode.less;
}
else{
stack.push(searchNode.less);
searchNode = searchNode.more;
}
}
double minDist = Double.POSITIVE_INFINITY;
T minValue = null;
double[] array = searchNode.pointLocations.array;
//Find the closest point in this Node and use as a solution
for(int j = searchNode.entries; j-- > 0;){
double distance = searchNode.pointDist(searchLocation,j);
if(distance < minDist){
minDist = distance;
minValue = searchNode.pointPayloads.get(j);
}
}
//backtrace stack
while(stack.size() > 0){
searchNode = stack.pop();
if(searchNode.pointRectDist(searchLocation) < minDist){
if(searchNode.pointLocations == null){
if(searchLocation[searchNode.splitDim] < searchNode.splitVal){
stack.push(searchNode.more);
stack.push(searchNode.less);
}
else{
stack.push(searchNode.less);
stack.push(searchNode.more);
}
}
else{
array = searchNode.pointLocations.array;
for(int j = searchNode.entries; j-- > 0;){
double distance = searchNode.pointDist(searchLocation,j);
if(distance < minDist){
minDist = distance;
minValue = searchNode.pointPayloads.get(j);
}
}
}
}
}
return new SearchResult(minDist,minValue);
}
public ArrayList<SearchResult<double[]>> nearestNeighbours(double[] searchLocation, int K){
Node searchNode = root;
ArrayDeque<Node> stack = new ArrayDeque<Node>(50);
//Do a Depth First Search to find the Node this location would be stored
while(searchNode.pointLocations == null){
if(searchLocation[searchNode.splitDim] < searchNode.splitVal){
stack.push(searchNode.more);
searchNode = searchNode.less;
}
else{
stack.push(searchNode.less);
searchNode = searchNode.more;
}
}
PrioQueue<T> results = new PrioQueue<T>(K);
ArrayList<T> payloads = searchNode.pointPayloads;
double[] temp = new double[_dimensions];
//Find the closest point in this Node and use as a solution
for(int j = searchNode.entries; j-- > 0;){
double distance = searchNode.pointDist(searchLocation,j);
results.offer(payloads.get(j),-distance);
}
//backtrace stack
while(stack.size() > 0){
searchNode = stack.pop();
if( searchNode.pointRectDist(searchLocation) < -results.peekPrio()){
if(searchNode.pointLocations == null){
if(searchLocation[searchNode.splitDim] < searchNode.splitVal){
stack.push(searchNode.more);
stack.push(searchNode.less);//less will be popped first
}
else{
stack.push(searchNode.less);
stack.push(searchNode.more);
}
}
else{
payloads = searchNode.pointPayloads;
for(int j = searchNode.entries; j-- > 0;){
double distance = searchNode.pointDist(searchLocation,j);
results.offer(payloads.get(j),-distance);
}
}
}
}
ArrayList<SearchResult<double[]>> returnResults = new ArrayList<SearchResult<double[]>>(results.elements.size());
//for(int i =0, j = results.elements.size(); i<j;i++){//Forward (closest first)
for(int i = results.elements.size(); i-- > 0;){//Reverse (Like Rednaxela Gen2)
PrioQueue<T>.Element e = results.elements.get(i);
SearchResult s = new SearchResult(-e.priority,e.contents);
returnResults.add(s);
}
return returnResults;
}
//NB! This Priority Queue keeps things with the HIGHEST priority.
//If you want lowest priority items kept, negate your values
private static class PrioQueue<S>{
ArrayList<Element> elements;
private double minPrio;
PrioQueue(int size){
elements = new ArrayList<Element>(size);
while(size-->0){
elements.add(new Element(null,Double.NEGATIVE_INFINITY));
}
minPrio = Double.NEGATIVE_INFINITY;
}
//uses O(log(n)) comparisons and one big shift of size O(N)
//and is MUCH simpler than a heap --> faster JIT
boolean offer(S value,double priority){
//is this point worthy of joining the exulted ranks?
if(priority > minPrio){
//recycle object to avoid garbage collector stalls
Element replace = elements.remove(elements.size() - 1);
replace.update(value,priority);
add(replace);
return true;
}
return false;
}
void add(Element e){
//find the right place with a binary search
int index = searchFor(e.priority);
//and re-insert updated value (ArrayList automatically shifts other elements up)
elements.add(index,e);
minPrio = elements.get(elements.size() - 1).priority;
}
int searchFor(double priority){
int i = elements.size()-1;
int j = 0;
while(i>=j){
int index = (i+j)>>1;
if(elements.get(index).priority < priority)
i = index-1;
else
j = index+1;
}
return j;
}
double peekPrio(){
return minPrio;
}
/* //Methods for using it as a priority stack - leave them out for now
void push(S value, double priority){
Element insert = new Element(value,priority);
add(insert);
}
S pop(){
Element remove = elements.remove(elements.size() - 1);
if(elements.size() == 0)
minPrio = Double.NEGATIVE_INFINITY;
else
minPrio = elements.get(elements.size() - 1).priority;
return remove.contents;
}
int size(){
return elements.size();
}
void trim(double newMinPrio){
if(newMinPrio > minPrio){
int index = searchFor(newMinPrio);
int size = elements.size();
elements.subList(index,elements.size()).clear();
if(elements.size() == 0)
minPrio = Double.NEGATIVE_INFINITY;
else
minPrio = elements.get(elements.size() - 1).priority;
}
}
// */
class Element{
S contents;
double priority;
Element(S con, double prio){
contents = con;
priority = prio;
}
void update(S con, double prio){
contents = con;
priority = prio;
}
}
}
public static class SearchResult<S>{
double distance;
S payload;
SearchResult(double dist, S load){
distance = dist;
payload = load;
}
}
private class Node {
//for accessing bounding box data
// - if trees weren't so unbalanced might be better to use an implicit heap?
int index;
//keep track of size of subtree
int entries;
//leaf
ContiguousDoubleArrayList pointLocations ;
ArrayList<T> pointPayloads = new ArrayList<T>(_bucketSize);
//stem
Node less, more;
int splitDim;
double splitVal;
private Node(){
this(new double[_bucketSize*_dimensions]);
}
private Node(double[] pointMemory){
pointLocations = new ContiguousDoubleArrayList(pointMemory);
index = _nodes++;
nodeMinMaxBounds.add(bounds_template);
}
private final double pointRectDist(double[] location){
int minOffset = 2*index*_dimensions;
int maxOffset = minOffset+_dimensions;
double distance=0;
double[] array = nodeMinMaxBounds.array;
for(int i = _dimensions; i-- > 0; ){
double lowDist = array[i+minOffset] - location[i];
if(lowDist > 0)
distance += sqr(lowDist);
else{
double highDist = location[i] - array[i+maxOffset];
if(highDist > 0)
distance += sqr(highDist);
}
}
return distance;
}
private final double pointDist(double[] location, int index){
double distance = 0;
int offset = index*_dimensions;
for(int i = _dimensions; i-- > 0;)
distance += sqr(pointLocations.array[offset+i] - location[i]);
return distance;
}
private void expandBounds(double[] location){
entries++;
int offset = index*2*_dimensions;
for(int i = 0; i < _dimensions;i++){
nodeMinMaxBounds.array[offset+i] = Math.min(nodeMinMaxBounds.array[offset+i],location[i]);
nodeMinMaxBounds.array[offset+_dimensions+i] = Math.max(nodeMinMaxBounds.array[offset+_dimensions+i],location[i]);
}
}
private int add(double[] location, T load){
pointLocations.add(location);
pointPayloads.add(load);
return entries;
}
private void split(){
double diff = 0;
int offset = index*2*_dimensions;
for(int i = 0; i < _dimensions; i++){
double min = nodeMinMaxBounds.array[offset+i];
double max = nodeMinMaxBounds.array[offset+_dimensions+i];
if(max - min > diff){
diff = max - min;
splitVal = 0.5*(max + min);
splitDim = i;
}
}
less = new Node(mem_recycle);//recycle that memory!
more = new Node();
//reduce garbage by factor of _bucketSize by recycling this array
double[] pointLocation = new double[_dimensions];
for(int i = 0; i < entries; i++){
System.arraycopy(pointLocations.array,i*_dimensions,pointLocation,0,_dimensions);
T load = pointPayloads.get(i);
if(pointLocation[splitDim] < splitVal){
less.expandBounds(pointLocation);
less.add(pointLocation,load);
}
else{
more.expandBounds(pointLocation);
more.add(pointLocation,load);
}
}
if(less.entries*more.entries == 0){
//one of them was 0, so the split was worthless. throw it away.
less = null;
more = null;
}
else{
//we won't be needing that now, so keep it for the next split to reduce garbage
mem_recycle = pointLocations.array;
pointLocations = null;
pointPayloads.clear();
pointPayloads = null;
}
}
}
private static class ContiguousDoubleArrayList{
double[] array;
int size;
ContiguousDoubleArrayList(){
this(300);
}
ContiguousDoubleArrayList(int size){
this(new double[size]);
}
ContiguousDoubleArrayList(double[] data){
array = data;
}
ContiguousDoubleArrayList add(double[] da){
if(size + da.length >= array.length){
array = Arrays.copyOf(array,(array.length+da.length)*2);
}
System.arraycopy(da,0,array,size,da.length);
size += da.length;
return this;
}
}
private static final double sqr(double d){
return d*d;}
}
</syntaxhighlight></code>