Search in sources :

Example 1 with SpecifiedIndex

use of org.nd4j.linalg.indexing.SpecifiedIndex in project deeplearning4j by deeplearning4j.

the class Tsne method x2p.

/**
     * This method build probabilities for given source data
     *
     * @param X
     * @param tolerance
     * @param perplexity
     * @return
     */
private INDArray x2p(final INDArray X, double tolerance, double perplexity) {
    int n = X.rows();
    final INDArray p = zeros(n, n);
    final INDArray beta = ones(n, 1);
    final double logU = Math.log(perplexity);
    INDArray sumX = pow(X, 2).sum(1);
    logger.debug("sumX shape: " + Arrays.toString(sumX.shape()));
    INDArray times = X.mmul(X.transpose()).muli(-2);
    logger.debug("times shape: " + Arrays.toString(times.shape()));
    INDArray prodSum = times.transpose().addiColumnVector(sumX);
    logger.debug("prodSum shape: " + Arrays.toString(prodSum.shape()));
    INDArray D = // thats times
    X.mmul(X.transpose()).mul(-2).transpose().addColumnVector(// thats prodSum
    sumX).addRowVector(// thats D
    sumX.transpose());
    logger.info("Calculating probabilities of data similarities...");
    logger.debug("Tolerance: " + tolerance);
    for (int i = 0; i < n; i++) {
        if (i % 500 == 0 && i > 0)
            logger.info("Handled [" + i + "] records out of [" + n + "]");
        double betaMin = Double.NEGATIVE_INFINITY;
        double betaMax = Double.POSITIVE_INFINITY;
        int[] vals = Ints.concat(ArrayUtil.range(0, i), ArrayUtil.range(i + 1, n));
        INDArrayIndex[] range = new INDArrayIndex[] { new SpecifiedIndex(vals) };
        INDArray row = D.slice(i).get(range);
        Pair<Double, INDArray> pair = hBeta(row, beta.getDouble(i));
        //INDArray hDiff = pair.getFirst().sub(logU);
        double hDiff = pair.getFirst() - logU;
        int tries = 0;
        //while hdiff > tolerance
        while (Math.abs(hDiff) > tolerance && tries < 50) {
            //if hdiff > 0
            if (hDiff > 0) {
                betaMin = beta.getDouble(i);
                if (Double.isInfinite(betaMax))
                    beta.putScalar(i, beta.getDouble(i) * 2.0);
                else
                    beta.putScalar(i, (beta.getDouble(i) + betaMax) / 2.0);
            } else {
                betaMax = beta.getDouble(i);
                if (Double.isInfinite(betaMin))
                    beta.putScalar(i, beta.getDouble(i) / 2.0);
                else
                    beta.putScalar(i, (beta.getDouble(i) + betaMin) / 2.0);
            }
            pair = hBeta(row, beta.getDouble(i));
            hDiff = pair.getFirst() - logU;
            tries++;
        }
        p.slice(i).put(range, pair.getSecond());
    }
    //dont need data in memory after
    logger.info("Mean value of sigma " + sqrt(beta.rdiv(1)).mean(Integer.MAX_VALUE));
    BooleanIndexing.applyWhere(p, Conditions.isNan(), new Value(1e-12));
    //set 0 along the diagonal
    INDArray permute = p.transpose();
    INDArray pOut = p.add(permute);
    pOut.divi(pOut.sumNumber().doubleValue() + 1e-6);
    pOut.muli(4);
    BooleanIndexing.applyWhere(pOut, Conditions.lessThan(1e-12), new Value(1e-12));
    return pOut;
}
Also used : SpecifiedIndex(org.nd4j.linalg.indexing.SpecifiedIndex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) Value(org.nd4j.linalg.indexing.functions.Value)

Aggregations

INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 INDArrayIndex (org.nd4j.linalg.indexing.INDArrayIndex)1 SpecifiedIndex (org.nd4j.linalg.indexing.SpecifiedIndex)1 Value (org.nd4j.linalg.indexing.functions.Value)1