use of org.deeplearning4j.clustering.vptree.VPTree in project deeplearning4j by deeplearning4j.
the class TreeModelUtils method checkTree.
protected synchronized void checkTree() {
// build new tree if it wasn't created before
if (vpTree == null) {
List<DataPoint> points = new ArrayList<>();
for (String word : vocabCache.words()) {
points.add(new DataPoint(vocabCache.indexOf(word), lookupTable.vector(word)));
}
vpTree = new VPTree(points);
}
}
use of org.deeplearning4j.clustering.vptree.VPTree in project deeplearning4j by deeplearning4j.
the class BarnesHutTsne method computeGaussianPerplexity.
/**
* Convert data to probability
* co-occurrences (aka calculating the kernel)
* @param d the data to convert
* @param u the perplexity of the model
* @return the probabilities of co-occurrence
*/
public INDArray computeGaussianPerplexity(final INDArray d, double u) {
N = d.rows();
final int k = (int) (3 * u);
if (u > k)
throw new IllegalStateException("Illegal k value " + k + "greater than " + u);
rows = zeros(1, N + 1);
cols = zeros(1, N * k);
vals = zeros(1, N * k);
for (int n = 0; n < N; n++) rows.putScalar(n + 1, rows.getDouble(n) + k);
final INDArray beta = ones(N, 1);
final double logU = FastMath.log(u);
VPTree tree = new VPTree(d, simiarlityFunction, invert);
log.info("Calculating probabilities of data similarities...");
for (int i = 0; i < N; i++) {
if (i % 500 == 0)
log.info("Handled " + i + " records");
double betaMin = -Double.MAX_VALUE;
double betaMax = Double.MAX_VALUE;
List<DataPoint> results = new ArrayList<>();
tree.search(new DataPoint(i, d.slice(i)), k + 1, results, new ArrayList<Double>());
double betas = beta.getDouble(i);
INDArray cArr = VPTree.buildFromData(results);
Pair<INDArray, Double> pair = computeGaussianKernel(cArr, beta.getDouble(i), k);
INDArray currP = pair.getFirst();
double hDiff = pair.getSecond() - logU;
int tries = 0;
boolean found = false;
//binary search
while (!found && tries < 200) {
if (hDiff < tolerance && -hDiff < tolerance)
found = true;
else {
if (hDiff > 0) {
betaMin = betas;
if (betaMax == Double.MAX_VALUE || betaMax == -Double.MAX_VALUE)
betas *= 2;
else
betas = (betas + betaMax) / 2.0;
} else {
betaMax = betas;
if (betaMin == -Double.MAX_VALUE || betaMin == Double.MAX_VALUE)
betas /= 2.0;
else
betas = (betas + betaMin) / 2.0;
}
pair = computeGaussianKernel(cArr, betas, k);
hDiff = pair.getSecond() - logU;
tries++;
}
}
currP.divi(currP.sum(Integer.MAX_VALUE));
INDArray indices = Nd4j.create(1, k + 1);
for (int j = 0; j < indices.length(); j++) {
if (j >= results.size())
break;
indices.putScalar(j, results.get(j).getIndex());
}
for (int l = 0; l < k; l++) {
cols.putScalar(rows.getInt(i) + l, indices.getDouble(l + 1));
vals.putScalar(rows.getInt(i) + l, currP.getDouble(l));
}
}
return vals;
}
Aggregations