Search in sources :

Example 1 with InvalidKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.

the class KerasConvolution method setWeights.

/**
     * Set weights for layer.
     *
     * @param weights   Map from parameter name to INDArray.
     */
@Override
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
    this.weights = new HashMap<String, INDArray>();
    if (weights.containsKey(KERAS_PARAM_NAME_W)) {
        /* Theano and TensorFlow backends store convolutional weights
             * with a different dimensional ordering than DL4J so we need
             * to permute them to match.
             *
             * DL4J: (# outputs, # inputs, # rows, # cols)
             */
        INDArray kerasParamValue = weights.get(KERAS_PARAM_NAME_W);
        INDArray paramValue;
        switch(this.getDimOrder()) {
            case TENSORFLOW:
                /* TensorFlow convolutional weights: # rows, # cols, # inputs, # outputs */
                paramValue = kerasParamValue.permute(3, 2, 0, 1);
                break;
            case THEANO:
                /* Theano convolutional weights match DL4J: # outputs, # inputs, # rows, # cols
                     * Theano's default behavior is to rotate filters by 180 degree before application.
                     */
                paramValue = kerasParamValue.dup();
                for (int i = 0; i < paramValue.tensorssAlongDimension(2, 3); i++) {
                    //dup required since we only want data from the view not the whole array
                    INDArray copyFilter = paramValue.tensorAlongDimension(i, 2, 3).dup();
                    double[] flattenedFilter = copyFilter.ravel().data().asDouble();
                    ArrayUtils.reverse(flattenedFilter);
                    INDArray newFilter = Nd4j.create(flattenedFilter, copyFilter.shape());
                    //manipulating weights in place to save memory
                    INDArray inPlaceFilter = paramValue.tensorAlongDimension(i, 2, 3);
                    inPlaceFilter.muli(0).addi(newFilter);
                }
                break;
            default:
                throw new InvalidKerasConfigurationException("Unknown keras backend " + this.getDimOrder());
        }
        this.weights.put(ConvolutionParamInitializer.WEIGHT_KEY, paramValue);
    } else
        throw new InvalidKerasConfigurationException("Parameter " + KERAS_PARAM_NAME_W + " does not exist in weights");
    if (weights.containsKey(KERAS_PARAM_NAME_B))
        this.weights.put(ConvolutionParamInitializer.BIAS_KEY, weights.get(KERAS_PARAM_NAME_B));
    else
        throw new InvalidKerasConfigurationException("Parameter " + KERAS_PARAM_NAME_B + " does not exist in weights");
    if (weights.size() > 2) {
        Set<String> paramNames = weights.keySet();
        paramNames.remove(KERAS_PARAM_NAME_W);
        paramNames.remove(KERAS_PARAM_NAME_B);
        String unknownParamNames = paramNames.toString();
        log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1));
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) InvalidKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)

Example 2 with InvalidKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.

the class KerasDense method setWeights.

/**
     * Set weights for layer.
     *
     * @param weights
     */
@Override
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
    this.weights = new HashMap<String, INDArray>();
    if (weights.containsKey(KERAS_PARAM_NAME_W))
        this.weights.put(DefaultParamInitializer.WEIGHT_KEY, weights.get(KERAS_PARAM_NAME_W));
    else
        throw new InvalidKerasConfigurationException("Parameter " + KERAS_PARAM_NAME_W + " does not exist in weights");
    if (weights.containsKey(KERAS_PARAM_NAME_B))
        this.weights.put(DefaultParamInitializer.BIAS_KEY, weights.get(KERAS_PARAM_NAME_B));
    else
        throw new InvalidKerasConfigurationException("Parameter " + KERAS_PARAM_NAME_B + " does not exist in weights");
    if (weights.size() > 2) {
        Set<String> paramNames = weights.keySet();
        paramNames.remove(KERAS_PARAM_NAME_W);
        paramNames.remove(KERAS_PARAM_NAME_B);
        String unknownParamNames = paramNames.toString();
        log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1));
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) InvalidKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)

Example 3 with InvalidKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.

the class KerasEmbedding method setWeights.

/**
     * Set weights for layer.
     *
     * @param weights
     */
