use of org.deeplearning4j.clustering.sptree.DataPoint in project deeplearning4j by deeplearning4j.
the class TreeModelUtils method checkTree.
protected synchronized void checkTree() {
// build new tree if it wasn't created before
if (vpTree == null) {
List<DataPoint> points = new ArrayList<>();
for (String word : vocabCache.words()) {
points.add(new DataPoint(vocabCache.indexOf(word), lookupTable.vector(word)));
}
vpTree = new VPTree(points);
}
}
use of org.deeplearning4j.clustering.sptree.DataPoint in project deeplearning4j by deeplearning4j.
the class TreeModelUtils method wordsNearest.
@Override
public Collection<String> wordsNearest(INDArray words, int top) {
checkTree();
List<DataPoint> add = new ArrayList<>();
List<Double> distances = new ArrayList<>();
// we need n+1 to address original datapoint removal
vpTree.search(new DataPoint(0, words), top, add, distances);
Collection<String> ret = new ArrayList<>();
for (DataPoint e : add) {
String word = vocabCache.wordAtIndex(e.getIndex());
ret.add(word);
}
return super.wordsNearest(words, top);
}
use of org.deeplearning4j.clustering.sptree.DataPoint in project deeplearning4j by deeplearning4j.
the class BarnesHutTsne method computeGaussianPerplexity.
/**
* Convert data to probability
* co-occurrences (aka calculating the kernel)
* @param d the data to convert
* @param u the perplexity of the model
* @return the probabilities of co-occurrence
*/
public INDArray computeGaussianPerplexity(final INDArray d, double u) {
N = d.rows();
final int k = (int) (3 * u);
if (u > k)
throw new IllegalStateException("Illegal k value " + k + "greater than " + u);
rows = zeros(1, N + 1);
cols = zeros(1, N * k);
vals = zeros(1, N * k);
for (int n = 0; n < N; n++) rows.putScalar(n + 1, rows.getDouble(n) + k);
final INDArray beta = ones(N, 1);
final double logU = FastMath.log(u);
VPTree tree = new VPTree(d, simiarlityFunction, invert);
log.info("Calculating probabilities of data similarities...");
for (int i = 0; i < N; i++) {
if (i % 500 == 0)
log.info("Handled " + i + " records");
double betaMin = -Double.MAX_VALUE;
double betaMax = Double.MAX_VALUE;
List<DataPoint> results = new ArrayList<>();
tree.search(new DataPoint(i, d.slice(i)), k + 1, results, new ArrayList<Double>());
double betas = beta.getDouble(i);
INDArray cArr = VPTree.buildFromData(results);
Pair<INDArray, Double> pair = computeGaussianKernel(cArr, beta.getDouble(i), k);
INDArray currP = pair.getFirst();
double hDiff = pair.getSecond() - logU;
int tries = 0;
boolean found = false;
//binary search
while (!found && tries < 200) {
if (hDiff < tolerance && -hDiff < tolerance)
found = true;
else {
if (hDiff > 0) {
betaMin = betas;
if (betaMax == Double.MAX_VALUE || betaMax == -Double.MAX_VALUE)
betas *= 2;
else
betas = (betas + betaMax) / 2.0;
} else {
betaMax = betas;
if (betaMin == -Double.MAX_VALUE || betaMin == Double.MAX_VALUE)
betas /= 2.0;
else
betas = (betas + betaMin) / 2.0;
}
pair = computeGaussianKernel(cArr, betas, k);
hDiff = pair.getSecond() - logU;
tries++;
}
}
currP.divi(currP.sum(Integer.MAX_VALUE));
INDArray indices = Nd4j.create(1, k + 1);
for (int j = 0; j < indices.length(); j++) {
if (j >= results.size())
break;
indices.putScalar(j, results.get(j).getIndex());
}
for (int l = 0; l < k; l++) {
cols.putScalar(rows.getInt(i) + l, indices.getDouble(l + 1));
vals.putScalar(rows.getInt(i) + l, currP.getDouble(l));
}
}
return vals;
}
use of org.deeplearning4j.clustering.sptree.DataPoint in project deeplearning4j by deeplearning4j.
the class VPTree method buildFromPoints.
private Node buildFromPoints(int lower, int upper) {
if (upper == lower)
return null;
Node ret = new Node(lower, 0);
if (upper - lower > 1) {
int randomPoint = MathUtils.randomNumberBetween(lower, upper - 1);
// Partition around the median distance
int median = (upper + lower) / 2;
double[] distances = new double[items.size()];
double[] sortedDistances = new double[items.size()];
DataPoint basePoint = items.get(randomPoint);
for (int i = 0; i < items.size(); ++i) {
distances[i] = getDistance(basePoint, items.get(i));
sortedDistances[i] = distances[i];
}
Arrays.sort(sortedDistances);
final double medianDistance = sortedDistances[sortedDistances.length / 2];
List<DataPoint> leftPoints = new ArrayList<>(sortedDistances.length);
List<DataPoint> rightPoints = new ArrayList<>(sortedDistances.length);
for (int i = 0; i < distances.length; i++) {
if (distances[i] < medianDistance) {
leftPoints.add(items.get(i));
} else {
rightPoints.add(items.get(i));
}
}
for (int i = 0; i < leftPoints.size(); ++i) {
items.set(i, leftPoints.get(i));
}
for (int i = 0; i < rightPoints.size(); ++i) {
items.set(i + leftPoints.size(), rightPoints.get(i));
}
ret.setThreshold(getDistance(items.get(lower), items.get(median)));
ret.setIndex(lower);
ret.setLeft(buildFromPoints(lower + 1, median));
ret.setRight(buildFromPoints(median, upper));
}
return ret;
}
use of org.deeplearning4j.clustering.sptree.DataPoint in project deeplearning4j by deeplearning4j.
the class VPTree method search.
public void search(Node node, DataPoint target, int k, PriorityQueue<HeapItem> pq) {
if (node == null)
return;
DataPoint get = items.get(node.getIndex());
double distance = getDistance(get, target);
if (distance < tau) {
if (pq.size() == k)
pq.next();
pq.add(new HeapItem(node.index, distance), distance);
if (pq.size() == k)
tau = pq.peek().getDistance();
}
if (node.getLeft() == null && node.getRight() == null)
return;
if (distance < node.getThreshold()) {
if (distance - tau <= node.getThreshold()) {
// if there can still be neighbors inside the ball, recursively search left child first
search(node.getLeft(), target, k, pq);
}
if (distance + tau >= node.getThreshold()) {
// if there can still be neighbors outside the ball, recursively search right child
search(node.getRight(), target, k, pq);
}
} else {
if (distance + tau >= node.getThreshold()) {
// if there can still be neighbors outside the ball, recursively search right child first
search(node.getRight(), target, k, pq);
}
if (distance - tau <= node.getThreshold()) {
// if there can still be neighbors inside the ball, recursively search left child
search(node.getLeft(), target, k, pq);
}
}
}
Aggregations