use of org.deeplearning4j.clustering.sptree.SpTree in project deeplearning4j by deeplearning4j.
the class BarnesHutTsne method gradient.
@Override
public Gradient gradient() {
if (yIncs == null)
yIncs = zeros(Y.shape());
if (gains == null)
gains = ones(Y.shape());
AtomicDouble sumQ = new AtomicDouble(0);
/* Calculate gradient based on barnes hut approximation with positive and negative forces */
INDArray posF = Nd4j.create(Y.shape());
INDArray negF = Nd4j.create(Y.shape());
if (tree == null)
tree = new SpTree(Y);
tree.computeEdgeForces(rows, cols, vals, N, posF);
for (int n = 0; n < N; n++) tree.computeNonEdgeForces(n, theta, negF.slice(n), sumQ);
INDArray dC = posF.subi(negF.divi(sumQ));
Gradient ret = new DefaultGradient();
ret.gradientForVariable().put(Y_GRAD, dC);
return ret;
}
Aggregations