use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.
the class KerasBatchNormalization method setWeights.
/**
* Set weights for layer.
*
* @param weights Map from parameter name to INDArray.
*/
@Override
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
this.weights = new HashMap<String, INDArray>();
if (weights.containsKey(PARAM_NAME_BETA))
this.weights.put(BatchNormalizationParamInitializer.BETA, weights.get(PARAM_NAME_BETA));
else
throw new InvalidKerasConfigurationException("Parameter " + PARAM_NAME_BETA + " does not exist in weights");
if (weights.containsKey(PARAM_NAME_GAMMA))
this.weights.put(BatchNormalizationParamInitializer.GAMMA, weights.get(PARAM_NAME_GAMMA));
else
throw new InvalidKerasConfigurationException("Parameter " + PARAM_NAME_GAMMA + " does not exist in weights");
if (weights.containsKey(PARAM_NAME_RUNNING_MEAN))
this.weights.put(BatchNormalizationParamInitializer.GLOBAL_MEAN, weights.get(PARAM_NAME_RUNNING_MEAN));
else
throw new InvalidKerasConfigurationException("Parameter " + PARAM_NAME_RUNNING_MEAN + " does not exist in weights");
if (weights.containsKey(PARAM_NAME_RUNNING_STD))
this.weights.put(BatchNormalizationParamInitializer.GLOBAL_VAR, weights.get(PARAM_NAME_RUNNING_STD));
else
throw new InvalidKerasConfigurationException("Parameter " + PARAM_NAME_RUNNING_STD + " does not exist in weights");
if (weights.size() > 4) {
Set<String> paramNames = weights.keySet();
paramNames.remove(PARAM_NAME_BETA);
paramNames.remove(PARAM_NAME_GAMMA);
paramNames.remove(PARAM_NAME_RUNNING_MEAN);
paramNames.remove(PARAM_NAME_RUNNING_STD);
String unknownParamNames = paramNames.toString();
log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1));
}
}
use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.
the class KerasBatchNormalization method getBatchNormMode.
/**
* Get BatchNormalization "mode" from Keras layer configuration. Most modes currently unsupported.
*
* @param layerConfig dictionary containing Keras layer configuration
* @return
* @throws InvalidKerasConfigurationException
*/
protected int getBatchNormMode(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
Map<String, Object> innerConfig = getInnerLayerConfigFromConfig(layerConfig);
if (!innerConfig.containsKey(LAYER_FIELD_MODE))
throw new InvalidKerasConfigurationException("Keras BatchNorm layer config missing " + LAYER_FIELD_MODE + " field");
int batchNormMode = (int) innerConfig.get(LAYER_FIELD_MODE);
switch(batchNormMode) {
case LAYER_BATCHNORM_MODE_1:
throw new UnsupportedKerasConfigurationException("Keras BatchNormalization mode " + LAYER_BATCHNORM_MODE_1 + " (sample-wise) not supported");
case LAYER_BATCHNORM_MODE_2:
throw new UnsupportedKerasConfigurationException("Keras BatchNormalization (per-batch statistics during testing) " + LAYER_BATCHNORM_MODE_2 + " not supported");
}
return batchNormMode;
}
use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.
the class KerasMerge method getMergeMode.
public ElementWiseVertex.Op getMergeMode(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
Map<String, Object> innerConfig = getInnerLayerConfigFromConfig(layerConfig);
if (!innerConfig.containsKey(LAYER_FIELD_MODE))
throw new InvalidKerasConfigurationException("Keras Merge layer config missing " + LAYER_FIELD_MODE + " field");
ElementWiseVertex.Op op = null;
String mergeMode = (String) innerConfig.get(LAYER_FIELD_MODE);
switch(mergeMode) {
case LAYER_MERGE_MODE_SUM:
op = ElementWiseVertex.Op.Add;
break;
case LAYER_MERGE_MODE_MUL:
op = ElementWiseVertex.Op.Product;
break;
case LAYER_MERGE_MODE_CONCAT:
// leave null
break;
case LAYER_MERGE_MODE_AVE:
case LAYER_MERGE_MODE_COS:
case LAYER_MERGE_MODE_DOT:
case LAYER_MERGE_MODE_MAX:
default:
throw new UnsupportedKerasConfigurationException("Keras Merge layer mode " + mergeMode + " not supported");
}
return op;
}
use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.
the class KerasZeroPadding method getPaddingFromConfig.
/**
* Get zero padding from Keras layer configuration.
*
* @param layerConfig dictionary containing Keras layer configuration
* @return
* @throws InvalidKerasConfigurationException
*/
public int[] getPaddingFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
Map<String, Object> innerConfig = getInnerLayerConfigFromConfig(layerConfig);
if (!innerConfig.containsKey(LAYER_FIELD_PADDING))
throw new InvalidKerasConfigurationException("Field " + LAYER_FIELD_PADDING + " not found in Keras ZeroPadding layer");
List<Integer> paddingList = (List<Integer>) innerConfig.get(LAYER_FIELD_PADDING);
switch(this.className) {
case LAYER_CLASS_NAME_ZERO_PADDING_2D:
if (paddingList.size() == 2) {
paddingList.add(paddingList.get(1));
paddingList.add(1, paddingList.get(0));
}
if (paddingList.size() != 4)
throw new InvalidKerasConfigurationException("Found Keras ZeroPadding2D layer with invalid " + paddingList.size() + "D padding.");
break;
case LAYER_CLASS_NAME_ZERO_PADDING_1D:
throw new UnsupportedKerasConfigurationException("Keras ZeroPadding1D layer not supported");
default:
throw new UnsupportedKerasConfigurationException("Keras " + this.className + " padding layer not supported");
}
int[] padding = new int[paddingList.size()];
for (int i = 0; i < paddingList.size(); i++) padding[i] = paddingList.get(i);
return padding;
}
use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project deeplearning4j by deeplearning4j.
the class KerasLstm method setWeights.
/**
* Set weights for layer.
*
* @param weights
*/
@Override
public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
this.weights = new HashMap<String, INDArray>();
/* Keras stores LSTM parameters in distinct arrays (e.g., the recurrent weights
* are stored in four matrices: U_c, U_f, U_i, U_o) while DL4J stores them
* concatenated into one matrix (e.g., U = [ U_c U_f U_o U_i ]). Thus we have
* to map the Keras weight matrix to its corresponding DL4J weight submatrix.
*/
INDArray W_c;
if (weights.containsKey(KERAS_PARAM_NAME_W_C))
W_c = weights.get(KERAS_PARAM_NAME_W_C);
else
throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_W_C);
INDArray W_f;
if (weights.containsKey(KERAS_PARAM_NAME_W_F))
W_f = weights.get(KERAS_PARAM_NAME_W_F);
else
throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_W_F);
INDArray W_o;
if (weights.containsKey(KERAS_PARAM_NAME_W_O))
W_o = weights.get(KERAS_PARAM_NAME_W_O);
else
throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_W_O);
INDArray W_i;
if (weights.containsKey(KERAS_PARAM_NAME_W_I))
W_i = weights.get(KERAS_PARAM_NAME_W_I);
else
throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_W_I);
INDArray W = Nd4j.zeros(W_c.rows(), W_c.columns() + W_f.columns() + W_o.columns() + W_i.columns());
W.put(new INDArrayIndex[] { NDArrayIndex.interval(0, W.rows()), NDArrayIndex.interval(0, W_c.columns()) }, W_c);
W.put(new INDArrayIndex[] { NDArrayIndex.interval(0, W.rows()), NDArrayIndex.interval(W_c.columns(), W_c.columns() + W_f.columns()) }, W_f);
W.put(new INDArrayIndex[] { NDArrayIndex.interval(0, W.rows()), NDArrayIndex.interval(W_c.columns() + W_f.columns(), W_c.columns() + W_f.columns() + W_o.columns()) }, W_o);
W.put(new INDArrayIndex[] { NDArrayIndex.interval(0, W.rows()), NDArrayIndex.interval(W_c.columns() + W_f.columns() + W_o.columns(), W_c.columns() + W_f.columns() + W_o.columns() + W_i.columns()) }, W_i);
this.weights.put(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, W);
INDArray U_c;
if (weights.containsKey(KERAS_PARAM_NAME_U_C))
U_c = weights.get(KERAS_PARAM_NAME_U_C);
else
throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_U_C);
INDArray U_f;
if (weights.containsKey(KERAS_PARAM_NAME_U_F))
U_f = weights.get(KERAS_PARAM_NAME_U_F);
else
throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_U_F);
INDArray U_o;
if (weights.containsKey(KERAS_PARAM_NAME_U_O))
U_o = weights.get(KERAS_PARAM_NAME_U_O);
else
throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_U_O);
INDArray U_i;
if (weights.containsKey(KERAS_PARAM_NAME_U_I))
U_i = weights.get(KERAS_PARAM_NAME_U_I);
else
throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_U_I);
/* DL4J has three additional columns in its recurrent weights matrix that don't appear in Keras LSTMs.
* These are for peephole connections. Since Keras doesn't use them, we zero them out.
*/
INDArray U = Nd4j.zeros(U_c.rows(), U_c.columns() + U_f.columns() + U_o.columns() + U_i.columns() + 3);
U.put(new INDArrayIndex[] { NDArrayIndex.interval(0, U.rows()), NDArrayIndex.interval(0, U_c.columns()) }, U_c);
U.put(new INDArrayIndex[] { NDArrayIndex.interval(0, U.rows()), NDArrayIndex.interval(U_c.columns(), U_c.columns() + U_f.columns()) }, U_f);
U.put(new INDArrayIndex[] { NDArrayIndex.interval(0, U.rows()), NDArrayIndex.interval(U_c.columns() + U_f.columns(), U_c.columns() + U_f.columns() + U_o.columns()) }, U_o);
U.put(new INDArrayIndex[] { NDArrayIndex.interval(0, U.rows()), NDArrayIndex.interval(U_c.columns() + U_f.columns() + U_o.columns(), U_c.columns() + U_f.columns() + U_o.columns() + U_i.columns()) }, U_i);
this.weights.put(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY, U);
INDArray b_c;
if (weights.containsKey(KERAS_PARAM_NAME_B_C))
b_c = weights.get(KERAS_PARAM_NAME_B_C);
else
throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_B_C);
INDArray b_f;
if (weights.containsKey(KERAS_PARAM_NAME_B_F))
b_f = weights.get(KERAS_PARAM_NAME_B_F);
else
throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_B_F);
INDArray b_o;
if (weights.containsKey(KERAS_PARAM_NAME_B_O))
b_o = weights.get(KERAS_PARAM_NAME_B_O);
else
throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_B_O);
INDArray b_i;
if (weights.containsKey(KERAS_PARAM_NAME_B_I))
b_i = weights.get(KERAS_PARAM_NAME_B_I);
else
throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter " + KERAS_PARAM_NAME_B_I);
INDArray b = Nd4j.zeros(b_c.rows(), b_c.columns() + b_f.columns() + b_o.columns() + b_i.columns());
b.put(new INDArrayIndex[] { NDArrayIndex.interval(0, b.rows()), NDArrayIndex.interval(0, b_c.columns()) }, b_c);
b.put(new INDArrayIndex[] { NDArrayIndex.interval(0, b.rows()), NDArrayIndex.interval(b_c.columns(), b_c.columns() + b_f.columns()) }, b_f);
b.put(new INDArrayIndex[] { NDArrayIndex.interval(0, b.rows()), NDArrayIndex.interval(b_c.columns() + b_f.columns(), b_c.columns() + b_f.columns() + b_o.columns()) }, b_o);
b.put(new INDArrayIndex[] { NDArrayIndex.interval(0, b.rows()), NDArrayIndex.interval(b_c.columns() + b_f.columns() + b_o.columns(), b_c.columns() + b_f.columns() + b_o.columns() + b_i.columns()) }, b_i);
this.weights.put(GravesLSTMParamInitializer.BIAS_KEY, b);
if (weights.size() > NUM_WEIGHTS_IN_KERAS_LSTM) {
Set<String> paramNames = weights.keySet();
paramNames.remove(KERAS_PARAM_NAME_W_C);
paramNames.remove(KERAS_PARAM_NAME_W_F);
paramNames.remove(KERAS_PARAM_NAME_W_I);
paramNames.remove(KERAS_PARAM_NAME_W_O);
paramNames.remove(KERAS_PARAM_NAME_U_C);
paramNames.remove(KERAS_PARAM_NAME_U_F);
paramNames.remove(KERAS_PARAM_NAME_U_I);
paramNames.remove(KERAS_PARAM_NAME_U_O);
paramNames.remove(KERAS_PARAM_NAME_B_C);
paramNames.remove(KERAS_PARAM_NAME_B_F);
paramNames.remove(KERAS_PARAM_NAME_B_I);
paramNames.remove(KERAS_PARAM_NAME_B_O);
String unknownParamNames = paramNames.toString();
log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1));
}
}
Aggregations