Search in sources :

Example 6 with InvalidKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.

the class KerasBatchNormalization method setWeights.

/**
     * Set weights for layer.
     *
     * @param weights   Map from parameter name to INDArray.
     */
@Override
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
    this.weights = new HashMap<String, INDArray>();
    if (weights.containsKey(PARAM_NAME_BETA))
        this.weights.put(BatchNormalizationParamInitializer.BETA, weights.get(PARAM_NAME_BETA));
    else
        throw new InvalidKerasConfigurationException("Parameter " + PARAM_NAME_BETA + " does not exist in weights");
    if (weights.containsKey(PARAM_NAME_GAMMA))
        this.weights.put(BatchNormalizationParamInitializer.GAMMA, weights.get(PARAM_NAME_GAMMA));
    else
        throw new InvalidKerasConfigurationException("Parameter " + PARAM_NAME_GAMMA + " does not exist in weights");
    if (weights.containsKey(PARAM_NAME_RUNNING_MEAN))
        this.weights.put(BatchNormalizationParamInitializer.GLOBAL_MEAN, weights.get(PARAM_NAME_RUNNING_MEAN));
    else
        throw new InvalidKerasConfigurationException("Parameter " + PARAM_NAME_RUNNING_MEAN + " does not exist in weights");
    if (weights.containsKey(PARAM_NAME_RUNNING_STD))
        this.weights.put(BatchNormalizationParamInitializer.GLOBAL_VAR, weights.get(PARAM_NAME_RUNNING_STD));
    else
        throw new InvalidKerasConfigurationException("Parameter " + PARAM_NAME_RUNNING_STD + " does not exist in weights");
    if (weights.size() > 4) {
        Set<String> paramNames = weights.keySet();
        paramNames.remove(PARAM_NAME_BETA);
        paramNames.remove(PARAM_NAME_GAMMA);
        paramNames.remove(PARAM_NAME_RUNNING_MEAN);
        paramNames.remove(PARAM_NAME_RUNNING_STD);
        String unknownParamNames = paramNames.toString();
        log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1));
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) InvalidKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)

Example 7 with InvalidKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.

the class KerasBatchNormalization method getBatchNormMode.

/**
     * Get BatchNormalization "mode" from Keras layer configuration. Most modes currently unsupported.
     *
     * @param layerConfig          dictionary containing Keras layer configuration
     * @return
     * @throws InvalidKerasConfigurationException
     */
protected int getBatchNormMode(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
    Map<String, Object> innerConfig = getInnerLayerConfigFromConfig(layerConfig);
    if (!innerConfig.containsKey(LAYER_FIELD_MODE))
        throw new InvalidKerasConfigurationException("Keras BatchNorm layer config missing " + LAYER_FIELD_MODE + " field");
    int batchNormMode = (int) innerConfig.get(LAYER_FIELD_MODE);
    switch(batchNormMode) {
        case LAYER_BATCHNORM_MODE_1:
            throw new UnsupportedKerasConfigurationException("Keras BatchNormalization mode " + LAYER_BATCHNORM_MODE_1 + " (sample-wise) not supported");
        case LAYER_BATCHNORM_MODE_2:
            throw new UnsupportedKerasConfigurationException("Keras BatchNormalization (per-batch statistics during testing) " + LAYER_BATCHNORM_MODE_2 + " not supported");
    }
    return batchNormMode;
}
Also used : UnsupportedKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException) InvalidKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)

Example 8 with InvalidKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.

the class KerasMerge method getMergeMode.

public ElementWiseVertex.Op getMergeMode(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
    Map<String, Object> innerConfig = getInnerLayerConfigFromConfig(layerConfig);
    if (!innerConfig.containsKey(LAYER_FIELD_MODE))
        throw new InvalidKerasConfigurationException("Keras Merge layer config missing " + LAYER_FIELD_MODE + " field");
    ElementWiseVertex.Op op = null;
    String mergeMode = (String) innerConfig.get(LAYER_FIELD_MODE);
    switch(mergeMode) {
        case LAYER_MERGE_MODE_SUM:
            op = ElementWiseVertex.Op.Add;
            break;
        case LAYER_MERGE_MODE_MUL:
            op = ElementWiseVertex.Op.Product;
            break;
        case LAYER_MERGE_MODE_CONCAT:
            // leave null
            break;
        case LAYER_MERGE_MODE_AVE:
        case LAYER_MERGE_MODE_COS:
        case LAYER_MERGE_MODE_DOT:
        case LAYER_MERGE_MODE_MAX:
        default:
            throw new UnsupportedKerasConfigurationException("Keras Merge layer mode " + mergeMode + " not supported");
    }
    return op;
}
Also used : ElementWiseVertex(org.deeplearning4j.nn.conf.graph.ElementWiseVertex) UnsupportedKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException) InvalidKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)

Example 9 with InvalidKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.

the class KerasZeroPadding method getPaddingFromConfig.

