Search in sources :

Example 1 with DataPoint

use of org.deeplearning4j.clustering.sptree.DataPoint in project deeplearning4j by deeplearning4j.

the class TreeModelUtils method checkTree.

protected synchronized void checkTree() {
    // build new tree if it wasn't created before
    if (vpTree == null) {
        List<DataPoint> points = new ArrayList<>();
        for (String word : vocabCache.words()) {
            points.add(new DataPoint(vocabCache.indexOf(word), lookupTable.vector(word)));
        }
        vpTree = new VPTree(points);
    }
}
Also used : DataPoint(org.deeplearning4j.clustering.sptree.DataPoint) VPTree(org.deeplearning4j.clustering.vptree.VPTree)

Example 2 with DataPoint

use of org.deeplearning4j.clustering.sptree.DataPoint in project deeplearning4j by deeplearning4j.

the class TreeModelUtils method wordsNearest.

@Override
public Collection<String> wordsNearest(INDArray words, int top) {
    checkTree();
    List<DataPoint> add = new ArrayList<>();
    List<Double> distances = new ArrayList<>();
    // we need n+1 to address original datapoint removal
    vpTree.search(new DataPoint(0, words), top, add, distances);
    Collection<String> ret = new ArrayList<>();
    for (DataPoint e : add) {
        String word = vocabCache.wordAtIndex(e.getIndex());
        ret.add(word);
    }
    return super.wordsNearest(words, top);
}
Also used : DataPoint(org.deeplearning4j.clustering.sptree.DataPoint)

Example 3 with DataPoint

use of org.deeplearning4j.clustering.sptree.DataPoint in project deeplearning4j by deeplearning4j.

the class BarnesHutTsne method computeGaussianPerplexity.

/**
     * Convert data to probability
     * co-occurrences (aka calculating the kernel)
     * @param d the data to convert
     * @param u the perplexity of the model
     * @return the probabilities of co-occurrence
     */
