Search in sources :

Example 6 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method getConfig.

private Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> getConfig() {
    boolean noData = currentSessionID == null;
    StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
    List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
    if (allStatic.size() == 0)
        return null;
    StatsInitializationReport p = (StatsInitializationReport) allStatic.get(0);
    String modelClass = p.getModelClassName();
    String config = p.getModelConfigJson();
    if (modelClass.endsWith("MultiLayerNetwork")) {
        MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(config);
        return new Triple<>(conf, null, null);
    } else if (modelClass.endsWith("ComputationGraph")) {
        ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(config);
        return new Triple<>(null, conf, null);
    } else {
        try {
            NeuralNetConfiguration layer = NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class);
            return new Triple<>(null, null, layer);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    return null;
}
Also used : Triple(org.deeplearning4j.berkeley.Triple) StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration)

Example 7 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method getLayerActivations.

private Triple<int[], float[], float[]> getLayerActivations(int index, TrainModuleUtils.GraphInfo gi, List<Persistable> updates, List<Integer> iterationCounts) {
    if (gi == null) {
        return EMPTY_TRIPLE;
    }
    //Index may be for an input, for example
    String type = gi.getVertexTypes().get(index);
    if ("input".equalsIgnoreCase(type)) {
        return EMPTY_TRIPLE;
    }
    List<String> origNames = gi.getOriginalVertexName();
    if (index < 0 || index >= origNames.size()) {
        return EMPTY_TRIPLE;
    }
    String layerName = origNames.get(index);
    int size = (updates == null ? 0 : updates.size());
    int[] iterCounts = new int[size];
    float[] mean = new float[size];
    float[] stdev = new float[size];
    int used = 0;
    if (updates != null) {
        int uCount = -1;
        for (Persistable u : updates) {
            uCount++;
            if (!(u instanceof StatsReport))
                continue;
            StatsReport sp = (StatsReport) u;
            if (iterationCounts == null) {
                iterCounts[used] = sp.getIterationCount();
            } else {
                iterCounts[used] = iterationCounts.get(uCount);
            }
            Map<String, Double> means = sp.getMean(StatsType.Activations);
            Map<String, Double> stdevs = sp.getStdev(StatsType.Activations);
            //TODO PROPER VALIDATION ETC, ERROR HANDLING
            if (means != null && means.containsKey(layerName)) {
                mean[used] = means.get(layerName).floatValue();
                stdev[used] = stdevs.get(layerName).floatValue();
                if (!Float.isFinite(mean[used])) {
                    mean[used] = (float) NAN_REPLACEMENT_VALUE;
                }
                if (!Float.isFinite(stdev[used])) {
                    stdev[used] = (float) NAN_REPLACEMENT_VALUE;
                }
                used++;
            }
        }
    }
    if (used != iterCounts.length) {
        iterCounts = Arrays.copyOf(iterCounts, used);
        mean = Arrays.copyOf(mean, used);
        stdev = Arrays.copyOf(stdev, used);
    }
    return new Triple<>(iterCounts, mean, stdev);
}
Also used : Triple(org.deeplearning4j.berkeley.Triple) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport) Persistable(org.deeplearning4j.api.storage.Persistable)

Example 8 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method getModelGraph.

private Result getModelGraph() {
    boolean noData = currentSessionID == null;
    StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
    List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
    if (allStatic.size() == 0) {
        return ok();
    }
    TrainModuleUtils.GraphInfo gi = getGraphInfo();
    if (gi == null)
        return ok();
    return ok(Json.toJson(gi));
}
Also used : StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable)

Example 9 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method sessionInfo.

private Result sessionInfo() {
    //Display, for each session: session ID, start time, number of workers, last update
    Map<String, Object> dataEachSession = new HashMap<>();
    for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
        Map<String, Object> dataThisSession = new HashMap<>();
        String sid = entry.getKey();
        StatsStorage ss = entry.getValue();
        List<String> workerIDs = ss.listWorkerIDsForSessionAndType(sid, StatsListener.TYPE_ID);
        int workerCount = (workerIDs == null ? 0 : workerIDs.size());
        List<Persistable> staticInfo = ss.getAllStaticInfos(sid, StatsListener.TYPE_ID);
        long initTime = Long.MAX_VALUE;
        if (staticInfo != null) {
            for (Persistable p : staticInfo) {
                initTime = Math.min(p.getTimeStamp(), initTime);
            }
        }
        long lastUpdateTime = Long.MIN_VALUE;
        List<Persistable> lastUpdatesAllWorkers = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
        for (Persistable p : lastUpdatesAllWorkers) {
            lastUpdateTime = Math.max(lastUpdateTime, p.getTimeStamp());
        }
        dataThisSession.put("numWorkers", workerCount);
        dataThisSession.put("initTime", initTime == Long.MAX_VALUE ? "" : initTime);
        dataThisSession.put("lastUpdate", lastUpdateTime == Long.MIN_VALUE ? "" : lastUpdateTime);
        // add hashmap of workers
        if (workerCount > 0) {
            dataThisSession.put("workers", workerIDs);
        }
        //Model info: type, # layers, # params...
        if (staticInfo != null && staticInfo.size() > 0) {
            StatsInitializationReport sr = (StatsInitializationReport) staticInfo.get(0);
            String modelClassName = sr.getModelClassName();
            if (modelClassName.endsWith("MultiLayerNetwork")) {
                modelClassName = "MultiLayerNetwork";
            } else if (modelClassName.endsWith("ComputationGraph")) {
                modelClassName = "ComputationGraph";
            }
            int numLayers = sr.getModelNumLayers();
            long numParams = sr.getModelNumParams();
            dataThisSession.put("modelType", modelClassName);
            dataThisSession.put("numLayers", numLayers);
            dataThisSession.put("numParams", numParams);
        } else {
            dataThisSession.put("modelType", "");
            dataThisSession.put("numLayers", "");
            dataThisSession.put("numParams", "");
        }
        dataEachSession.put(sid, dataThisSession);
    }
    return ok(Json.toJson(dataEachSession));
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable)

