use of ai.djl.util.PairList in project djl by deepjavalibrary.
the class IValueUtils method getInputs.
static IValue[] getInputs(NDList ndList) {
List<PairList<String, PtNDArray>> outputs = new ArrayList<>();
Map<String, Integer> indexMap = new ConcurrentHashMap<>();
for (NDArray array : ndList) {
String name = array.getName();
if (name != null && name.contains(".")) {
String[] strings = name.split("\\.", 2);
int index = addToMap(indexMap, strings[0], outputs);
PairList<String, PtNDArray> pl = outputs.get(index);
pl.add(strings[1], (PtNDArray) array);
} else if (name != null && Pattern.matches("\\w+\\[]", name)) {
int index = addToMap(indexMap, name, outputs);
PairList<String, PtNDArray> pl = outputs.get(index);
pl.add("[]", (PtNDArray) array);
} else if (name != null && Pattern.matches("\\w+\\(\\)", name)) {
int index = addToMap(indexMap, name, outputs);
PairList<String, PtNDArray> pl = outputs.get(index);
pl.add("()", (PtNDArray) array);
} else {
PairList<String, PtNDArray> pl = new PairList<>();
pl.add(null, (PtNDArray) array);
outputs.add(pl);
}
}
IValue[] ret = new IValue[outputs.size()];
for (int i = 0; i < outputs.size(); ++i) {
PairList<String, PtNDArray> pl = outputs.get(i);
String key = pl.get(0).getKey();
if (key == null) {
// not List, Dict, Tuple input
ret[i] = IValue.from(pl.get(0).getValue());
} else if ("[]".equals(key)) {
// list
PtNDArray[] arrays = pl.values().toArray(new PtNDArray[0]);
ret[i] = IValue.listFrom(arrays);
} else if ("()".equals(key)) {
// Tuple
IValue[] arrays = pl.values().stream().map(IValue::from).toArray(IValue[]::new);
ret[i] = IValue.tupleFrom(arrays);
} else {
Map<String, PtNDArray> map = new ConcurrentHashMap<>();
for (Pair<String, PtNDArray> pair : pl) {
map.put(pair.getKey(), pair.getValue());
}
ret[i] = IValue.stringMapFrom(map);
}
}
return ret;
}
use of ai.djl.util.PairList in project djl by deepjavalibrary.
the class BaseModel method describeOutput.
/**
* {@inheritDoc}
*/
@Override
public PairList<String, Shape> describeOutput() {
if (block instanceof SymbolBlock) {
return ((SymbolBlock) block).describeOutput();
}
// create fake input to calculate output shapes
NDList input = new NDList();
for (Pair<String, Shape> pair : describeInput()) {
input.add(manager.ones(pair.getValue()));
}
List<String> outputNames = new ArrayList<>();
NDList output = block.forward(new ParameterStore(manager, true), input, false);
Shape[] outputShapes = output.stream().map(NDArray::getShape).toArray(Shape[]::new);
for (int i = 0; i < outputShapes.length; i++) {
outputNames.add("output" + i);
}
return new PairList<>(outputNames, Arrays.asList(outputShapes));
}
use of ai.djl.util.PairList in project djl by deepjavalibrary.
the class JnaUtils method createCachedOp.
/* Need tests
public static Pointer createSymbolFromJson(String json) {
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXSymbolCreateFromJSON(json, ref));
return ref.getValue();
}
public static Pointer compose(Pointer symbol, String name, String[] keys) {
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXSymbolCompose(symbol, name, keys.length, keys, ref));
return ref.getValue();
}
public static Pointer grad(Pointer symbol, String name, int numWrt, String[] wrt) {
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXSymbolCompose(symbol, name, numWrt, wrt, ref));
return ref.getValue();
}
public static Shape[] inferShape(Pointer symbol, String[] keys) {
IntBuffer argIndex = IntBuffer.allocate(1);
IntBuffer argShapeData = IntBuffer.allocate(1);
IntBuffer inShapeSize = IntBuffer.allocate(1);
PointerByReference inShapeNDim = new PointerByReference();
PointerByReference inShapeData = new PointerByReference();
IntBuffer outShapeSize = IntBuffer.allocate(1);
PointerByReference outShapeNDim = new PointerByReference();
PointerByReference outShapeData = new PointerByReference();
IntBuffer auxShapeSize = IntBuffer.allocate(1);
PointerByReference auxShapeNDim = new PointerByReference();
PointerByReference auxShapeData = new PointerByReference();
IntBuffer complete = IntBuffer.allocate(1);
checkCall(
LIB.MXSymbolInferShape(
symbol,
keys.length,
keys,
argIndex.array(),
argShapeData.array(),
inShapeSize,
inShapeNDim,
inShapeData,
outShapeSize,
outShapeNDim,
outShapeData,
auxShapeSize,
auxShapeNDim,
auxShapeData,
complete));
if (complete.get() == 1) {
Shape[] ret = new Shape[keys.length];
// TODO: add implementation
return ret; // NOPMD
}
return null;
}
public static Pointer inferType(Pointer symbol, String[] keys) {
int[] argTypeData = new int[1];
IntBuffer inTypeSize = IntBuffer.allocate(1);
PointerByReference inTypeData = new PointerByReference();
IntBuffer outTypeSize = IntBuffer.allocate(1);
PointerByReference outTypeData = new PointerByReference();
IntBuffer auxTypeSize = IntBuffer.allocate(1);
PointerByReference auxTypeData = new PointerByReference();
IntBuffer complete = IntBuffer.allocate(1);
checkCall(
LIB.MXSymbolInferType(
symbol,
keys.length,
keys,
argTypeData,
inTypeSize,
inTypeData,
outTypeSize,
outTypeData,
auxTypeSize,
auxTypeData,
complete));
if (complete.get() == 1) {
return outTypeData.getValue();
}
return null;
}
public static Pointer quantizeSymbol(
Pointer symbol,
String[] excludedSymbols,
String[] offlineParams,
String quantizedDType,
byte calibQuantize) {
PointerByReference ref = new PointerByReference();
checkCall(
LIB.MXQuantizeSymbol(
symbol,
ref,
excludedSymbols.length,
excludedSymbols,
offlineParams.length,
offlineParams,
quantizedDType,
calibQuantize));
return ref.getValue();
}
public static Pointer setCalibTableToQuantizedSymbol(
Pointer symbol,
String[] layerNames,
FloatBuffer lowQuantiles,
FloatBuffer highQuantiles) {
PointerByReference ref = new PointerByReference();
checkCall(
LIB.MXSetCalibTableToQuantizedSymbol(
symbol, layerNames.length, layerNames, lowQuantiles, highQuantiles, ref));
return ref.getValue();
}
public static Pointer genBackendSubgraph(Pointer symbol, String backend) {
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXGenBackendSubgraph(symbol, backend, ref));
return ref.getValue();
}
*/
// ///////////////////////////////
// MXNet Executors
// ///////////////////////////////
/* Need tests
public static void freeExecutor(Pointer executor) {
checkCall(LIB.MXExecutorFree(executor));
}
public static String getExecutorDebugString(Pointer executor) {
String[] ret = new String[1];
checkCall(LIB.MXExecutorPrint(executor, ret));
return ret[0];
}
public static void forward(Pointer executor, boolean isTrain) {
checkCall(LIB.MXExecutorForward(executor, isTrain ? 1 : 0));
}
public static Pointer backward(Pointer executor, int length) {
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXExecutorBackward(executor, length, ref));
return ref.getValue();
}
public static Pointer backwardEx(Pointer executor, int length, boolean isTrain) {
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXExecutorBackwardEx(executor, length, ref, isTrain ? 1 : 0));
return ref.getValue();
}
public static NDArray[] getExecutorOutputs(MxNDManager manager, Pointer executor) {
IntBuffer outSize = IntBuffer.allocate(1);
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXExecutorOutputs(executor, outSize, ref));
int size = outSize.get();
Pointer[] pointers = ref.getValue().getPointerArray(0, size);
NDArray[] ndArrays = new NDArray[size];
for (int i = 0; i < size; ++i) {
ndArrays[i] = manager.create(pointers[i]);
}
return ndArrays;
}
public static Pointer bindExecutorSimple(
Symbol symbol,
Device device,
String[] g2cKeys,
int[] g2cDeviceTypes,
int[] g2cDeviceIds,
String[] argParams,
String[] argParamGradReqs,
String[] inputArgNames,
IntBuffer inputShapeData,
IntBuffer inputShapeIdx,
String[] inputDataTypeNames,
int[] inputDataTypes,
String[] inputStorageTypeNames,
int[] inputStorageTypes,
String[] sharedArgParams,
IntBuffer sharedBufferLen,
String[] sharedBufferNames,
PointerByReference sharedBufferHandles,
PointerByReference updatedSharedBufferNames,
PointerByReference updatedSharedBufferHandles,
IntBuffer numInArgs,
PointerByReference inArgs,
PointerByReference argGrads,
IntBuffer numAuxStates,
PointerByReference auxStates) {
int deviceId = device.getDeviceId();
int deviceType = DeviceType.toDeviceType(device);
PointerByReference ref = new PointerByReference();
checkCall(
LIB.MXExecutorSimpleBind(
symbol.getHandle(),
deviceType,
deviceId,
g2cKeys == null ? 0 : g2cKeys.length,
g2cKeys,
g2cDeviceTypes,
g2cDeviceIds,
argParams.length,
argParams,
argParamGradReqs,
inputArgNames.length,
inputArgNames,
inputShapeData.array(),
inputShapeIdx.array(),
inputDataTypeNames.length,
inputDataTypeNames,
inputDataTypes,
inputStorageTypeNames == null ? 0 : inputStorageTypeNames.length,
inputStorageTypeNames,
inputStorageTypes,
sharedArgParams.length,
sharedArgParams,
sharedBufferLen,
sharedBufferNames,
sharedBufferHandles,
updatedSharedBufferNames,
updatedSharedBufferHandles,
numInArgs,
inArgs,
argGrads,
numAuxStates,
auxStates,
null,
ref));
return ref.getValue();
}
public static Pointer bindExecutor(
Pointer executor, Device device, int len, int auxStatesLen) {
int deviceId = device.getDeviceId();
int deviceType = DeviceType.toDeviceType(device);
PointerByReference inArgs = new PointerByReference();
PointerByReference argGradStore = new PointerByReference();
IntBuffer gradReqType = IntBuffer.allocate(1);
PointerByReference auxStates = new PointerByReference();
PointerByReference ref = new PointerByReference();
checkCall(
LIB.MXExecutorBind(
executor,
deviceType,
deviceId,
len,
inArgs,
argGradStore,
gradReqType,
auxStatesLen,
auxStates,
ref));
return ref.getValue();
}
public static Pointer bindExecutorX(
Pointer executor,
Device device,
int len,
int auxStatesLen,
String[] keys,
int[] deviceTypes,
int[] deviceIds) {
int deviceId = device.getDeviceId();
int deviceType = DeviceType.toDeviceType(device);
PointerByReference inArgs = new PointerByReference();
PointerByReference argGradStore = new PointerByReference();
IntBuffer gradReqType = IntBuffer.allocate(1);
PointerByReference auxStates = new PointerByReference();
PointerByReference ref = new PointerByReference();
checkCall(
LIB.MXExecutorBindX(
executor,
deviceType,
deviceId,
keys.length,
keys,
deviceTypes,
deviceIds,
len,
inArgs,
argGradStore,
gradReqType,
auxStatesLen,
auxStates,
ref));
return ref.getValue();
}
public static Pointer bindExecutorEX(
Pointer executor,
Device device,
int len,
int auxStatesLen,
String[] keys,
int[] deviceTypes,
int[] deviceIds,
Pointer sharedExecutor) {
int deviceId = device.getDeviceId();
int deviceType = DeviceType.toDeviceType(device);
PointerByReference inArgs = new PointerByReference();
PointerByReference argGradStore = new PointerByReference();
IntBuffer gradReqType = IntBuffer.allocate(1);
PointerByReference auxStates = new PointerByReference();
PointerByReference ref = new PointerByReference();
checkCall(
LIB.MXExecutorBindEX(
executor,
deviceType,
deviceId,
keys.length,
keys,
deviceTypes,
deviceIds,
len,
inArgs,
argGradStore,
gradReqType,
auxStatesLen,
auxStates,
sharedExecutor,
ref));
return ref.getValue();
}
public static Pointer reshapeExecutor(
boolean partialShaping,
boolean allowUpSizing,
Device device,
String[] keys,
int[] deviceTypes,
int[] deviceIds,
String[] providedArgShapeNames,
IntBuffer providedArgShapeData,
IntBuffer providedArgShapeIdx,
IntBuffer numInArgs,
PointerByReference inArgs,
PointerByReference argGrads,
IntBuffer numAuxStates,
PointerByReference auxStates,
Pointer sharedExecutor) {
int deviceId = device.getDeviceId();
int deviceType = DeviceType.toDeviceType(device);
PointerByReference ref = new PointerByReference();
checkCall(
LIB.MXExecutorReshape(
partialShaping ? 1 : 0,
allowUpSizing ? 1 : 0,
deviceType,
deviceId,
keys.length,
keys,
deviceTypes,
deviceIds,
providedArgShapeNames.length,
providedArgShapeNames,
providedArgShapeData.array(),
providedArgShapeIdx.array(),
numInArgs,
inArgs,
argGrads,
numAuxStates,
auxStates,
sharedExecutor,
ref));
return ref.getValue();
}
public static Pointer getOptimizedSymbol(Pointer executor) {
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXExecutorGetOptimizedSymbol(executor, ref));
return ref.getValue();
}
public static void setMonitorCallback(
Pointer executor,
MxnetLibrary.ExecutorMonitorCallback callback,
Pointer callbackHandle) {
checkCall(LIB.MXExecutorSetMonitorCallback(executor, callback, callbackHandle));
}
*/
// ///////////////////////////////
// MXNet Executors
// ///////////////////////////////
/*
public static Pointer[] listDataIters() {
IntBuffer outSize = IntBuffer.allocate(1);
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXListDataIters(outSize, ref));
return ref.getValue().getPointerArray(0, outSize.get());
}
public static Pointer createIter(Pointer iter, String[] keys, String[] values) {
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXDataIterCreateIter(iter, keys.length, keys, values, ref));
return ref.getValue();
}
public static String getIterInfo(Pointer iter) {
String[] name = new String[1];
String[] description = new String[1];
IntBuffer numArgs = IntBuffer.allocate(1);
PointerByReference argNames = new PointerByReference();
PointerByReference argTypes = new PointerByReference();
PointerByReference argDesc = new PointerByReference();
checkCall(
LIB.MXDataIterGetIterInfo(
iter, name, description, numArgs, argNames, argTypes, argDesc));
return name[0];
}
public static void freeDataIter(Pointer iter) {
checkCall(LIB.MXDataIterFree(iter));
}
public static int next(Pointer iter) {
IntBuffer ret = IntBuffer.allocate(1);
checkCall(LIB.MXDataIterNext(iter, ret));
return ret.get();
}
public static void beforeFirst(Pointer iter) {
checkCall(LIB.MXDataIterBeforeFirst(iter));
}
public static Pointer getData(Pointer iter) {
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXDataIterGetData(iter, ref));
return ref.getValue();
}
public static Pointer getIndex(Pointer iter) {
LongBuffer outSize = LongBuffer.wrap(new long[1]);
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXDataIterGetIndex(iter, ref, outSize));
return ref.getValue();
}
public static int getPadNum(Pointer iter) {
IntBuffer outSize = IntBuffer.allocate(1);
checkCall(LIB.MXDataIterGetPadNum(iter, outSize));
return outSize.get();
}
public static String getDataIterLabel(Pointer iter) {
PointerByReference ref = new PointerByReference();
checkCall(LIB.MXDataIterGetLabel(iter, ref));
return ref.getValue().getString(0, StandardCharsets.UTF_8.name());
}
*/
/*
int MXRecordIOWriterCreate(String uri, PointerByReference out);
int MXRecordIOWriterFree(Pointer handle);
int MXRecordIOWriterWriteRecord(Pointer handle, String buf, NativeSize size);
int MXRecordIOWriterTell(Pointer handle, NativeSizeByReference pos);
int MXRecordIOReaderCreate(String uri, PointerByReference out);
int MXRecordIOReaderFree(Pointer handle);
int MXRecordIOReaderReadRecord(Pointer handle, String buf[], NativeSizeByReference size);
int MXRecordIOReaderSeek(Pointer handle, NativeSize pos);
int MXRecordIOReaderTell(Pointer handle, NativeSizeByReference pos);
int MXRtcCreate(ByteBuffer name, int num_input, int num_output, PointerByReference input_names,
PointerByReference output_names, PointerByReference inputs,
PointerByReference outputs, ByteBuffer kernel, PointerByReference out);
int MXRtcPush(Pointer handle, int num_input, int num_output, PointerByReference inputs,
PointerByReference outputs, int gridDimX, int gridDimY, int gridDimZ,
int blockDimX, int blockDimY, int blockDimZ);
int MXRtcFree(Pointer handle);
int MXCustomOpRegister(String op_type, MxnetLibrary.CustomOpPropCreator creator);
int MXCustomFunctionRecord(int num_inputs, PointerByReference inputs, int num_outputs,
PointerByReference outputs, MXCallbackList callbacks);
int MXRtcCudaModuleCreate(String source, int num_options, String options[], int num_exports,
String exports[], PointerByReference out);
int MXRtcCudaModuleFree(Pointer handle);
int MXRtcCudaKernelCreate(Pointer handle, String name, int num_args, IntBuffer is_ndarray,
IntBuffer is_const, IntBuffer arg_types, PointerByReference out);
int MXRtcCudaKernelFree(Pointer handle);
int MXRtcCudaKernelCall(Pointer handle, int dev_id, PointerByReference args, int grid_dim_x,
int grid_dim_y, int grid_dim_z, int block_dim_x, int block_dim_y,
int block_dim_z, int shared_mem);
int MXNDArrayGetSharedMemHandle(Pointer handle, IntBuffer shared_pid, IntBuffer shared_id);
int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, IntBuffer shape, int ndim,
int dtype, PointerByReference out);
*/
// ////////////////////////////////
// cached Op
// ////////////////////////////////
/**
* Creates cached op flags.
*
* <p>data_indices : [0, 2, 4] Used to label input location, param_indices : [1, 3] Used to
* label param location
*
* @param block the {@link MxSymbolBlock} that loaded in the backend
* @param manager the NDManager used to create NDArray
* @param training true if CachedOp is created to forward in traning otherwise, false
* @return a CachedOp for inference
*/
public static CachedOp createCachedOp(MxSymbolBlock block, MxNDManager manager, boolean training) {
Symbol symbol = block.getSymbol();
List<Parameter> parameters = block.getAllParameters();
// record data index in all inputs
PairList<String, Integer> dataIndices = new PairList<>();
// record parameter index in all inputs
List<Integer> paramIndices = new ArrayList<>();
int index = 0;
for (Parameter parameter : parameters) {
// We assume uninitialized parameters are data inputs
if (parameter.isInitialized()) {
paramIndices.add(index);
} else {
dataIndices.add(parameter.getName(), index);
}
++index;
}
// Creating CachedOp
Pointer symbolHandle = symbol.getHandle();
PointerByReference ref = REFS.acquire();
// static_alloc and static_shape are enabled by default
String[] keys = { "data_indices", "param_indices", "static_alloc", "static_shape" };
String[] values = { dataIndices.values().toString(), paramIndices.toString(), "1", "1" };
checkCall(LIB.MXCreateCachedOpEx(symbolHandle, keys.length, keys, values, ref));
Pointer pointer = ref.getValue();
REFS.recycle(ref);
return new CachedOp(pointer, manager, parameters, paramIndices, dataIndices);
}
use of ai.djl.util.PairList in project djl by deepjavalibrary.
the class JnaUtils method getFunctionByName.
private static FunctionInfo getFunctionByName(String name, String functionName, Pointer handle) {
String[] nameRef = { name };
String[] description = new String[1];
IntBuffer numArgs = IntBuffer.allocate(1);
PointerByReference argNameRef = REFS.acquire();
PointerByReference argTypeRef = REFS.acquire();
PointerByReference argDescRef = REFS.acquire();
String[] keyVarArgs = new String[1];
String[] returnType = new String[1];
checkCall(LIB.MXSymbolGetAtomicSymbolInfo(handle, nameRef, description, numArgs, argNameRef, argTypeRef, argDescRef, keyVarArgs, returnType));
int count = numArgs.get();
PairList<String, String> arguments = new PairList<>();
if (count != 0) {
String[] argNames = argNameRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name());
String[] argTypes = argTypeRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name());
for (int i = 0; i < argNames.length; i++) {
arguments.add(argNames[i], argTypes[i]);
}
}
REFS.recycle(argNameRef);
REFS.recycle(argTypeRef);
REFS.recycle(argDescRef);
return new FunctionInfo(handle, functionName, arguments);
}
use of ai.djl.util.PairList in project djl by deepjavalibrary.
the class FunctionInfo method invoke.
/**
* Calls an operator with the given arguments.
*
* @param manager the manager to attach the result to
* @param src the input NDArray(s) to the operator
* @param params the non-NDArray arguments to the operator. Should be a {@code PairList<String,
* String>}
* @return the error code or zero for no errors
*/
public NDArray[] invoke(NDManager manager, NDArray[] src, PairList<String, ?> params) {
checkDevices(src);
PairList<Pointer, SparseFormat> pairList = JnaUtils.imperativeInvoke(handle, src, null, params);
final MxNDManager mxManager = (MxNDManager) manager;
return pairList.stream().map(pair -> {
if (pair.getValue() != SparseFormat.DENSE) {
return mxManager.create(pair.getKey(), pair.getValue());
}
return mxManager.create(pair.getKey());
}).toArray(MxNDArray[]::new);
}
Aggregations