/**
     * Get zero padding from Keras layer configuration.
     *
     * @param layerConfig       dictionary containing Keras layer configuration
     * @return
     * @throws InvalidKerasConfigurationException
     */
public int[] getPaddingFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
    Map<String, Object> innerConfig = getInnerLayerConfigFromConfig(layerConfig);
    if (!innerConfig.containsKey(LAYER_FIELD_PADDING))
        throw new InvalidKerasConfigurationException("Field " + LAYER_FIELD_PADDING + " not found in Keras ZeroPadding layer");
    List<Integer> paddingList = (List<Integer>) innerConfig.get(LAYER_FIELD_PADDING);
    switch(this.className) {
        case LAYER_CLASS_NAME_ZERO_PADDING_2D:
            if (paddingList.size() == 2) {
                paddingList.add(paddingList.get(1));
                paddingList.add(1, paddingList.get(0));
            }
            if (paddingList.size() != 4)
                throw new InvalidKerasConfigurationException("Found Keras ZeroPadding2D layer with invalid " + paddingList.size() + "D padding.");
            break;
        case LAYER_CLASS_NAME_ZERO_PADDING_1D:
            throw new UnsupportedKerasConfigurationException("Keras ZeroPadding1D layer not supported");
        default:
            throw new UnsupportedKerasConfigurationException("Keras " + this.className + " padding layer not supported");
    }
    int[] padding = new int[paddingList.size()];
    for (int i = 0; i < paddingList.size(); i++) padding[i] = paddingList.get(i);
    return padding;
}
Also used : List(java.util.List) UnsupportedKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException) InvalidKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)

Example 10 with InvalidKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.

the class KerasLstm method setWeights.

/**
     * Set weights for layer.
     *
     * @param weights
     */
