Search in sources :

Example 1 with Parameter

use of ai.djl.nn.Parameter in project djl by deepjavalibrary.

the class BlockCoreTest method testEncode.

private void testEncode(NDManager manager, Block block) throws IOException, MalformedModelException {
    PairList<String, Parameter> original = block.getParameters();
    File temp = File.createTempFile("block", ".param");
    DataOutputStream os = new DataOutputStream(Files.newOutputStream(temp.toPath()));
    block.saveParameters(os);
    block.loadParameters(manager, new DataInputStream(Files.newInputStream(temp.toPath())));
    Files.delete(temp.toPath());
    PairList<String, Parameter> loaded = block.getParameters();
    int bound = original.size();
    for (int idx = 0; idx < bound; idx++) {
        Assert.assertEquals(original.valueAt(idx), loaded.valueAt(idx));
    }
}
Also used : DataOutputStream(java.io.DataOutputStream) Parameter(ai.djl.nn.Parameter) DataInputStream(java.io.DataInputStream) File(java.io.File)

Example 2 with Parameter

use of ai.djl.nn.Parameter in project djl by deepjavalibrary.

the class GRU method forwardInternal.

/**
 * {@inheritDoc}
 */
@Override
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
    NDArrayEx ex = inputs.head().getNDArrayInternal();
    Device device = inputs.head().getDevice();
    NDList gruParams = new NDList();
    for (Parameter parameter : parameters.values()) {
        gruParams.add(parameterStore.getValue(parameter, device, training));
    }
    NDArray input = inputs.head();
    if (inputs.size() == 1) {
        int batchIndex = batchFirst ? 0 : 1;
        inputs.add(input.getManager().zeros(new Shape((long) numLayers * getNumDirections(), input.size(batchIndex), stateSize)));
    }
    NDList outputs = ex.gru(input, inputs.get(1), gruParams, hasBiases, numLayers, dropRate, training, bidirectional, batchFirst);
    if (returnState) {
        return outputs;
    }
    outputs.stream().skip(1).forEach(NDArray::close);
    return new NDList(outputs.get(0));
}
Also used : Shape(ai.djl.ndarray.types.Shape) Device(ai.djl.Device) NDList(ai.djl.ndarray.NDList) Parameter(ai.djl.nn.Parameter) NDArray(ai.djl.ndarray.NDArray) NDArrayEx(ai.djl.ndarray.internal.NDArrayEx)

Example 3 with Parameter

use of ai.djl.nn.Parameter in project djl by deepjavalibrary.

the class LSTM method forwardInternal.

/**
 * {@inheritDoc}
 */
@Override
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
    NDArrayEx ex = inputs.head().getNDArrayInternal();
    Device device = inputs.head().getDevice();
    NDList rnnParams = new NDList();
    for (Parameter parameter : parameters.values()) {
        rnnParams.add(parameterStore.getValue(parameter, device, training));
    }
    NDArray input = inputs.head();
    if (inputs.size() == 1) {
        int batchIndex = batchFirst ? 0 : 1;
        Shape stateShape = new Shape((long) numLayers * getNumDirections(), input.size(batchIndex), stateSize);
        // hidden state
        inputs.add(input.getManager().zeros(stateShape));
        // cell
        inputs.add(input.getManager().zeros(stateShape));
    }
    NDList outputs = ex.lstm(input, new NDList(inputs.get(1), inputs.get(2)), rnnParams, hasBiases, numLayers, dropRate, training, bidirectional, batchFirst);
    if (returnState) {
        return outputs;
    }
    outputs.stream().skip(1).forEach(NDArray::close);
    return new NDList(outputs.get(0));
}
Also used : Shape(ai.djl.ndarray.types.Shape) Device(ai.djl.Device) NDList(ai.djl.ndarray.NDList) Parameter(ai.djl.nn.Parameter) NDArray(ai.djl.ndarray.NDArray) NDArrayEx(ai.djl.ndarray.internal.NDArrayEx)

Example 4 with Parameter

use of ai.djl.nn.Parameter in project djl by deepjavalibrary.

the class RNN method forwardInternal.

/**
 * {@inheritDoc}
 */
@Override
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
    NDArrayEx ex = inputs.head().getNDArrayInternal();
    Device device = inputs.head().getDevice();
    NDList rnnParams = new NDList();
    for (Parameter parameter : parameters.values()) {
        rnnParams.add(parameterStore.getValue(parameter, device, training));
    }
    NDArray input = inputs.head();
    if (inputs.size() == 1) {
        int batchIndex = batchFirst ? 0 : 1;
        inputs.add(input.getManager().zeros(new Shape((long) numLayers * getNumDirections(), input.size(batchIndex), stateSize)));
    }
    NDList outputs = ex.rnn(input, inputs.get(1), rnnParams, hasBiases, numLayers, activation, dropRate, training, bidirectional, batchFirst);
    if (returnState) {
        return outputs;
    }
    outputs.stream().skip(1).forEach(NDArray::close);
    return new NDList(outputs.get(0));
}
Also used : Shape(ai.djl.ndarray.types.Shape) Device(ai.djl.Device) NDList(ai.djl.ndarray.NDList) Parameter(ai.djl.nn.Parameter) NDArray(ai.djl.ndarray.NDArray) NDArrayEx(ai.djl.ndarray.internal.NDArrayEx)

Example 5 with Parameter

use of ai.djl.nn.Parameter in project djl by deepjavalibrary.

the class RecurrentBlock method prepare.

/**
 * {@inheritDoc}
 */
@Override
public void prepare(Shape[] inputs) {
    Shape inputShape = inputs[0];
    ParameterList parameters = getDirectParameters();
    for (Pair<String, Parameter> pair : parameters) {
        String name = pair.getKey();
        Parameter parameter = pair.getValue();
        int layer = Integer.parseInt(name.split("_")[1]);
        long inputSize = inputShape.get(2);
        if (layer > 0) {
            inputSize = stateSize * getNumDirections();
        }
        if (name.contains("BIAS")) {
            parameter.setShape(new Shape(gates * stateSize));
        } else if (name.contains("i2h")) {
            parameter.setShape(new Shape(gates * stateSize, inputSize));
        } else if (name.contains("h2h")) {
            parameter.setShape(new Shape(gates * stateSize, stateSize));
        } else {
            throw new IllegalArgumentException("Invalid parameter name");
        }
    }
}
Also used : Shape(ai.djl.ndarray.types.Shape) ParameterList(ai.djl.nn.ParameterList) Parameter(ai.djl.nn.Parameter)

Aggregations

Parameter (ai.djl.nn.Parameter)23 NDArray (ai.djl.ndarray.NDArray)15 NDList (ai.djl.ndarray.NDList)15 Shape (ai.djl.ndarray.types.Shape)15 Model (ai.djl.Model)10 NDManager (ai.djl.ndarray.NDManager)10 Block (ai.djl.nn.Block)10 DefaultTrainingConfig (ai.djl.training.DefaultTrainingConfig)10 Trainer (ai.djl.training.Trainer)10 TrainingConfig (ai.djl.training.TrainingConfig)10 Batch (ai.djl.training.dataset.Batch)10 Test (org.testng.annotations.Test)10 Device (ai.djl.Device)4 NDArrayEx (ai.djl.ndarray.internal.NDArrayEx)3 HashSet (java.util.HashSet)2 MalformedModelException (ai.djl.MalformedModelException)1 CachedOp (ai.djl.mxnet.engine.CachedOp)1 Symbol (ai.djl.mxnet.engine.Symbol)1 ParameterList (ai.djl.nn.ParameterList)1 ZooModel (ai.djl.repository.zoo.ZooModel)1