Search in sources :

Example 6 with Variance

use of org.nd4j.linalg.api.ops.impl.accum.Variance in project nd4j by deeplearning4j.

the class CudaGridExecutioner method pushToGrid.

/**
 * This method adds op into GridOp queue
 *
 * @return
 */
protected void pushToGrid(OpDescriptor descriptor, boolean flush) {
    // we should just add op to queue here
    // deviceQueues.get().add(descriptor);
    // FIXME: following code should be removed, since it's just executing supers instead of batching
    execCounter.incrementAndGet();
    Op op = descriptor.getOp();
    int[] dimensions = descriptor.getDimensions();
    if (op instanceof TransformOp) {
        TransformOp t = (TransformOp) op;
        if (flush)
            flushQueue();
        // logger.info("Sending TransformOp to CudaExecutioner");
        super.invoke(t);
    } else if (op instanceof Variance) {
        Variance acc = (Variance) op;
        if (flush)
            flushQueue();
        super.naiveExec(acc, dimensions);
    } else if (op instanceof Accumulation) {
        Accumulation acc = (Accumulation) op;
        if (flush)
            flushQueue();
        // logger.info("Sending AccumulationOp to CudaExecutioner: {}", Arrays.toString(dimensions));
        super.naiveExec(acc, dimensions);
    } else if (op instanceof ScalarOp) {
        ScalarOp sc = (ScalarOp) op;
        if (flush)
            flushQueue();
        // logger.info("Sending ScalarOp to CudaExecutioner");
        super.invoke(sc);
    } else if (op instanceof BroadcastOp) {
        BroadcastOp broadcastOp = (BroadcastOp) op;
        if (flush)
            flushQueue();
        // logger.info("Sending BroadcastOp to CudaExecutioner");
        if (dimensions != null) {
            super.exec(broadcastOp, dimensions);
        } else {
            super.invoke(broadcastOp);
        }
    } else if (op instanceof IndexAccumulation) {
        IndexAccumulation indexAccumulation = (IndexAccumulation) op;
        if (flush)
            flushQueue();
        // logger.info("Sending IndexAccumulationOp to CudaExecutioner");
        super.exec(indexAccumulation, dimensions);
    } else if (op instanceof MetaOp) {
        // logger.info("Executing MetaOp");
        metaCounter.incrementAndGet();
        exec((MetaOp) op);
    } else if (op instanceof GridOp) {
        // logger.info("Executing GridOp");
        exec((GridOp) op);
    }
}
Also used : InvertedPredicateMetaOp(org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp) PredicateMetaOp(org.nd4j.linalg.api.ops.impl.meta.PredicateMetaOp) ReduceMetaOp(org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp) PostulateMetaOp(org.nd4j.linalg.api.ops.impl.meta.PostulateMetaOp) InvertedPredicateMetaOp(org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp) PredicateMetaOp(org.nd4j.linalg.api.ops.impl.meta.PredicateMetaOp) ReduceMetaOp(org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp) PostulateMetaOp(org.nd4j.linalg.api.ops.impl.meta.PostulateMetaOp) Variance(org.nd4j.linalg.api.ops.impl.accum.Variance)

Aggregations

Variance (org.nd4j.linalg.api.ops.impl.accum.Variance)6 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 lombok.val (lombok.val)3 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)3 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)2 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)2 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)2 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)2 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)2 ArrayList (java.util.ArrayList)1 Set (java.util.Set)1 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)1 StandardDeviation (org.nd4j.linalg.api.ops.impl.accum.StandardDeviation)1 InvertedPredicateMetaOp (org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp)1 PostulateMetaOp (org.nd4j.linalg.api.ops.impl.meta.PostulateMetaOp)1 PredicateMetaOp (org.nd4j.linalg.api.ops.impl.meta.PredicateMetaOp)1 ReduceMetaOp (org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp)1 BaseRandomOp (org.nd4j.linalg.api.ops.random.BaseRandomOp)1