use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.
the class Shape method isRowVectorShape.
/**
* Returns true if the given shape is of length 1
* or provided the shape length is 2:
* element 0 is 1
* @param shapeInfo the shape info to check
* @return true if the above conditions hold,false otherwise
*/
public static boolean isRowVectorShape(DataBuffer shapeInfo) {
int rank = Shape.rank(shapeInfo);
DataBuffer shape = Shape.shapeOf(shapeInfo);
return (rank == 2 && shape.getInt(0) == 1) || rank == 1;
}
use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.
the class Shape method createShapeInformation.
/**
* Creates the shape information buffer
* given the shape,stride
* @param shape the shape for the buffer
* @param stride the stride for the buffer
* @param offset the offset for the buffer
* @param elementWiseStride the element wise stride for the buffer
* @param order the order for the buffer
* @return the shape information buffer given the parameters
*/
public static DataBuffer createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) {
if (shape.length != stride.length)
throw new IllegalStateException("Shape and stride must be the same length");
int rank = shape.length;
int[] shapeBuffer = new int[rank * 2 + 4];
shapeBuffer[0] = rank;
int count = 1;
for (int e = 0; e < shape.length; e++) shapeBuffer[count++] = shape[e];
for (int e = 0; e < stride.length; e++) shapeBuffer[count++] = stride[e];
shapeBuffer[count++] = (int) offset;
shapeBuffer[count++] = elementWiseStride;
shapeBuffer[count] = (int) order;
DataBuffer ret = Nd4j.createBufferDetached(shapeBuffer);
ret.setConstant(true);
/*
DataBuffer ret = Nd4j.createBuffer(new int[shapeInfoLength(shape.length)]);
ret.setConstant(true);
int count = 1;
ret.put(0,shape.length);
if(shape.length != stride.length)
throw new IllegalStateException("Shape and stride must be the same length");
for (int i = 0; i < shape.length; i++) {
ret.put(count++,shape[i]);
}
for (int i = 0; i < shape.length; i++) {
ret.put(count++,stride[i]);
}
ret.put(count++,offset);
ret.put(count++,elementWiseStride);
ret.put(count++,order);
*/
return ret;
}
use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.
the class Nd4j method createBufferDetached.
/**
* Create a buffer equal of length prod(shape). This method is NOT affected by workspaces
*
* @param data
* @return
*/
public static DataBuffer createBufferDetached(int[] data) {
DataBuffer ret;
ret = DATA_BUFFER_FACTORY_INSTANCE.createInt(data);
logCreationIfNecessary(ret);
return ret;
}
use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.
the class Nd4j method createBuffer.
/**
* Create a buffer based on the data opType
*
* @param data the data to create the buffer with
* @return the created buffer
*/
public static DataBuffer createBuffer(double[] data) {
DataBuffer ret;
if (dataType() == DataBuffer.Type.DOUBLE)
ret = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createDouble(data) : DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (dataType() == DataBuffer.Type.HALF)
ret = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createHalf(data) : DATA_BUFFER_FACTORY_INSTANCE.createHalf(ArrayUtil.toFloats(data), Nd4j.getMemoryManager().getCurrentWorkspace());
else
ret = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? DATA_BUFFER_FACTORY_INSTANCE.createFloat(ArrayUtil.toFloats(data)) : DATA_BUFFER_FACTORY_INSTANCE.createFloat(ArrayUtil.toFloats(data), Nd4j.getMemoryManager().getCurrentWorkspace());
logCreationIfNecessary(ret);
return ret;
}
use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.
the class NDArrayIndex method resolve.
/**
* Given an all index and
* the intended indexes, return an
* index array containing a combination of all elements
* for slicing and overriding particular indexes where necessary
* @param shapeInfo the index containing all elements
* @param intendedIndexes the indexes specified by the user
* @return the resolved indexes (containing all where nothing is specified, and the intended index
* for a particular dimension otherwise)
*/
public static INDArrayIndex[] resolve(DataBuffer shapeInfo, INDArrayIndex... intendedIndexes) {
int numSpecified = 0;
for (int i = 0; i < intendedIndexes.length; i++) {
if (intendedIndexes[i] instanceof SpecifiedIndex)
numSpecified++;
}
if (numSpecified > 0) {
DataBuffer shape = Shape.shapeOf(shapeInfo);
INDArrayIndex[] ret = new INDArrayIndex[intendedIndexes.length];
for (int i = 0; i < intendedIndexes.length; i++) {
if (intendedIndexes[i] instanceof SpecifiedIndex)
ret[i] = intendedIndexes[i];
else {
if (intendedIndexes[i] instanceof NDArrayIndexAll) {
// FIXME: LONG
SpecifiedIndex specifiedIndex = new SpecifiedIndex(ArrayUtil.range(0L, (long) shape.getInt(i)));
ret[i] = specifiedIndex;
} else if (intendedIndexes[i] instanceof NDArrayIndexEmpty) {
ret[i] = new SpecifiedIndex(new long[0]);
} else if (intendedIndexes[i] instanceof IntervalIndex) {
IntervalIndex intervalIndex = (IntervalIndex) intendedIndexes[i];
ret[i] = new SpecifiedIndex(ArrayUtil.range(intervalIndex.begin, intervalIndex.end(), intervalIndex.stride()));
}
}
}
return ret;
}
/**
* If it's a vector and index asking
* for a scalar just return the array
*/
int rank = Shape.rank(shapeInfo);
DataBuffer shape = Shape.shapeOf(shapeInfo);
if (intendedIndexes.length >= rank || Shape.isVector(shapeInfo) && intendedIndexes.length == 1) {
if (Shape.isRowVectorShape(shapeInfo) && intendedIndexes.length == 1) {
INDArrayIndex[] ret = new INDArrayIndex[2];
ret[0] = NDArrayIndex.point(0);
int size;
if (1 == shape.getInt(0) && rank == 2)
size = shape.getInt(1);
else
size = shape.getInt(0);
ret[1] = validate(size, intendedIndexes[0]);
return ret;
}
List<INDArrayIndex> retList = new ArrayList<>(intendedIndexes.length);
for (int i = 0; i < intendedIndexes.length; i++) {
if (i < rank)
retList.add(validate(shape.getInt(i), intendedIndexes[i]));
else
retList.add(intendedIndexes[i]);
}
return retList.toArray(new INDArrayIndex[retList.size()]);
}
List<INDArrayIndex> retList = new ArrayList<>(intendedIndexes.length + 1);
int numNewAxes = 0;
if (Shape.isMatrix(shape) && intendedIndexes.length == 1) {
retList.add(validate(shape.getInt(0), intendedIndexes[0]));
retList.add(NDArrayIndex.all());
} else {
for (int i = 0; i < intendedIndexes.length; i++) {
retList.add(validate(shape.getInt(i), intendedIndexes[i]));
if (intendedIndexes[i] instanceof NewAxis)
numNewAxes++;
}
}
int length = rank + numNewAxes;
// fill the rest with all
while (retList.size() < length) retList.add(NDArrayIndex.all());
return retList.toArray(new INDArrayIndex[retList.size()]);
}
Aggregations