public INDArray computeGaussianPerplexity(final INDArray d, double u) {
    N = d.rows();
    final int k = (int) (3 * u);
    if (u > k)
        throw new IllegalStateException("Illegal k value " + k + "greater than " + u);
    rows = zeros(1, N + 1);
    cols = zeros(1, N * k);
    vals = zeros(1, N * k);
    for (int n = 0; n < N; n++) rows.putScalar(n + 1, rows.getDouble(n) + k);
    final INDArray beta = ones(N, 1);
    final double logU = FastMath.log(u);
    VPTree tree = new VPTree(d, simiarlityFunction, invert);
    log.info("Calculating probabilities of data similarities...");
    for (int i = 0; i < N; i++) {
        if (i % 500 == 0)
            log.info("Handled " + i + " records");
        double betaMin = -Double.MAX_VALUE;
        double betaMax = Double.MAX_VALUE;
        List<DataPoint> results = new ArrayList<>();
        tree.search(new DataPoint(i, d.slice(i)), k + 1, results, new ArrayList<Double>());
        double betas = beta.getDouble(i);
        INDArray cArr = VPTree.buildFromData(results);
        Pair<INDArray, Double> pair = computeGaussianKernel(cArr, beta.getDouble(i), k);
        INDArray currP = pair.getFirst();
        double hDiff = pair.getSecond() - logU;
        int tries = 0;
        boolean found = false;
        //binary search
        while (!found && tries < 200) {
            if (hDiff < tolerance && -hDiff < tolerance)
                found = true;
            else {
                if (hDiff > 0) {
                    betaMin = betas;
                    if (betaMax == Double.MAX_VALUE || betaMax == -Double.MAX_VALUE)
                        betas *= 2;
                    else
                        betas = (betas + betaMax) / 2.0;
                } else {
                    betaMax = betas;
                    if (betaMin == -Double.MAX_VALUE || betaMin == Double.MAX_VALUE)
                        betas /= 2.0;
                    else
                        betas = (betas + betaMin) / 2.0;
                }
                pair = computeGaussianKernel(cArr, betas, k);
                hDiff = pair.getSecond() - logU;
                tries++;
            }
        }
        currP.divi(currP.sum(Integer.MAX_VALUE));
        INDArray indices = Nd4j.create(1, k + 1);
        for (int j = 0; j < indices.length(); j++) {
            if (j >= results.size())
                break;
            indices.putScalar(j, results.get(j).getIndex());
        }
        for (int l = 0; l < k; l++) {
            cols.putScalar(rows.getInt(i) + l, indices.getDouble(l + 1));
            vals.putScalar(rows.getInt(i) + l, currP.getDouble(l));
        }
    }
    return vals;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataPoint(org.deeplearning4j.clustering.sptree.DataPoint) VPTree(org.deeplearning4j.clustering.vptree.VPTree) ArrayList(java.util.ArrayList) AtomicDouble(com.google.common.util.concurrent.AtomicDouble) DataPoint(org.deeplearning4j.clustering.sptree.DataPoint)

Example 4 with DataPoint

use of org.deeplearning4j.clustering.sptree.DataPoint in project deeplearning4j by deeplearning4j.

the class VPTree method buildFromPoints.

private Node buildFromPoints(int lower, int upper) {
    if (upper == lower)
        return null;
    Node ret = new Node(lower, 0);
    if (upper - lower > 1) {
        int randomPoint = MathUtils.randomNumberBetween(lower, upper - 1);
        // Partition around the median distance
        int median = (upper + lower) / 2;
        double[] distances = new double[items.size()];
        double[] sortedDistances = new double[items.size()];
        DataPoint basePoint = items.get(randomPoint);
        for (int i = 0; i < items.size(); ++i) {
            distances[i] = getDistance(basePoint, items.get(i));
            sortedDistances[i] = distances[i];
        }
        Arrays.sort(sortedDistances);
        final double medianDistance = sortedDistances[sortedDistances.length / 2];
        List<DataPoint> leftPoints = new ArrayList<>(sortedDistances.length);
        List<DataPoint> rightPoints = new ArrayList<>(sortedDistances.length);
        for (int i = 0; i < distances.length; i++) {
            if (distances[i] < medianDistance) {
                leftPoints.add(items.get(i));
            } else {
                rightPoints.add(items.get(i));
            }
        }
        for (int i = 0; i < leftPoints.size(); ++i) {
            items.set(i, leftPoints.get(i));
        }
        for (int i = 0; i < rightPoints.size(); ++i) {
            items.set(i + leftPoints.size(), rightPoints.get(i));
        }
        ret.setThreshold(getDistance(items.get(lower), items.get(median)));
        ret.setIndex(lower);
        ret.setLeft(buildFromPoints(lower + 1, median));
        ret.setRight(buildFromPoints(median, upper));
    }
    return ret;
}
Also used : DataPoint(org.deeplearning4j.clustering.sptree.DataPoint) ArrayList(java.util.ArrayList) DataPoint(org.deeplearning4j.clustering.sptree.DataPoint)

Example 5 with DataPoint

use of org.deeplearning4j.clustering.sptree.DataPoint in project deeplearning4j by deeplearning4j.

the class VPTree method search.

public void search(Node node, DataPoint target, int k, PriorityQueue<HeapItem> pq) {
    if (node == null)
        return;
    DataPoint get = items.get(node.getIndex());
    double distance = getDistance(get, target);
    if (distance < tau) {
        if (pq.size() == k)
            pq.next();
        pq.add(new HeapItem(node.index, distance), distance);
        if (pq.size() == k)
            tau = pq.peek().getDistance();
    }
    if (node.getLeft() == null && node.getRight() == null)
        return;
    if (distance < node.getThreshold()) {
        if (distance - tau <= node.getThreshold()) {
            // if there can still be neighbors inside the ball, recursively search left child first
            search(node.getLeft(), target, k, pq);
        }
        if (distance + tau >= node.getThreshold()) {
            // if there can still be neighbors outside the ball, recursively search right child
            search(node.getRight(), target, k, pq);
        }
    } else {
        if (distance + tau >= node.getThreshold()) {
            // if there can still be neighbors outside the ball, recursively search right child first
            search(node.getRight(), target, k, pq);
        }
        if (distance - tau <= node.getThreshold()) {
            // if there can still be neighbors inside the ball, recursively search left child
            search(node.getLeft(), target, k, pq);
        }
    }
}
Also used : HeapItem(org.deeplearning4j.clustering.sptree.HeapItem) DataPoint(org.deeplearning4j.clustering.sptree.DataPoint)

Aggregations

DataPoint (org.deeplearning4j.clustering.sptree.DataPoint)6 ArrayList (java.util.ArrayList)3 VPTree (org.deeplearning4j.clustering.vptree.VPTree)2 AtomicDouble (com.google.common.util.concurrent.AtomicDouble)1 HeapItem (org.deeplearning4j.clustering.sptree.HeapItem)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1