Search in sources :

Example 6 with UnsupportedKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException in project deeplearning4j by deeplearning4j.

the class KerasLstm method getRecurrentDropout.

/**
     * Get LSTM recurrent weight dropout from Keras layer configuration. Currently unsupported.
     *
     * @param layerConfig       dictionary containing Keras layer configuration
     * @return                  epsilon
     * @throws InvalidKerasConfigurationException
     */
public static double getRecurrentDropout(Map<String, Object> layerConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
    /* NOTE: Keras "dropout" parameter determines dropout probability,
         * while DL4J "dropout" parameter determines retention probability.
         */
    Map<String, Object> innerConfig = getInnerLayerConfigFromConfig(layerConfig);
    double dropout = 1.0;
    if (innerConfig.containsKey(LAYER_FIELD_DROPOUT_U))
        dropout = 1.0 - (double) innerConfig.get(LAYER_FIELD_DROPOUT_U);
    if (dropout < 1.0)
        throw new UnsupportedKerasConfigurationException("Dropout > 0 on LSTM recurrent connections not supported.");
    return dropout;
}
Also used : UnsupportedKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException)

Example 7 with UnsupportedKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException in project tika by apache.

the class DL4JInceptionV3Net method initialize.

@Override
public void initialize(Map<String, Param> params) throws TikaConfigException {
    //STEP 1: resolve weights file, download if necessary
    if (modelWeightsPath.startsWith("http://") || modelWeightsPath.startsWith("https://")) {
        LOG.debug("Config instructed to download the weights file, doing so.");
        try {
            modelWeightsPath = cachedDownload(cacheDir, URI.create(modelWeightsPath)).getAbsolutePath();
        } catch (IOException e) {
            throw new TikaConfigException(e.getMessage(), e);
        }
    } else {
        File modelFile = retrieveFile(modelWeightsPath);
        if (!modelFile.exists()) {
            LOG.error("modelWeights does not exist at :: {}", modelWeightsPath);
            return;
        }
        modelWeightsPath = modelFile.getAbsolutePath();
    }
    //STEP 2: resolve model JSON
    File modelJsonFile = retrieveFile(modelJsonPath);
    if (modelJsonFile == null || !modelJsonFile.exists()) {
        LOG.error("Could not locate file {}", modelJsonPath);
        return;
    }
    modelJsonPath = modelJsonFile.getAbsolutePath();
    //STEP 3: Load labels map
    try (InputStream stream = retrieveResource(labelFile)) {
        this.labelMap = loadClassIndex(stream);
    } catch (IOException | ParseException e) {
        LOG.error("Could not load labels map", e);
        return;
    }
    //STEP 4: initialize the graph
    try {
        this.imageLoader = new NativeImageLoader(imgHeight, imgWidth, imgChannels);
        LOG.info("Going to load Inception network...");
        long st = System.currentTimeMillis();
        this.graph = KerasModelImport.importKerasModelAndWeights(modelJsonPath, modelWeightsPath, false);
        long time = System.currentTimeMillis() - st;
        LOG.info("Loaded the Inception model. Time taken={}ms", time);
    } catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) {
        throw new TikaConfigException(e.getMessage(), e);
    }
}
Also used : FileInputStream(java.io.FileInputStream) InputStream(java.io.InputStream) TikaConfigException(org.apache.tika.exception.TikaConfigException) IOException(java.io.IOException) ParseException(org.json.simple.parser.ParseException) UnsupportedKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException) File(java.io.File) InvalidKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException) NativeImageLoader(org.datavec.image.loader.NativeImageLoader)

Aggregations

UnsupportedKerasConfigurationException (org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException)7 InvalidKerasConfigurationException (org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)6 File (java.io.File)1 FileInputStream (java.io.FileInputStream)1 IOException (java.io.IOException)1 InputStream (java.io.InputStream)1 List (java.util.List)1 TikaConfigException (org.apache.tika.exception.TikaConfigException)1 NativeImageLoader (org.datavec.image.loader.NativeImageLoader)1 ElementWiseVertex (org.deeplearning4j.nn.conf.graph.ElementWiseVertex)1 WeightInit (org.deeplearning4j.nn.weights.WeightInit)1 ParseException (org.json.simple.parser.ParseException)1