Search in sources :

Example 1 with VPTree

use of org.deeplearning4j.clustering.vptree.VPTree in project deeplearning4j by deeplearning4j.

the class TreeModelUtils method checkTree.

protected synchronized void checkTree() {
    // build new tree if it wasn't created before
    if (vpTree == null) {
        List<DataPoint> points = new ArrayList<>();
        for (String word : vocabCache.words()) {
            points.add(new DataPoint(vocabCache.indexOf(word), lookupTable.vector(word)));
        }
        vpTree = new VPTree(points);
    }
}
Also used : DataPoint(org.deeplearning4j.clustering.sptree.DataPoint) VPTree(org.deeplearning4j.clustering.vptree.VPTree)

Example 2 with VPTree

use of org.deeplearning4j.clustering.vptree.VPTree in project deeplearning4j by deeplearning4j.

the class BarnesHutTsne method computeGaussianPerplexity.

/**
     * Convert data to probability
     * co-occurrences (aka calculating the kernel)
     * @param d the data to convert
     * @param u the perplexity of the model
     * @return the probabilities of co-occurrence
     */
public INDArray computeGaussianPerplexity(final INDArray d, double u) {
    N = d.rows();
    final int k = (int) (3 * u);
    if (u > k)
        throw new IllegalStateException("Illegal k value " + k + "greater than " + u);
    rows = zeros(1, N + 1);
    cols = zeros(1, N * k);
    vals = zeros(1, N * k);
    for (int n = 0; n < N; n++) rows.putScalar(n + 1, rows.getDouble(n) + k);
    final INDArray beta = ones(N, 1);
    final double logU = FastMath.log(u);
    VPTree tree = new VPTree(d, simiarlityFunction, invert);
    log.info("Calculating probabilities of data similarities...");
    for (int i = 0; i < N; i++) {
        if (i % 500 == 0)
            log.info("Handled " + i + " records");
        double betaMin = -Double.MAX_VALUE;
        double betaMax = Double.MAX_VALUE;
        List<DataPoint> results = new ArrayList<>();
        tree.search(new DataPoint(i, d.slice(i)), k + 1, results, new ArrayList<Double>());
        double betas = beta.getDouble(i);
        INDArray cArr = VPTree.buildFromData(results);
        Pair<INDArray, Double> pair = computeGaussianKernel(cArr, beta.getDouble(i), k);
        INDArray currP = pair.getFirst();
        double hDiff = pair.getSecond() - logU;
        int tries = 0;
        boolean found = false;
        //binary search
        while (!found && tries < 200) {
            if (hDiff < tolerance && -hDiff < tolerance)
                found = true;
            else {
                if (hDiff > 0) {
                    betaMin = betas;
                    if (betaMax == Double.MAX_VALUE || betaMax == -Double.MAX_VALUE)
                        betas *= 2;
                    else
                        betas = (betas + betaMax) / 2.0;
                } else {
                    betaMax = betas;
                    if (betaMin == -Double.MAX_VALUE || betaMin == Double.MAX_VALUE)
                        betas /= 2.0;
                    else
                        betas = (betas + betaMin) / 2.0;
                }
                pair = computeGaussianKernel(cArr, betas, k);
                hDiff = pair.getSecond() - logU;
                tries++;
            }
        }
        currP.divi(currP.sum(Integer.MAX_VALUE));
        INDArray indices = Nd4j.create(1, k + 1);
        for (int j = 0; j < indices.length(); j++) {
            if (j >= results.size())
                break;
            indices.putScalar(j, results.get(j).getIndex());
        }
        for (int l = 0; l < k; l++) {
            cols.putScalar(rows.getInt(i) + l, indices.getDouble(l + 1));
            vals.putScalar(rows.getInt(i) + l, currP.getDouble(l));
        }
    }
    return vals;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataPoint(org.deeplearning4j.clustering.sptree.DataPoint) VPTree(org.deeplearning4j.clustering.vptree.VPTree) ArrayList(java.util.ArrayList) AtomicDouble(com.google.common.util.concurrent.AtomicDouble) DataPoint(org.deeplearning4j.clustering.sptree.DataPoint)

Aggregations

DataPoint (org.deeplearning4j.clustering.sptree.DataPoint)2 VPTree (org.deeplearning4j.clustering.vptree.VPTree)2 AtomicDouble (com.google.common.util.concurrent.AtomicDouble)1 ArrayList (java.util.ArrayList)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1