Search in sources :

Example 1 with SpTree

use of org.deeplearning4j.clustering.sptree.SpTree in project deeplearning4j by deeplearning4j.

the class BarnesHutTsne method gradient.

@Override
public Gradient gradient() {
    if (yIncs == null)
        yIncs = zeros(Y.shape());
    if (gains == null)
        gains = ones(Y.shape());
    AtomicDouble sumQ = new AtomicDouble(0);
    /* Calculate gradient based on barnes hut approximation with positive and negative forces */
    INDArray posF = Nd4j.create(Y.shape());
    INDArray negF = Nd4j.create(Y.shape());
    if (tree == null)
        tree = new SpTree(Y);
    tree.computeEdgeForces(rows, cols, vals, N, posF);
    for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, negF.slice(n), sumQ);
    INDArray dC = posF.subi(negF.divi(sumQ));
    Gradient ret = new DefaultGradient();
    ret.gradientForVariable().put(Y_GRAD, dC);
    return ret;
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) AtomicDouble(com.google.common.util.concurrent.AtomicDouble) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataPoint(org.deeplearning4j.clustering.sptree.DataPoint) SpTree(org.deeplearning4j.clustering.sptree.SpTree)

Aggregations

AtomicDouble (com.google.common.util.concurrent.AtomicDouble)1 DataPoint (org.deeplearning4j.clustering.sptree.DataPoint)1 SpTree (org.deeplearning4j.clustering.sptree.SpTree)1 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)1 Gradient (org.deeplearning4j.nn.gradient.Gradient)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1