Search in sources :

Example 16 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class ConvolutionalIterationListener method iterationDone.

/**
     * Event listener for each iteration
     *
     * @param model     the model iterating
     * @param iteration the iteration number
     */
@Override
public void iterationDone(Model model, int iteration) {
    if (iteration % freq == 0) {
        List<INDArray> tensors = new ArrayList<>();
        int cnt = 0;
        Random rnd = new Random();
        BufferedImage sourceImage = null;
        if (model instanceof MultiLayerNetwork) {
            MultiLayerNetwork l = (MultiLayerNetwork) model;
            for (Layer layer : l.getLayers()) {
                if (layer.type() == Layer.Type.CONVOLUTIONAL) {
                    INDArray output = layer.activate();
                    int sampleDim = rnd.nextInt(output.shape()[0] - 1) + 1;
                    if (cnt == 0) {
                        INDArray inputs = ((ConvolutionLayer) layer).input();
                        try {
                            sourceImage = restoreRGBImage(inputs.tensorAlongDimension(sampleDim, new int[] { 3, 2, 1 }));
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                    INDArray tad = output.tensorAlongDimension(sampleDim, 3, 2, 1);
                    tensors.add(tad);
                    cnt++;
                }
            }
        } else if (model instanceof ComputationGraph) {
            ComputationGraph l = (ComputationGraph) model;
            for (Layer layer : l.getLayers()) {
                if (layer.type() == Layer.Type.CONVOLUTIONAL) {
                    INDArray output = layer.activate();
                    int sampleDim = rnd.nextInt(output.shape()[0] - 1) + 1;
                    if (cnt == 0) {
                        INDArray inputs = ((ConvolutionLayer) layer).input();
                        try {
                            sourceImage = restoreRGBImage(inputs.tensorAlongDimension(sampleDim, new int[] { 3, 2, 1 }));
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                    INDArray tad = output.tensorAlongDimension(sampleDim, 3, 2, 1);
                    tensors.add(tad);
                    cnt++;
                }
            }
        }
        BufferedImage render = rasterizeConvoLayers(tensors, sourceImage);
        Persistable p = new ConvolutionListenerPersistable(sessionID, workerID, System.currentTimeMillis(), render);
        ssr.putStaticInfo(p);
        minibatchNum++;
    }
}
Also used : Persistable(org.deeplearning4j.api.storage.Persistable) ArrayList(java.util.ArrayList) Layer(org.deeplearning4j.nn.api.Layer) ConvolutionLayer(org.deeplearning4j.nn.layers.convolution.ConvolutionLayer) BufferedImage(java.awt.image.BufferedImage) ConvolutionLayer(org.deeplearning4j.nn.layers.convolution.ConvolutionLayer) IOException(java.io.IOException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Random(java.util.Random) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 17 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method getLayerLearningRates.

private Map<String, Object> getLayerLearningRates(int layerIdx, TrainModuleUtils.GraphInfo gi, List<Persistable> updates, List<Integer> iterationCounts, ModelType modelType) {
    if (gi == null) {
        return Collections.emptyMap();
    }
    String layerName = gi.getOriginalVertexName().get(layerIdx);
    int size = (updates == null ? 0 : updates.size());
    int[] iterCounts = new int[size];
    Map<String, float[]> byName = new HashMap<>();
    int used = 0;
    if (updates != null) {
        int uCount = -1;
        for (Persistable u : updates) {
            uCount++;
            if (!(u instanceof StatsReport))
                continue;
            StatsReport sp = (StatsReport) u;
            if (iterationCounts == null) {
                iterCounts[used] = sp.getIterationCount();
            } else {
                iterCounts[used] = iterationCounts.get(uCount);
            }
            //TODO PROPER VALIDATION ETC, ERROR HANDLING
            Map<String, Double> lrs = sp.getLearningRates();
            String prefix;
            if (modelType == ModelType.Layer) {
                prefix = layerName;
            } else {
                prefix = layerName + "_";
            }
            for (String p : lrs.keySet()) {
                if (p.startsWith(prefix)) {
                    String layerParamName = p.substring(Math.min(p.length(), prefix.length()));
                    if (!byName.containsKey(layerParamName)) {
                        byName.put(layerParamName, new float[size]);
                    }
                    float[] lrThisParam = byName.get(layerParamName);
                    lrThisParam[used] = lrs.get(p).floatValue();
                }
            }
            used++;
        }
    }
    List<String> paramNames = new ArrayList<>(byName.keySet());
    //Sorted for consistency
    Collections.sort(paramNames);
    Map<String, Object> ret = new HashMap<>();
    ret.put("iterCounts", iterCounts);
    ret.put("paramNames", paramNames);
    ret.put("lrs", byName);
    return ret;
}
Also used : Persistable(org.deeplearning4j.api.storage.Persistable) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport)

Example 18 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method getDefaultSession.

private void getDefaultSession() {
    if (currentSessionID != null)
        return;
    long mostRecentTime = Long.MIN_VALUE;
    String sessionID = null;
    for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
        List<Persistable> staticInfos = entry.getValue().getAllStaticInfos(entry.getKey(), StatsListener.TYPE_ID);
        if (staticInfos == null || staticInfos.size() == 0)
            continue;
        Persistable p = staticInfos.get(0);
        long thisTime = p.getTimeStamp();
        if (thisTime > mostRecentTime) {
            mostRecentTime = thisTime;
            sessionID = entry.getKey();
        }
    }
    if (sessionID != null) {
        currentSessionID = sessionID;
    }
}
Also used : StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable)

Example 19 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method getSystemData.

public Result getSystemData() {
    Long lastUpdate = lastUpdateForSession.get(currentSessionID);
    if (lastUpdate == null)
        lastUpdate = -1L;
    I18N i18n = I18NProvider.getInstance();
    //First: get the MOST RECENT update...
    //Then get all updates from most recent - 5 minutes -> TODO make this configurable...
    boolean noData = currentSessionID == null;
    StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
    List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
    List<Persistable> latestUpdates = (noData ? Collections.EMPTY_LIST : ss.getLatestUpdateAllWorkers(currentSessionID, StatsListener.TYPE_ID));
    long lastUpdateTime = -1;
    if (latestUpdates == null || latestUpdates.size() == 0) {
        noData = true;
    } else {
        for (Persistable p : latestUpdates) {
            lastUpdateTime = Math.max(lastUpdateTime, p.getTimeStamp());
        }
    }
    //TODO Make configurable
    long fromTime = lastUpdateTime - 5 * 60 * 1000;
    List<Persistable> lastNMinutes = (noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, fromTime));
    Map<String, Object> mem = getMemory(allStatic, lastNMinutes, i18n);
    Pair<Map<String, Object>, Map<String, Object>> hwSwInfo = getHardwareSoftwareInfo(allStatic, i18n);
    Map<String, Object> ret = new HashMap<>();
    ret.put("updateTimestamp", lastUpdate);
    ret.put("memory", mem);
    ret.put("hardware", hwSwInfo.getFirst());
    ret.put("software", hwSwInfo.getSecond());
    return ok(Json.toJson(ret));
}
Also used : StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable)

