use of ai.djl.nn.Parameter in project djl by deepjavalibrary.
the class BlockCoreTest method testEncode.
private void testEncode(NDManager manager, Block block) throws IOException, MalformedModelException {
PairList<String, Parameter> original = block.getParameters();
File temp = File.createTempFile("block", ".param");
DataOutputStream os = new DataOutputStream(Files.newOutputStream(temp.toPath()));
block.saveParameters(os);
block.loadParameters(manager, new DataInputStream(Files.newInputStream(temp.toPath())));
Files.delete(temp.toPath());
PairList<String, Parameter> loaded = block.getParameters();
int bound = original.size();
for (int idx = 0; idx < bound; idx++) {
Assert.assertEquals(original.valueAt(idx), loaded.valueAt(idx));
}
}
use of ai.djl.nn.Parameter in project djl by deepjavalibrary.
the class GRU method forwardInternal.
/**
* {@inheritDoc}
*/
@Override
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
NDArrayEx ex = inputs.head().getNDArrayInternal();
Device device = inputs.head().getDevice();
NDList gruParams = new NDList();
for (Parameter parameter : parameters.values()) {
gruParams.add(parameterStore.getValue(parameter, device, training));
}
NDArray input = inputs.head();
if (inputs.size() == 1) {
int batchIndex = batchFirst ? 0 : 1;
inputs.add(input.getManager().zeros(new Shape((long) numLayers * getNumDirections(), input.size(batchIndex), stateSize)));
}
NDList outputs = ex.gru(input, inputs.get(1), gruParams, hasBiases, numLayers, dropRate, training, bidirectional, batchFirst);
if (returnState) {
return outputs;
}
outputs.stream().skip(1).forEach(NDArray::close);
return new NDList(outputs.get(0));
}
use of ai.djl.nn.Parameter in project djl by deepjavalibrary.
the class LSTM method forwardInternal.
/**
* {@inheritDoc}
*/
@Override
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
NDArrayEx ex = inputs.head().getNDArrayInternal();
Device device = inputs.head().getDevice();
NDList rnnParams = new NDList();
for (Parameter parameter : parameters.values()) {
rnnParams.add(parameterStore.getValue(parameter, device, training));
}
NDArray input = inputs.head();
if (inputs.size() == 1) {
int batchIndex = batchFirst ? 0 : 1;
Shape stateShape = new Shape((long) numLayers * getNumDirections(), input.size(batchIndex), stateSize);
// hidden state
inputs.add(input.getManager().zeros(stateShape));
// cell
inputs.add(input.getManager().zeros(stateShape));
}
NDList outputs = ex.lstm(input, new NDList(inputs.get(1), inputs.get(2)), rnnParams, hasBiases, numLayers, dropRate, training, bidirectional, batchFirst);
if (returnState) {
return outputs;
}
outputs.stream().skip(1).forEach(NDArray::close);
return new NDList(outputs.get(0));
}
use of ai.djl.nn.Parameter in project djl by deepjavalibrary.
the class RNN method forwardInternal.
/**
* {@inheritDoc}
*/
@Override
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
NDArrayEx ex = inputs.head().getNDArrayInternal();
Device device = inputs.head().getDevice();
NDList rnnParams = new NDList();
for (Parameter parameter : parameters.values()) {
rnnParams.add(parameterStore.getValue(parameter, device, training));
}
NDArray input = inputs.head();
if (inputs.size() == 1) {
int batchIndex = batchFirst ? 0 : 1;
inputs.add(input.getManager().zeros(new Shape((long) numLayers * getNumDirections(), input.size(batchIndex), stateSize)));
}
NDList outputs = ex.rnn(input, inputs.get(1), rnnParams, hasBiases, numLayers, activation, dropRate, training, bidirectional, batchFirst);
if (returnState) {
return outputs;
}
outputs.stream().skip(1).forEach(NDArray::close);
return new NDList(outputs.get(0));
}
use of ai.djl.nn.Parameter in project djl by deepjavalibrary.
the class RecurrentBlock method prepare.
/**
* {@inheritDoc}
*/
@Override
public void prepare(Shape[] inputs) {
Shape inputShape = inputs[0];
ParameterList parameters = getDirectParameters();
for (Pair<String, Parameter> pair : parameters) {
String name = pair.getKey();
Parameter parameter = pair.getValue();
int layer = Integer.parseInt(name.split("_")[1]);
long inputSize = inputShape.get(2);
if (layer > 0) {
inputSize = stateSize * getNumDirections();
}
if (name.contains("BIAS")) {
parameter.setShape(new Shape(gates * stateSize));
} else if (name.contains("i2h")) {
parameter.setShape(new Shape(gates * stateSize, inputSize));
} else if (name.contains("h2h")) {
parameter.setShape(new Shape(gates * stateSize, stateSize));
} else {
throw new IllegalArgumentException("Invalid parameter name");
}
}
}
Aggregations