Search in sources :

Example 31 with IComplexNDArray

use of org.nd4j.linalg.api.complex.IComplexNDArray in project nd4j by deeplearning4j.

the class NativeOpExecutioner method exec.

private void exec(ScalarOp op) {
    if (op.x() instanceof IComplexNDArray || executionMode() == ExecutionMode.JAVA) {
        super.exec(op);
    } else {
        long st = profilingHookIn(op);
        validateDataType(Nd4j.dataType(), op);
        if (op.x().lengthLong() != op.z().lengthLong())
            throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: [" + Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "] != [" + Arrays.toString(op.z().shapeInfoDataBuffer().asInt()) + "]");
        if (op.getDimension() != null) {
            invoke(op, op.getDimension());
            return;
        }
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().elementWiseStride() >= 1 && !op.isExecSpecial()) {
                loop.execScalarDouble(null, op.opNum(), (DoublePointer) op.x().data().addressPointer(), op.x().elementWiseStride(), (DoublePointer) op.z().data().addressPointer(), op.z().elementWiseStride(), op.scalar().doubleValue(), (DoublePointer) getPointerForExtraArgs(op), op.n());
            } else
                loop.execScalarDouble(null, op.opNum(), (DoublePointer) op.x().data().addressPointer(), (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(), (DoublePointer) op.z().data().addressPointer(), (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(), op.scalar().doubleValue(), (DoublePointer) getPointerForExtraArgs(op));
        } else {
            if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().elementWiseStride() >= 1 && !op.isExecSpecial()) {
                loop.execScalarFloat(null, op.opNum(), (FloatPointer) op.x().data().addressPointer(), op.x().elementWiseStride(), (FloatPointer) op.z().data().addressPointer(), op.z().elementWiseStride(), op.scalar().floatValue(), (FloatPointer) getPointerForExtraArgs(op), op.n());
            } else
                loop.execScalarFloat(null, op.opNum(), (FloatPointer) op.x().data().addressPointer(), (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(), (FloatPointer) op.z().data().addressPointer(), (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(), op.scalar().floatValue(), (FloatPointer) getPointerForExtraArgs(op));
        }
        profilingHookOut(op, st);
    }
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) IComplexNDArray(org.nd4j.linalg.api.complex.IComplexNDArray)

Example 32 with IComplexNDArray

use of org.nd4j.linalg.api.complex.IComplexNDArray in project nd4j by deeplearning4j.

the class JCublasNDArrayFactory method createComplex.

@Override
public IComplexNDArray createComplex(float[] dim) {
    if (dim.length % 2 != 0)
        throw new IllegalArgumentException("Complex nd array buffers must have an even number of elements");
    IComplexNDArray ret = Nd4j.createComplex(dim.length / 2);
    int count = 0;
    for (int i = 0; i < dim.length - 1; i += 2) {
        ret.putScalar(count++, Nd4j.createDouble(dim[i], dim[i + 1]));
    }
    return ret;
}
Also used : IComplexNDArray(org.nd4j.linalg.api.complex.IComplexNDArray) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint)

Example 33 with IComplexNDArray

use of org.nd4j.linalg.api.complex.IComplexNDArray in project nd4j by deeplearning4j.

the class NDArrayStrings method format.

private String format(INDArray arr, int offset, boolean summarize) {
    int rank = arr.rank();
    if (arr.isScalar() && rank == 0) {
        // true scalar i.e shape = [] not legacy which is [1,1]
        if (arr instanceof IComplexNDArray) {
            return ((IComplexNDArray) arr).getComplex(0).toString();
        }
        // ///
        double arrElement = arr.getDouble(0);
        if (!dontOverrideFormat && ((Math.abs(arrElement) < this.minToPrintWithoutSwitching && arrElement != 0) || (Math.abs(arrElement) >= this.maxToPrintWithoutSwitching))) {
            // switch to scientific notation
            String asString = new DecimalFormat(scientificFormat).format(arrElement);
            // from E to small e
            asString = asString.replace('E', 'e');
            return asString;
        } else {
            if (arr.getDouble(0) == 0)
                return "0";
            return decimalFormat.format(arr.getDouble(0));
        }
    } else if (rank == 1) {
        // true vector
        return vectorToString(arr, summarize);
    } else if (arr.isRowVector()) {
        // a slice from a higher dim array
        if (offset == 0) {
            StringBuilder sb = new StringBuilder();
            sb.append("[");
            sb.append(vectorToString(arr, summarize));
            sb.append("]");
            return sb.toString();
        }
        return vectorToString(arr, summarize);
    } else {
        offset++;
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        for (int i = 0; i < arr.slices(); i++) {
            if (summarize && i > 2 && i < arr.slices() - 3) {
                if (i == 3) {
                    sb.append(" ...");
                    sb.append(newLineSep + " \n");
                    sb.append(StringUtils.repeat("\n", rank - 2));
                    sb.append(StringUtils.repeat(" ", offset));
                }
            } else {
                if (arr.rank() == 3 && arr.slice(i).isRowVector())
                    sb.append("[");
                // hack fix for slice issue with 'f' order
                if (arr.ordering() == 'f' && arr.rank() > 2 && arr.size(arr.rank() - 1) == 1) {
                    sb.append(format(arr.dup('c').slice(i), offset, summarize));
                } else if (arr.rank() <= 1) {
                    sb.append(format(Nd4j.scalar(arr.getDouble(0)), offset, summarize));
                } else {
                    sb.append(format(arr.slice(i), offset, summarize));
                }
                if (i != arr.slices() - 1) {
                    if (arr.rank() == 3 && arr.slice(i).isRowVector())
                        sb.append("]");
                    sb.append(newLineSep + " \n");
                    sb.append(StringUtils.repeat("\n", rank - 2));
                    sb.append(StringUtils.repeat(" ", offset));
                } else {
                    if (arr.rank() == 3 && arr.slice(i).isRowVector())
                        sb.append("]");
                }
            }
        }
        sb.append("]");
        return sb.toString();
    }
}
Also used : DecimalFormat(java.text.DecimalFormat) IComplexNDArray(org.nd4j.linalg.api.complex.IComplexNDArray)

