use of org.nd4j.linalg.api.complex.IComplexNDArray in project nd4j by deeplearning4j.
the class NativeOpExecutioner method exec.
private void exec(ScalarOp op) {
if (op.x() instanceof IComplexNDArray || executionMode() == ExecutionMode.JAVA) {
super.exec(op);
} else {
long st = profilingHookIn(op);
validateDataType(Nd4j.dataType(), op);
if (op.x().lengthLong() != op.z().lengthLong())
throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: [" + Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "] != [" + Arrays.toString(op.z().shapeInfoDataBuffer().asInt()) + "]");
if (op.getDimension() != null) {
invoke(op, op.getDimension());
return;
}
if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().elementWiseStride() >= 1 && !op.isExecSpecial()) {
loop.execScalarDouble(null, op.opNum(), (DoublePointer) op.x().data().addressPointer(), op.x().elementWiseStride(), (DoublePointer) op.z().data().addressPointer(), op.z().elementWiseStride(), op.scalar().doubleValue(), (DoublePointer) getPointerForExtraArgs(op), op.n());
} else
loop.execScalarDouble(null, op.opNum(), (DoublePointer) op.x().data().addressPointer(), (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(), (DoublePointer) op.z().data().addressPointer(), (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(), op.scalar().doubleValue(), (DoublePointer) getPointerForExtraArgs(op));
} else {
if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().elementWiseStride() >= 1 && !op.isExecSpecial()) {
loop.execScalarFloat(null, op.opNum(), (FloatPointer) op.x().data().addressPointer(), op.x().elementWiseStride(), (FloatPointer) op.z().data().addressPointer(), op.z().elementWiseStride(), op.scalar().floatValue(), (FloatPointer) getPointerForExtraArgs(op), op.n());
} else
loop.execScalarFloat(null, op.opNum(), (FloatPointer) op.x().data().addressPointer(), (IntPointer) op.x().shapeInfoDataBuffer().addressPointer(), (FloatPointer) op.z().data().addressPointer(), (IntPointer) op.z().shapeInfoDataBuffer().addressPointer(), op.scalar().floatValue(), (FloatPointer) getPointerForExtraArgs(op));
}
profilingHookOut(op, st);
}
}
use of org.nd4j.linalg.api.complex.IComplexNDArray in project nd4j by deeplearning4j.
the class JCublasNDArrayFactory method createComplex.
@Override
public IComplexNDArray createComplex(float[] dim) {
if (dim.length % 2 != 0)
throw new IllegalArgumentException("Complex nd array buffers must have an even number of elements");
IComplexNDArray ret = Nd4j.createComplex(dim.length / 2);
int count = 0;
for (int i = 0; i < dim.length - 1; i += 2) {
ret.putScalar(count++, Nd4j.createDouble(dim[i], dim[i + 1]));
}
return ret;
}
use of org.nd4j.linalg.api.complex.IComplexNDArray in project nd4j by deeplearning4j.
the class NDArrayStrings method format.
private String format(INDArray arr, int offset, boolean summarize) {
int rank = arr.rank();
if (arr.isScalar() && rank == 0) {
// true scalar i.e shape = [] not legacy which is [1,1]
if (arr instanceof IComplexNDArray) {
return ((IComplexNDArray) arr).getComplex(0).toString();
}
// ///
double arrElement = arr.getDouble(0);
if (!dontOverrideFormat && ((Math.abs(arrElement) < this.minToPrintWithoutSwitching && arrElement != 0) || (Math.abs(arrElement) >= this.maxToPrintWithoutSwitching))) {
// switch to scientific notation
String asString = new DecimalFormat(scientificFormat).format(arrElement);
// from E to small e
asString = asString.replace('E', 'e');
return asString;
} else {
if (arr.getDouble(0) == 0)
return "0";
return decimalFormat.format(arr.getDouble(0));
}
} else if (rank == 1) {
// true vector
return vectorToString(arr, summarize);
} else if (arr.isRowVector()) {
// a slice from a higher dim array
if (offset == 0) {
StringBuilder sb = new StringBuilder();
sb.append("[");
sb.append(vectorToString(arr, summarize));
sb.append("]");
return sb.toString();
}
return vectorToString(arr, summarize);
} else {
offset++;
StringBuilder sb = new StringBuilder();
sb.append("[");
for (int i = 0; i < arr.slices(); i++) {
if (summarize && i > 2 && i < arr.slices() - 3) {
if (i == 3) {
sb.append(" ...");
sb.append(newLineSep + " \n");
sb.append(StringUtils.repeat("\n", rank - 2));
sb.append(StringUtils.repeat(" ", offset));
}
} else {
if (arr.rank() == 3 && arr.slice(i).isRowVector())
sb.append("[");
// hack fix for slice issue with 'f' order
if (arr.ordering() == 'f' && arr.rank() > 2 && arr.size(arr.rank() - 1) == 1) {
sb.append(format(arr.dup('c').slice(i), offset, summarize));
} else if (arr.rank() <= 1) {
sb.append(format(Nd4j.scalar(arr.getDouble(0)), offset, summarize));
} else {
sb.append(format(arr.slice(i), offset, summarize));
}
if (i != arr.slices() - 1) {
if (arr.rank() == 3 && arr.slice(i).isRowVector())
sb.append("]");
sb.append(newLineSep + " \n");
sb.append(StringUtils.repeat("\n", rank - 2));
sb.append(StringUtils.repeat(" ", offset));
} else {
if (arr.rank() == 3 && arr.slice(i).isRowVector())
sb.append("]");
}
}
}
sb.append("]");
return sb.toString();
}
}
use of org.nd4j.linalg.api.complex.IComplexNDArray in project nd4j by deeplearning4j.
the class ComplexNDArrayUtil method truncate.
/**
* Truncates an ndarray to the specified shape.
* If the shape is the same or greater, it just returns
* the original array
*
* @param nd the ndarray to truncate
* @param n the number of elements to truncate to
* @return the truncated ndarray
*/
public static IComplexNDArray truncate(IComplexNDArray nd, int n, int dimension) {
if (nd.isVector()) {
IComplexNDArray truncated = Nd4j.createComplex(new int[] { 1, n });
for (int i = 0; i < n; i++) truncated.putScalar(i, nd.getComplex(i));
return truncated;
}
if (nd.size(dimension) > n) {
int[] shape = ArrayUtil.copy(nd.shape());
shape[dimension] = n;
IComplexNDArray ret = Nd4j.createComplex(shape);
IComplexNDArray ndLinear = nd.linearView();
IComplexNDArray retLinear = ret.linearView();
for (int i = 0; i < ret.length(); i++) retLinear.putScalar(i, ndLinear.getComplex(i));
return ret;
}
return nd;
}
use of org.nd4j.linalg.api.complex.IComplexNDArray in project nd4j by deeplearning4j.
the class ComplexNDArrayUtil method center.
/**
* Center an array
*
* @param arr the arr to center
* @param shape the shape of the array
* @return the center portion of the array based on the
* specified shape
*/
public static IComplexNDArray center(IComplexNDArray arr, int[] shape) {
if (arr.length() < ArrayUtil.prod(shape))
return arr;
for (int i = 0; i < shape.length; i++) if (shape[i] < 1)
shape[i] = 1;
INDArray shapeMatrix = NDArrayUtil.toNDArray(shape);
INDArray currShape = NDArrayUtil.toNDArray(arr.shape());
INDArray startIndex = Transforms.floor(currShape.sub(shapeMatrix).divi(Nd4j.scalar(2)));
INDArray endIndex = startIndex.add(shapeMatrix);
INDArrayIndex[] indexes = Indices.createFromStartAndEnd(startIndex, endIndex);
if (shapeMatrix.length() > 1)
return arr.get(indexes);
else {
IComplexNDArray ret = Nd4j.createComplex(new int[] { (int) shapeMatrix.getDouble(0) });
int start = (int) startIndex.getDouble(0);
int end = (int) endIndex.getDouble(0);
int count = 0;
for (int i = start; i < end; i++) {
ret.putScalar(count++, arr.getComplex(i));
}
return ret;
}
}
Aggregations