Search in sources :

Example 11 with InvalidKerasConfigurationException

use of org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException in project tika by apache.

the class DL4JInceptionV3Net method initialize.

@Override
public void initialize(Map<String, Param> params) throws TikaConfigException {
    //STEP 1: resolve weights file, download if necessary
    if (modelWeightsPath.startsWith("http://") || modelWeightsPath.startsWith("https://")) {
        LOG.debug("Config instructed to download the weights file, doing so.");
        try {
            modelWeightsPath = cachedDownload(cacheDir, URI.create(modelWeightsPath)).getAbsolutePath();
        } catch (IOException e) {
            throw new TikaConfigException(e.getMessage(), e);
        }
    } else {
        File modelFile = retrieveFile(modelWeightsPath);
        if (!modelFile.exists()) {
            LOG.error("modelWeights does not exist at :: {}", modelWeightsPath);
            return;
        }
        modelWeightsPath = modelFile.getAbsolutePath();
    }
    //STEP 2: resolve model JSON
    File modelJsonFile = retrieveFile(modelJsonPath);
    if (modelJsonFile == null || !modelJsonFile.exists()) {
        LOG.error("Could not locate file {}", modelJsonPath);
        return;
    }
    modelJsonPath = modelJsonFile.getAbsolutePath();
    //STEP 3: Load labels map
    try (InputStream stream = retrieveResource(labelFile)) {
        this.labelMap = loadClassIndex(stream);
    } catch (IOException | ParseException e) {
        LOG.error("Could not load labels map", e);
        return;
    }
    //STEP 4: initialize the graph
    try {
        this.imageLoader = new NativeImageLoader(imgHeight, imgWidth, imgChannels);
        LOG.info("Going to load Inception network...");
        long st = System.currentTimeMillis();
        this.graph = KerasModelImport.importKerasModelAndWeights(modelJsonPath, modelWeightsPath, false);
        long time = System.currentTimeMillis() - st;
        LOG.info("Loaded the Inception model. Time taken={}ms", time);
    } catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) {
        throw new TikaConfigException(e.getMessage(), e);
    }
}
Also used : FileInputStream(java.io.FileInputStream) InputStream(java.io.InputStream) TikaConfigException(org.apache.tika.exception.TikaConfigException) IOException(java.io.IOException) ParseException(org.json.simple.parser.ParseException) UnsupportedKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException) File(java.io.File) InvalidKerasConfigurationException(org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException) NativeImageLoader(org.datavec.image.loader.NativeImageLoader)

Aggregations

InvalidKerasConfigurationException (org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)11 UnsupportedKerasConfigurationException (org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException)6 INDArray (org.nd4j.linalg.api.ndarray.INDArray)5 File (java.io.File)1 FileInputStream (java.io.FileInputStream)1 IOException (java.io.IOException)1 InputStream (java.io.InputStream)1 List (java.util.List)1 TikaConfigException (org.apache.tika.exception.TikaConfigException)1 NativeImageLoader (org.datavec.image.loader.NativeImageLoader)1 ElementWiseVertex (org.deeplearning4j.nn.conf.graph.ElementWiseVertex)1 WeightInit (org.deeplearning4j.nn.weights.WeightInit)1 ParseException (org.json.simple.parser.ParseException)1