Example 34 with IComplexNDArray

use of org.nd4j.linalg.api.complex.IComplexNDArray in project nd4j by deeplearning4j.

the class ComplexNDArrayUtil method truncate.

/**
 * Truncates an ndarray to the specified shape.
 * If the shape is the same or greater, it just returns
 * the original array
 *
 * @param nd the ndarray to truncate
 * @param n  the number of elements to truncate to
 * @return the truncated ndarray
 */
public static IComplexNDArray truncate(IComplexNDArray nd, int n, int dimension) {
    if (nd.isVector()) {
        IComplexNDArray truncated = Nd4j.createComplex(new int[] { 1, n });
        for (int i = 0; i < n; i++) truncated.putScalar(i, nd.getComplex(i));
        return truncated;
    }
    if (nd.size(dimension) > n) {
        int[] shape = ArrayUtil.copy(nd.shape());
        shape[dimension] = n;
        IComplexNDArray ret = Nd4j.createComplex(shape);
        IComplexNDArray ndLinear = nd.linearView();
        IComplexNDArray retLinear = ret.linearView();
        for (int i = 0; i < ret.length(); i++) retLinear.putScalar(i, ndLinear.getComplex(i));
        return ret;
    }
    return nd;
}
Also used : IComplexNDArray(org.nd4j.linalg.api.complex.IComplexNDArray)

Example 35 with IComplexNDArray

use of org.nd4j.linalg.api.complex.IComplexNDArray in project nd4j by deeplearning4j.

the class ComplexNDArrayUtil method center.

/**
 * Center an array
 *
 * @param arr   the arr to center
 * @param shape the shape of the array
 * @return the center portion of the array based on the
 * specified shape
 */
public static IComplexNDArray center(IComplexNDArray arr, int[] shape) {
    if (arr.length() < ArrayUtil.prod(shape))
        return arr;
    for (int i = 0; i < shape.length; i++) if (shape[i] < 1)
        shape[i] = 1;
    INDArray shapeMatrix = NDArrayUtil.toNDArray(shape);
    INDArray currShape = NDArrayUtil.toNDArray(arr.shape());
    INDArray startIndex = Transforms.floor(currShape.sub(shapeMatrix).divi(Nd4j.scalar(2)));
    INDArray endIndex = startIndex.add(shapeMatrix);
    INDArrayIndex[] indexes = Indices.createFromStartAndEnd(startIndex, endIndex);
    if (shapeMatrix.length() > 1)
        return arr.get(indexes);
    else {
        IComplexNDArray ret = Nd4j.createComplex(new int[] { (int) shapeMatrix.getDouble(0) });
        int start = (int) startIndex.getDouble(0);
        int end = (int) endIndex.getDouble(0);
        int count = 0;
        for (int i = start; i < end; i++) {
            ret.putScalar(count++, arr.getComplex(i));
        }
        return ret;
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) IComplexNDArray(org.nd4j.linalg.api.complex.IComplexNDArray)

Aggregations

IComplexNDArray (org.nd4j.linalg.api.complex.IComplexNDArray)74 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)6 INDArray (org.nd4j.linalg.api.ndarray.INDArray)5 IComplexNumber (org.nd4j.linalg.api.complex.IComplexNumber)3 DecimalFormat (java.text.DecimalFormat)2 INDArrayIndex (org.nd4j.linalg.indexing.INDArrayIndex)2 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)1 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)1 IComplexDouble (org.nd4j.linalg.api.complex.IComplexDouble)1