use of org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException in project deeplearning4j by deeplearning4j.
the class KerasLstm method getRecurrentDropout.
/**
* Get LSTM recurrent weight dropout from Keras layer configuration. Currently unsupported.
*
* @param layerConfig dictionary containing Keras layer configuration
* @return epsilon
* @throws InvalidKerasConfigurationException
*/
public static double getRecurrentDropout(Map<String, Object> layerConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
/* NOTE: Keras "dropout" parameter determines dropout probability,
* while DL4J "dropout" parameter determines retention probability.
*/
Map<String, Object> innerConfig = getInnerLayerConfigFromConfig(layerConfig);
double dropout = 1.0;
if (innerConfig.containsKey(LAYER_FIELD_DROPOUT_U))
dropout = 1.0 - (double) innerConfig.get(LAYER_FIELD_DROPOUT_U);
if (dropout < 1.0)
throw new UnsupportedKerasConfigurationException("Dropout > 0 on LSTM recurrent connections not supported.");
return dropout;
}
use of org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException in project tika by apache.
the class DL4JInceptionV3Net method initialize.
@Override
public void initialize(Map<String, Param> params) throws TikaConfigException {
//STEP 1: resolve weights file, download if necessary
if (modelWeightsPath.startsWith("http://") || modelWeightsPath.startsWith("https://")) {
LOG.debug("Config instructed to download the weights file, doing so.");
try {
modelWeightsPath = cachedDownload(cacheDir, URI.create(modelWeightsPath)).getAbsolutePath();
} catch (IOException e) {
throw new TikaConfigException(e.getMessage(), e);
}
} else {
File modelFile = retrieveFile(modelWeightsPath);
if (!modelFile.exists()) {
LOG.error("modelWeights does not exist at :: {}", modelWeightsPath);
return;
}
modelWeightsPath = modelFile.getAbsolutePath();
}
//STEP 2: resolve model JSON
File modelJsonFile = retrieveFile(modelJsonPath);
if (modelJsonFile == null || !modelJsonFile.exists()) {
LOG.error("Could not locate file {}", modelJsonPath);
return;
}
modelJsonPath = modelJsonFile.getAbsolutePath();
//STEP 3: Load labels map
try (InputStream stream = retrieveResource(labelFile)) {
this.labelMap = loadClassIndex(stream);
} catch (IOException | ParseException e) {
LOG.error("Could not load labels map", e);
return;
}
//STEP 4: initialize the graph
try {
this.imageLoader = new NativeImageLoader(imgHeight, imgWidth, imgChannels);
LOG.info("Going to load Inception network...");
long st = System.currentTimeMillis();
this.graph = KerasModelImport.importKerasModelAndWeights(modelJsonPath, modelWeightsPath, false);
long time = System.currentTimeMillis() - st;
LOG.info("Loaded the Inception model. Time taken={}ms", time);
} catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) {
throw new TikaConfigException(e.getMessage(), e);
}
}
Aggregations