@Override
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
    this.weights = new HashMap<String, INDArray>();
    /* Keras stores LSTM parameters in distinct arrays (e.g., the recurrent weights
         * are stored in four matrices: U_c, U_f, U_i, U_o) while DL4J stores them
         * concatenated into one matrix (e.g., U = [ U_c U_f U_o U_i ]). Thus we have
         * to map the Keras weight matrix to its corresponding DL4J weight submatrix.
         */
    INDArray W_c;
    if (weights.containsKey(KERAS_PARAM_NAME_W_C))
        W_c = weights.get(KERAS_PARAM_NAME_W_C);
    else
        throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_W_C);
    INDArray W_f;
    if (weights.containsKey(KERAS_PARAM_NAME_W_F))
        W_f = weights.get(KERAS_PARAM_NAME_W_F);
    else
        throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_W_F);
    INDArray W_o;
    if (weights.containsKey(KERAS_PARAM_NAME_W_O))
        W_o = weights.get(KERAS_PARAM_NAME_W_O);
    else
        throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_W_O);
    INDArray W_i;
    if (weights.containsKey(KERAS_PARAM_NAME_W_I))
        W_i = weights.get(KERAS_PARAM_NAME_W_I);
    else
        throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_W_I);
    INDArray W = Nd4j.zeros(W_c.rows(), W_c.columns() + W_f.columns() + W_o.columns() + W_i.columns());
    W.put(new INDArrayIndex[] { NDArrayIndex.interval(0, W.rows()), NDArrayIndex.interval(0, W_c.columns()) }, W_c);
    W.put(new INDArrayIndex[] { NDArrayIndex.interval(0, W.rows()), NDArrayIndex.interval(W_c.columns(), W_c.columns() + W_f.columns()) }, W_f);
    W.put(new INDArrayIndex[] { NDArrayIndex.interval(0, W.rows()), NDArrayIndex.interval(W_c.columns() + W_f.columns(), W_c.columns() + W_f.columns() + W_o.columns()) }, W_o);
    W.put(new INDArrayIndex[] { NDArrayIndex.interval(0, W.rows()), NDArrayIndex.interval(W_c.columns() + W_f.columns() + W_o.columns(), W_c.columns() + W_f.columns() + W_o.columns() + W_i.columns()) }, W_i);
    this.weights.put(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, W);
    INDArray U_c;
    if (weights.containsKey(KERAS_PARAM_NAME_U_C))
        U_c = weights.get(KERAS_PARAM_NAME_U_C);
    else
        throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_U_C);
    INDArray U_f;
    if (weights.containsKey(KERAS_PARAM_NAME_U_F))
        U_f = weights.get(KERAS_PARAM_NAME_U_F);
    else
        throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_U_F);
    INDArray U_o;
    if (weights.containsKey(KERAS_PARAM_NAME_U_O))
        U_o = weights.get(KERAS_PARAM_NAME_U_O);
    else
        throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_U_O);
    INDArray U_i;
    if (weights.containsKey(KERAS_PARAM_NAME_U_I))
        U_i = weights.get(KERAS_PARAM_NAME_U_I);
    else
        throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_U_I);
    /* DL4J has three additional columns in its recurrent weights matrix that don't appear in Keras LSTMs.
         * These are for peephole connections. Since Keras doesn't use them, we zero them out.
         */
    INDArray U = Nd4j.zeros(U_c.rows(), U_c.columns() + U_f.columns() + U_o.columns() + U_i.columns() + 3);
    U.put(new INDArrayIndex[] { NDArrayIndex.interval(0, U.rows()), NDArrayIndex.interval(0, U_c.columns()) }, U_c);
    U.put(new INDArrayIndex[] { NDArrayIndex.interval(0, U.rows()), NDArrayIndex.interval(U_c.columns(), U_c.columns() + U_f.columns()) }, U_f);
    U.put(new INDArrayIndex[] { NDArrayIndex.interval(0, U.rows()), NDArrayIndex.interval(U_c.columns() + U_f.columns(), U_c.columns() + U_f.columns() + U_o.columns()) }, U_o);
    U.put(new INDArrayIndex[] { NDArrayIndex.interval(0, U.rows()), NDArrayIndex.interval(U_c.columns() + U_f.columns() + U_o.columns(), U_c.columns() + U_f.columns() + U_o.columns() + U_i.columns()) }, U_i);
    this.weights.put(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY, U);
    INDArray b_c;
    if (weights.containsKey(KERAS_PARAM_NAME_B_C))
        b_c = weights.get(KERAS_PARAM_NAME_B_C);
    else
        throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_B_C);
    INDArray b_f;
    if (weights.containsKey(KERAS_PARAM_NAME_B_F))
        b_f = weights.get(KERAS_PARAM_NAME_B_F);
    else
        throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_B_F);
    INDArray b_o;
    if (weights.containsKey(KERAS_PARAM_NAME_B_O))
        b_o = weights.get(KERAS_PARAM_NAME_B_O);
    else
        throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_B_O);
    INDArray b_i;
    if (weights.containsKey(KERAS_PARAM_NAME_B_I))
        b_i = weights.get(KERAS_PARAM_NAME_B_I);
    else
        throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_B_I);
    INDArray b = Nd4j.zeros(b_c.rows(), b_c.columns() + b_f.columns() + b_o.columns() + b_i.columns());
    b.put(new INDArrayIndex[] { NDArrayIndex.interval(0, b.rows()), NDArrayIndex.interval(0, b_c.columns()) }, b_c);
    b.put(new INDArrayIndex[] { NDArrayIndex.interval(0, b.rows()), NDArrayIndex.interval(b_c.columns(), b_c.columns() + b_f.columns()) }, b_f);
    b.put(new INDArrayIndex[] { NDArrayIndex.interval(0, b.rows()), NDArrayIndex.interval(b_c.columns() + b_f.columns(), b_c.columns() + b_f.columns() + b_o.columns()) }, b_o);
    b.put(new INDArrayIndex[] { NDArrayIndex.interval(0, b.rows()), NDArrayIndex.interval(b_c.columns() + b_f.columns() + b_o.columns(), b_c.columns() + b_f.columns() + b_o.columns() + b_i.columns()) }, b_i);
    this.weights.put(GravesLSTMParamInitializer.BIAS_KEY, b);
    if (weights.size() > NUM_WEIGHTS_IN_KERAS_LSTM) {
        Set<String> paramNames = weights.keySet();
        paramNames.remove(KERAS_PARAM_NAME_W_C);
        paramNames.remove(KERAS_PARAM_NAME_W_F);
        paramNames.remove(KERAS_PARAM_NAME_W_I);
        paramNames.remove(KERAS_PARAM_NAME_W_O);
        paramNames.remove(KERAS_PARAM_NAME_U_C);
        paramNames.remove(KERAS_PARAM_NAME_U_F);
        paramNames.remove(KERAS_PARAM_NAME_U_I);
        paramNames.remove(KERAS_PARAM_NAME_U_O);
        paramNames.remove(KERAS_PARAM_NAME_B_C);
        paramNames.remove(KERAS_PARAM_NAME_B_F);
        paramNames.remove(KERAS_PARAM_NAME_B_I);
        paramNames.remove(KERAS_PARAM_NAME_B_O);
        String unknownParamNames = paramNames.toString();
        log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1));
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) InvalidKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)

Aggregations

InvalidKerasConfigurationException (org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)11 UnsupportedKerasConfigurationException (org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException)6 INDArray (org.nd4j.linalg.api.ndarray.INDArray)5 File (java.io.File)1 FileInputStream (java.io.FileInputStream)1 IOException (java.io.IOException)1 InputStream (java.io.InputStream)1 List (java.util.List)1 TikaConfigException (org.apache.tika.exception.TikaConfigException)1 NativeImageLoader (org.datavec.image.loader.NativeImageLoader)1 ElementWiseVertex (org.deeplearning4j.nn.conf.graph.ElementWiseVertex)1 WeightInit (org.deeplearning4j.nn.weights.WeightInit)1 ParseException (org.json.simple.parser.ParseException)1