use of org.nd4j.linalg.api.ops.impl.transforms.Sqrt in project nd4j by deeplearning4j.
the class CrashTest method op.
protected void op(INDArray x, INDArray y, int i) {
// broadcast along row & column
INDArray row = Nd4j.ones(64);
INDArray column = Nd4j.ones(1024, 1);
x.addiRowVector(row);
x.addiColumnVector(column);
// casual scalar
x.addi(i * 2);
// reduction along all dimensions
float sum = x.sumNumber().floatValue();
// index reduction
Nd4j.getExecutioner().exec(new IMax(x), Integer.MAX_VALUE);
// casual transform
Nd4j.getExecutioner().exec(new Sqrt(x, x));
// dup
INDArray x1 = x.dup(x.ordering());
INDArray x2 = x.dup(x.ordering());
INDArray x3 = x.dup('c');
INDArray x4 = x.dup('f');
// vstack && hstack
INDArray vstack = Nd4j.vstack(x, x1, x2, x3, x4);
INDArray hstack = Nd4j.hstack(x, x1, x2, x3, x4);
// reduce3 call
Nd4j.getExecutioner().exec(new ManhattanDistance(x, x2));
// flatten call
INDArray flat = Nd4j.toFlattened(x, x1, x2, x3, x4);
// reduction along dimension: row & column
INDArray max_0 = x.max(0);
INDArray max_1 = x.max(1);
// index reduction along dimension: row & column
INDArray imax_0 = Nd4j.argMax(x, 0);
INDArray imax_1 = Nd4j.argMax(x, 1);
// logisoftmax, softmax & softmax derivative
Nd4j.getExecutioner().exec(new OldSoftMax(x));
Nd4j.getExecutioner().exec(new SoftMaxDerivative(x));
Nd4j.getExecutioner().exec(new LogSoftMax(x));
// BooleanIndexing
BooleanIndexing.replaceWhere(x, 5f, Conditions.lessThan(8f));
// assing on view
BooleanIndexing.assignIf(x, x1, Conditions.greaterThan(-1000000000f));
// std var along all dimensions
float std = x.stdNumber().floatValue();
// std var along row & col
INDArray xStd_0 = x.std(0);
INDArray xStd_1 = x.std(1);
// blas call
float dot = (float) Nd4j.getBlasWrapper().dot(x, x1);
// mmul
for (boolean tA : paramsA) {
for (boolean tB : paramsB) {
INDArray xT = tA ? x.dup() : x.dup().transpose();
INDArray yT = tB ? y.dup() : y.dup().transpose();
Nd4j.gemm(xT, yT, tA, tB);
}
}
// specially for views, checking here without dup and rollover
Nd4j.gemm(x, y, false, false);
System.out.println("Iteration passed: " + i);
}
Aggregations