Example 10 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method getLayerInfoTable.

private String[][] getLayerInfoTable(int layerIdx, TrainModuleUtils.GraphInfo gi, I18N i18N, boolean noData, StatsStorage ss, String wid) {
    List<String[]> layerInfoRows = new ArrayList<>();
    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerName"), gi.getVertexNames().get(layerIdx) });
    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerType"), "" });
    if (!noData) {
        Persistable p = ss.getStaticInfo(currentSessionID, StatsListener.TYPE_ID, wid);
        if (p != null) {
            StatsInitializationReport initReport = (StatsInitializationReport) p;
            String configJson = initReport.getModelConfigJson();
            String modelClass = initReport.getModelClassName();
            //TODO error handling...
            String layerType = "";
            Layer layer = null;
            NeuralNetConfiguration nnc = null;
            if (modelClass.endsWith("MultiLayerNetwork")) {
                MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(configJson);
                //-1 because of input
                int confIdx = layerIdx - 1;
                if (confIdx >= 0) {
                    nnc = conf.getConf(confIdx);
                    layer = nnc.getLayer();
                } else {
                    //Input layer
                    layerType = "Input";
                }
            } else if (modelClass.endsWith("ComputationGraph")) {
                ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(configJson);
                String vertexName = gi.getVertexNames().get(layerIdx);
                Map<String, GraphVertex> vertices = conf.getVertices();
                if (vertices.containsKey(vertexName) && vertices.get(vertexName) instanceof LayerVertex) {
                    LayerVertex lv = (LayerVertex) vertices.get(vertexName);
                    nnc = lv.getLayerConf();
                    layer = nnc.getLayer();
                } else if (conf.getNetworkInputs().contains(vertexName)) {
                    layerType = "Input";
                } else {
                    GraphVertex gv = conf.getVertices().get(vertexName);
                    if (gv != null) {
                        layerType = gv.getClass().getSimpleName();
                    }
                }
            } else if (modelClass.endsWith("VariationalAutoencoder")) {
                layerType = gi.getVertexTypes().get(layerIdx);
                Map<String, String> map = gi.getVertexInfo().get(layerIdx);
                for (Map.Entry<String, String> entry : map.entrySet()) {
                    layerInfoRows.add(new String[] { entry.getKey(), entry.getValue() });
                }
            }
            if (layer != null) {
                layerType = getLayerType(layer);
            }
            if (layer != null) {
                String activationFn = null;
                if (layer instanceof FeedForwardLayer) {
                    FeedForwardLayer ffl = (FeedForwardLayer) layer;
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerNIn"), String.valueOf(ffl.getNIn()) });
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerSize"), String.valueOf(ffl.getNOut()) });
                    activationFn = layer.getActivationFn().toString();
                }
                int nParams = layer.initializer().numParams(nnc);
                layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerNParams"), String.valueOf(nParams) });
                if (nParams > 0) {
                    WeightInit wi = layer.getWeightInit();
                    String str = wi.toString();
                    if (wi == WeightInit.DISTRIBUTION) {
                        str += layer.getDist();
                    }
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str });
                    Updater u = layer.getUpdater();
                    String us = (u == null ? "" : u.toString());
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerUpdater"), us });
                //TODO: Maybe L1/L2, dropout, updater-specific values etc
                }
                if (layer instanceof ConvolutionLayer || layer instanceof SubsamplingLayer) {
                    int[] kernel;
                    int[] stride;
                    int[] padding;
                    if (layer instanceof ConvolutionLayer) {
                        ConvolutionLayer cl = (ConvolutionLayer) layer;
                        kernel = cl.getKernelSize();
                        stride = cl.getStride();
                        padding = cl.getPadding();
                    } else {
                        SubsamplingLayer ssl = (SubsamplingLayer) layer;
                        kernel = ssl.getKernelSize();
                        stride = ssl.getStride();
                        padding = ssl.getPadding();
                        activationFn = null;
                        layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerSubsamplingPoolingType"), ssl.getPoolingType().toString() });
                    }
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnKernel"), Arrays.toString(kernel) });
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnStride"), Arrays.toString(stride) });
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnPadding"), Arrays.toString(padding) });
                }
                if (activationFn != null) {
                    layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerActivationFn"), activationFn });
                }
            }
            layerInfoRows.get(1)[1] = layerType;
        }
    }
    return layerInfoRows.toArray(new String[layerInfoRows.size()][0]);
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) Persistable(org.deeplearning4j.api.storage.Persistable) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) WeightInit(org.deeplearning4j.nn.weights.WeightInit) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) Layer(org.deeplearning4j.nn.conf.layers.Layer) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) GraphVertex(org.deeplearning4j.nn.conf.graph.GraphVertex) Updater(org.deeplearning4j.nn.conf.Updater) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer)

Aggregations

Persistable (org.deeplearning4j.api.storage.Persistable)30 StatsStorage (org.deeplearning4j.api.storage.StatsStorage)14 StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)7 StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)6 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)6 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 Test (org.junit.Test)5 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 MapDBStatsStorage (org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 IOException (java.io.IOException)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)3 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)3 FlowStaticPersistable (org.deeplearning4j.ui.flow.data.FlowStaticPersistable)3 FlowUpdatePersistable (org.deeplearning4j.ui.flow.data.FlowUpdatePersistable)3 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)3 BufferedImage (java.awt.image.BufferedImage)2 File (java.io.File)2 ArrayList (java.util.ArrayList)2