@Override
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
    this.weights = new HashMap<String, INDArray>();
    if (!weights.containsKey(KERAS_PARAM_NAME_W))
        throw new InvalidKerasConfigurationException("Parameter " + KERAS_PARAM_NAME_W + " does not exist in weights");
    INDArray W = weights.get(KERAS_PARAM_NAME_W);
    if (!weights.containsKey(KERAS_PARAM_NAME_B)) {
        log.warn("Setting DL4J EmbeddingLayer bias to zero.");
        weights.put(KERAS_PARAM_NAME_B, Nd4j.zeros(W.size(1)));
    }
    this.weights.put(DefaultParamInitializer.WEIGHT_KEY, W);
    this.weights.put(DefaultParamInitializer.BIAS_KEY, weights.get(KERAS_PARAM_NAME_B));
    if (weights.size() > 2) {
        Set<String> paramNames = weights.keySet();
        paramNames.remove(KERAS_PARAM_NAME_W);
        String unknownParamNames = paramNames.toString();
        log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1));
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) InvalidKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)

Example 4 with InvalidKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.

the class KerasLstm method getRecurrentWeightInitFromConfig.

/**
     * Get LSTM recurrent weight initialization from Keras layer configuration.
     *
     * @param layerConfig       dictionary containing Keras layer configuration
     * @return                  epsilon
     * @throws InvalidKerasConfigurationException
     */
public static WeightInit getRecurrentWeightInitFromConfig(Map<String, Object> layerConfig, boolean train) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
    Map<String, Object> innerConfig = getInnerLayerConfigFromConfig(layerConfig);
    if (!innerConfig.containsKey(LAYER_FIELD_INNER_INIT))
        throw new InvalidKerasConfigurationException("Keras LSTM layer config missing " + LAYER_FIELD_INNER_INIT + " field");
    String kerasInit = (String) innerConfig.get(LAYER_FIELD_INNER_INIT);
    WeightInit init;
    try {
        init = mapWeightInitialization(kerasInit);
    } catch (UnsupportedKerasConfigurationException e) {
        if (train)
            throw e;
        else {
            init = WeightInit.XAVIER;
            log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead).");
        }
    }
    return init;
}
Also used : WeightInit(org.deeplearning4j.nn.weights.WeightInit) UnsupportedKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException) InvalidKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)

Example 5 with InvalidKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.

the class KerasLstm method getForgetBiasInitFromConfig.

/**
     * Get LSTM forget gate bias initialization from Keras layer configuration.
     *
     * @param layerConfig       dictionary containing Keras layer configuration
     * @return                  epsilon
     * @throws InvalidKerasConfigurationException
     */
public static double getForgetBiasInitFromConfig(Map<String, Object> layerConfig, boolean train) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
    Map<String, Object> innerConfig = getInnerLayerConfigFromConfig(layerConfig);
    if (!innerConfig.containsKey(LAYER_FIELD_FORGET_BIAS_INIT))
        throw new InvalidKerasConfigurationException("Keras LSTM layer config missing " + LAYER_FIELD_FORGET_BIAS_INIT + " field");
    String kerasForgetBiasInit = (String) innerConfig.get(LAYER_FIELD_FORGET_BIAS_INIT);
    double init = 0;
    switch(kerasForgetBiasInit) {
        case LSTM_FORGET_BIAS_INIT_ZERO:
            init = 0.0;
            break;
        case LSTM_FORGET_BIAS_INIT_ONE:
            init = 1.0;
            break;
        default:
            if (train)
                throw new UnsupportedKerasConfigurationException("Unsupported LSTM forget gate bias initialization: " + kerasForgetBiasInit);
            else {
                init = 1.0;
                log.warn("Unsupported LSTM forget gate bias initialization: " + kerasForgetBiasInit + " (using 1 instead)");
            }
            break;
    }
    return init;
}
Also used : UnsupportedKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException) InvalidKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)

Aggregations

InvalidKerasConfigurationException (org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)11 UnsupportedKerasConfigurationException (org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException)6 INDArray (org.nd4j.linalg.api.ndarray.INDArray)5 File (java.io.File)1 FileInputStream (java.io.FileInputStream)1 IOException (java.io.IOException)1 InputStream (java.io.InputStream)1 List (java.util.List)1 TikaConfigException (org.apache.tika.exception.TikaConfigException)1 NativeImageLoader (org.datavec.image.loader.NativeImageLoader)1 ElementWiseVertex (org.deeplearning4j.nn.conf.graph.ElementWiseVertex)1 WeightInit (org.deeplearning4j.nn.weights.WeightInit)1 ParseException (org.json.simple.parser.ParseException)1