Example 20 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method getModelData.

private Result getModelData(String str) {
    Long lastUpdateTime = lastUpdateForSession.get(currentSessionID);
    if (lastUpdateTime == null)
        lastUpdateTime = -1L;
    //TODO validation
    int layerIdx = Integer.parseInt(str);
    I18N i18N = I18NProvider.getInstance();
    //Model info for layer
    boolean noData = currentSessionID == null;
    //First pass (optimize later): query all data...
    StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
    String wid = getWorkerIdForIndex(currentWorkerIdx);
    if (wid == null) {
        noData = true;
    }
    Map<String, Object> result = new HashMap<>();
    result.put("updateTimestamp", lastUpdateTime);
    Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> conf = getConfig();
    if (conf == null) {
        return ok(Json.toJson(result));
    }
    TrainModuleUtils.GraphInfo gi = getGraphInfo();
    if (gi == null) {
        return ok(Json.toJson(result));
    }
    // Get static layer info
    String[][] layerInfoTable = getLayerInfoTable(layerIdx, gi, i18N, noData, ss, wid);
    result.put("layerInfo", layerInfoTable);
    //First: get all data, and subsample it if necessary, to avoid returning too many points...
    List<Persistable> updates = (noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0));
    List<Integer> iterationCounts = null;
    boolean needToHandleLegacyIterCounts = false;
    if (updates != null && updates.size() > maxChartPoints) {
        int subsamplingFrequency = updates.size() / maxChartPoints;
        List<Persistable> subsampled = new ArrayList<>();
        iterationCounts = new ArrayList<>();
        int pCount = -1;
        int lastUpdateIdx = updates.size() - 1;
        int lastIterCount = -1;
        for (Persistable p : updates) {
            if (!(p instanceof StatsReport))
                continue;
            ;
            StatsReport sr = (StatsReport) p;
            pCount++;
            int iterCount = sr.getIterationCount();
            if (iterCount <= lastIterCount) {
                needToHandleLegacyIterCounts = true;
            }
            lastIterCount = iterCount;
            if (pCount > 0 && subsamplingFrequency > 1 && pCount % subsamplingFrequency != 0) {
                //Skip this to subsample the data
                if (pCount != lastUpdateIdx)
                    //Always keep the most recent value
                    continue;
            }
            subsampled.add(p);
            iterationCounts.add(iterCount);
        }
        updates = subsampled;
    } else if (updates != null) {
        int offset = 0;
        iterationCounts = new ArrayList<>(updates.size());
        int lastIterCount = -1;
        for (Persistable p : updates) {
            if (!(p instanceof StatsReport))
                continue;
            ;
            StatsReport sr = (StatsReport) p;
            int iterCount = sr.getIterationCount();
            if (iterCount <= lastIterCount) {
                needToHandleLegacyIterCounts = true;
            }
            iterationCounts.add(iterCount);
        }
    }
    //Now, it should use the proper iteration counts
    if (needToHandleLegacyIterCounts) {
        cleanLegacyIterationCounts(iterationCounts);
    }
    //Get mean magnitudes line chart
    ModelType mt;
    if (conf.getFirst() != null)
        mt = ModelType.MLN;
    else if (conf.getSecond() != null)
        mt = ModelType.CG;
    else
        mt = ModelType.Layer;
    MeanMagnitudes mm = getLayerMeanMagnitudes(layerIdx, gi, updates, iterationCounts, mt);
    Map<String, Object> mmRatioMap = new HashMap<>();
    mmRatioMap.put("layerParamNames", mm.getRatios().keySet());
    mmRatioMap.put("iterCounts", mm.getIterations());
    mmRatioMap.put("ratios", mm.getRatios());
    mmRatioMap.put("paramMM", mm.getParamMM());
    mmRatioMap.put("updateMM", mm.getUpdateMM());
    result.put("meanMag", mmRatioMap);
    //Get activations line chart for layer
    Triple<int[], float[], float[]> activationsData = getLayerActivations(layerIdx, gi, updates, iterationCounts);
    Map<String, Object> activationMap = new HashMap<>();
    activationMap.put("iterCount", activationsData.getFirst());
    activationMap.put("mean", activationsData.getSecond());
    activationMap.put("stdev", activationsData.getThird());
    result.put("activations", activationMap);
    //Get learning rate vs. time chart for layer
    Map<String, Object> lrs = getLayerLearningRates(layerIdx, gi, updates, iterationCounts, mt);
    result.put("learningRates", lrs);
    //Parameters histogram data
    Persistable lastUpdate = (updates != null && updates.size() > 0 ? updates.get(updates.size() - 1) : null);
    Map<String, Object> paramHistograms = getHistograms(layerIdx, gi, StatsType.Parameters, lastUpdate);
    result.put("paramHist", paramHistograms);
    //Updates histogram data
    Map<String, Object> updateHistograms = getHistograms(layerIdx, gi, StatsType.Updates, lastUpdate);
    result.put("updateHist", updateHistograms);
    return ok(Json.toJson(result));
}
Also used : Persistable(org.deeplearning4j.api.storage.Persistable) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration)

Aggregations

Persistable (org.deeplearning4j.api.storage.Persistable)30 StatsStorage (org.deeplearning4j.api.storage.StatsStorage)14 StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)7 StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)6 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)6 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 Test (org.junit.Test)5 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 MapDBStatsStorage (org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 IOException (java.io.IOException)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)3 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)3 FlowStaticPersistable (org.deeplearning4j.ui.flow.data.FlowStaticPersistable)3 FlowUpdatePersistable (org.deeplearning4j.ui.flow.data.FlowUpdatePersistable)3 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)3 BufferedImage (java.awt.image.BufferedImage)2 File (java.io.File)2 ArrayList (java.util.ArrayList)2