use of org.nd4j.linalg.indexing.SpecifiedIndex in project deeplearning4j by deeplearning4j.
the class Tsne method x2p.
/**
* This method build probabilities for given source data
*
* @param X
* @param tolerance
* @param perplexity
* @return
*/
private INDArray x2p(final INDArray X, double tolerance, double perplexity) {
int n = X.rows();
final INDArray p = zeros(n, n);
final INDArray beta = ones(n, 1);
final double logU = Math.log(perplexity);
INDArray sumX = pow(X, 2).sum(1);
logger.debug("sumX shape: " + Arrays.toString(sumX.shape()));
INDArray times = X.mmul(X.transpose()).muli(-2);
logger.debug("times shape: " + Arrays.toString(times.shape()));
INDArray prodSum = times.transpose().addiColumnVector(sumX);
logger.debug("prodSum shape: " + Arrays.toString(prodSum.shape()));
INDArray D = // thats times
X.mmul(X.transpose()).mul(-2).transpose().addColumnVector(// thats prodSum
sumX).addRowVector(// thats D
sumX.transpose());
logger.info("Calculating probabilities of data similarities...");
logger.debug("Tolerance: " + tolerance);
for (int i = 0; i < n; i++) {
if (i % 500 == 0 && i > 0)
logger.info("Handled [" + i + "] records out of [" + n + "]");
double betaMin = Double.NEGATIVE_INFINITY;
double betaMax = Double.POSITIVE_INFINITY;
int[] vals = Ints.concat(ArrayUtil.range(0, i), ArrayUtil.range(i + 1, n));
INDArrayIndex[] range = new INDArrayIndex[] { new SpecifiedIndex(vals) };
INDArray row = D.slice(i).get(range);
Pair<Double, INDArray> pair = hBeta(row, beta.getDouble(i));
//INDArray hDiff = pair.getFirst().sub(logU);
double hDiff = pair.getFirst() - logU;
int tries = 0;
//while hdiff > tolerance
while (Math.abs(hDiff) > tolerance && tries < 50) {
//if hdiff > 0
if (hDiff > 0) {
betaMin = beta.getDouble(i);
if (Double.isInfinite(betaMax))
beta.putScalar(i, beta.getDouble(i) * 2.0);
else
beta.putScalar(i, (beta.getDouble(i) + betaMax) / 2.0);
} else {
betaMax = beta.getDouble(i);
if (Double.isInfinite(betaMin))
beta.putScalar(i, beta.getDouble(i) / 2.0);
else
beta.putScalar(i, (beta.getDouble(i) + betaMin) / 2.0);
}
pair = hBeta(row, beta.getDouble(i));
hDiff = pair.getFirst() - logU;
tries++;
}
p.slice(i).put(range, pair.getSecond());
}
//dont need data in memory after
logger.info("Mean value of sigma " + sqrt(beta.rdiv(1)).mean(Integer.MAX_VALUE));
BooleanIndexing.applyWhere(p, Conditions.isNan(), new Value(1e-12));
//set 0 along the diagonal
INDArray permute = p.transpose();
INDArray pOut = p.add(permute);
pOut.divi(pOut.sumNumber().doubleValue() + 1e-6);
pOut.muli(4);
BooleanIndexing.applyWhere(pOut, Conditions.lessThan(1e-12), new Value(1e-12));
return pOut;
}
Aggregations