Search in sources :

Example 6 with StatsReport

use of org.deeplearning4j.ui.stats.api.StatsReport in project deeplearning4j by deeplearning4j.

the class HistogramModule method processRequest.

private Result processRequest(String sessionId) {
    //TODO cache the relevant info and update, rather than querying StatsStorage and building from scratch each time
    StatsStorage ss = knownSessionIDs.get(sessionId);
    if (ss == null) {
        return Results.notFound("Unknown session ID: " + sessionId);
    }
    List<String> workerIDs = ss.listWorkerIDsForSession(sessionId);
    //TODO checks
    StatsInitializationReport initReport = (StatsInitializationReport) ss.getStaticInfo(sessionId, StatsListener.TYPE_ID, workerIDs.get(0));
    if (initReport == null)
        return Results.ok(Json.toJson(Collections.EMPTY_MAP));
    String[] paramNames = initReport.getModelParamNames();
    //Infer layer names from param names...
    Set<String> layerNameSet = new LinkedHashSet<>();
    for (String s : paramNames) {
        String[] split = s.split("_");
        if (!layerNameSet.contains(split[0])) {
            layerNameSet.add(split[0]);
        }
    }
    List<String> layerNameList = new ArrayList<>(layerNameSet);
    List<Persistable> list = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, workerIDs.get(0), 0);
    Collections.sort(list, (a, b) -> Long.compare(a.getTimeStamp(), b.getTimeStamp()));
    List<Double> scoreList = new ArrayList<>(list.size());
    //List.get(i) -> layer i. Maps: parameter for the given layer
    List<Map<String, List<Double>>> meanMagHistoryParams = new ArrayList<>();
    //List.get(i) -> layer i. Maps: updates for the given layer
    List<Map<String, List<Double>>> meanMagHistoryUpdates = new ArrayList<>();
    for (int i = 0; i < layerNameList.size(); i++) {
        meanMagHistoryParams.add(new HashMap<>());
        meanMagHistoryUpdates.add(new HashMap<>());
    }
    StatsReport last = null;
    for (Persistable p : list) {
        if (!(p instanceof StatsReport)) {
            log.debug("Encountered unexpected type: {}", p);
            continue;
        }
        StatsReport sp = (StatsReport) p;
        scoreList.add(sp.getScore());
        //Mean magnitudes
        if (sp.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)) {
            updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Parameters), layerNameList, meanMagHistoryParams);
        }
        if (sp.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)) {
            updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Updates), layerNameList, meanMagHistoryUpdates);
        }
        last = sp;
    }
    Map<String, Map> newParams = getHistogram(last.getHistograms(StatsType.Parameters));
    Map<String, Map> newGrad = getHistogram(last.getHistograms(StatsType.Updates));
    double lastScore = (scoreList.size() == 0 ? 0.0 : scoreList.get(scoreList.size() - 1));
    CompactModelAndGradient g = new CompactModelAndGradient();
    g.setGradients(newGrad);
    g.setParameters(newParams);
    g.setScore(lastScore);
    g.setScores(scoreList);
    //        g.setPath(subPath);
    g.setUpdateMagnitudes(meanMagHistoryUpdates);
    g.setParamMagnitudes(meanMagHistoryParams);
    //        g.setLayerNames(layerNames);
    g.setLastUpdateTime(last.getTimeStamp());
    return Results.ok(Json.toJson(g));
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable) CompactModelAndGradient(org.deeplearning4j.ui.weights.beans.CompactModelAndGradient) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport)

Example 7 with StatsReport

use of org.deeplearning4j.ui.stats.api.StatsReport in project deeplearning4j by deeplearning4j.

the class TrainModule method getLayerLearningRates.

private Map<String, Object> getLayerLearningRates(int layerIdx, TrainModuleUtils.GraphInfo gi, List<Persistable> updates, List<Integer> iterationCounts, ModelType modelType) {
    if (gi == null) {
        return Collections.emptyMap();
    }
    String layerName = gi.getOriginalVertexName().get(layerIdx);
    int size = (updates == null ? 0 : updates.size());
    int[] iterCounts = new int[size];
    Map<String, float[]> byName = new HashMap<>();
    int used = 0;
    if (updates != null) {
        int uCount = -1;
        for (Persistable u : updates) {
            uCount++;
            if (!(u instanceof StatsReport))
                continue;
            StatsReport sp = (StatsReport) u;
            if (iterationCounts == null) {
                iterCounts[used] = sp.getIterationCount();
            } else {
                iterCounts[used] = iterationCounts.get(uCount);
            }
            //TODO PROPER VALIDATION ETC, ERROR HANDLING
            Map<String, Double> lrs = sp.getLearningRates();
            String prefix;
            if (modelType == ModelType.Layer) {
                prefix = layerName;
            } else {
                prefix = layerName + "_";
            }
            for (String p : lrs.keySet()) {
                if (p.startsWith(prefix)) {
                    String layerParamName = p.substring(Math.min(p.length(), prefix.length()));
                    if (!byName.containsKey(layerParamName)) {
                        byName.put(layerParamName, new float[size]);
                    }
                    float[] lrThisParam = byName.get(layerParamName);
                    lrThisParam[used] = lrs.get(p).floatValue();
                }
            }
            used++;
        }
    }
    List<String> paramNames = new ArrayList<>(byName.keySet());
    //Sorted for consistency
    Collections.sort(paramNames);
    Map<String, Object> ret = new HashMap<>();
    ret.put("iterCounts", iterCounts);
    ret.put("paramNames", paramNames);
    ret.put("lrs", byName);
    return ret;
}
Also used : Persistable(org.deeplearning4j.api.storage.Persistable) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport)

Example 8 with StatsReport

use of org.deeplearning4j.ui.stats.api.StatsReport in project deeplearning4j by deeplearning4j.

the class TrainModule method getModelData.

private Result getModelData(String str) {
    Long lastUpdateTime = lastUpdateForSession.get(currentSessionID);
    if (lastUpdateTime == null)
        lastUpdateTime = -1L;
    //TODO validation
    int layerIdx = Integer.parseInt(str);
    I18N i18N = I18NProvider.getInstance();
    //Model info for layer
    boolean noData = currentSessionID == null;
    //First pass (optimize later): query all data...
    StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
    String wid = getWorkerIdForIndex(currentWorkerIdx);
    if (wid == null) {
        noData = true;
    }
    Map<String, Object> result = new HashMap<>();
    result.put("updateTimestamp", lastUpdateTime);
    Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> conf = getConfig();
    if (conf == null) {
        return ok(Json.toJson(result));
    }
    TrainModuleUtils.GraphInfo gi = getGraphInfo();
    if (gi == null) {
        return ok(Json.toJson(result));
    }
    // Get static layer info
    String[][] layerInfoTable = getLayerInfoTable(layerIdx, gi, i18N, noData, ss, wid);
    result.put("layerInfo", layerInfoTable);
    //First: get all data, and subsample it if necessary, to avoid returning too many points...
    List<Persistable> updates = (noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0));
    List<Integer> iterationCounts = null;
    boolean needToHandleLegacyIterCounts = false;
    if (updates != null && updates.size() > maxChartPoints) {
        int subsamplingFrequency = updates.size() / maxChartPoints;
        List<Persistable> subsampled = new ArrayList<>();
        iterationCounts = new ArrayList<>();
        int pCount = -1;
        int lastUpdateIdx = updates.size() - 1;
        int lastIterCount = -1;
        for (Persistable p : updates) {
            if (!(p instanceof StatsReport))
                continue;
            ;
            StatsReport sr = (StatsReport) p;
            pCount++;
            int iterCount = sr.getIterationCount();
            if (iterCount <= lastIterCount) {
                needToHandleLegacyIterCounts = true;
            }
            lastIterCount = iterCount;
            if (pCount > 0 && subsamplingFrequency > 1 && pCount % subsamplingFrequency != 0) {
                //Skip this to subsample the data
                if (pCount != lastUpdateIdx)
                    //Always keep the most recent value
                    continue;
            }
            subsampled.add(p);
            iterationCounts.add(iterCount);
        }
        updates = subsampled;
    } else if (updates != null) {
        int offset = 0;
        iterationCounts = new ArrayList<>(updates.size());
        int lastIterCount = -1;
        for (Persistable p : updates) {
            if (!(p instanceof StatsReport))
                continue;
            ;
            StatsReport sr = (StatsReport) p;
            int iterCount = sr.getIterationCount();
            if (iterCount <= lastIterCount) {
                needToHandleLegacyIterCounts = true;
            }
            iterationCounts.add(iterCount);
        }
    }
    //Now, it should use the proper iteration counts
    if (needToHandleLegacyIterCounts) {
        cleanLegacyIterationCounts(iterationCounts);
    }
    //Get mean magnitudes line chart
    ModelType mt;
    if (conf.getFirst() != null)
        mt = ModelType.MLN;
    else if (conf.getSecond() != null)
        mt = ModelType.CG;
    else
        mt = ModelType.Layer;
    MeanMagnitudes mm = getLayerMeanMagnitudes(layerIdx, gi, updates, iterationCounts, mt);
    Map<String, Object> mmRatioMap = new HashMap<>();
    mmRatioMap.put("layerParamNames", mm.getRatios().keySet());
    mmRatioMap.put("iterCounts", mm.getIterations());
    mmRatioMap.put("ratios", mm.getRatios());
    mmRatioMap.put("paramMM", mm.getParamMM());
    mmRatioMap.put("updateMM", mm.getUpdateMM());
    result.put("meanMag", mmRatioMap);
    //Get activations line chart for layer
    Triple<int[], float[], float[]> activationsData = getLayerActivations(layerIdx, gi, updates, iterationCounts);
    Map<String, Object> activationMap = new HashMap<>();
    activationMap.put("iterCount", activationsData.getFirst());
    activationMap.put("mean", activationsData.getSecond());
    activationMap.put("stdev", activationsData.getThird());
    result.put("activations", activationMap);
    //Get learning rate vs. time chart for layer
    Map<String, Object> lrs = getLayerLearningRates(layerIdx, gi, updates, iterationCounts, mt);
    result.put("learningRates", lrs);
    //Parameters histogram data
    Persistable lastUpdate = (updates != null && updates.size() > 0 ? updates.get(updates.size() - 1) : null);
    Map<String, Object> paramHistograms = getHistograms(layerIdx, gi, StatsType.Parameters, lastUpdate);
    result.put("paramHist", paramHistograms);
    //Updates histogram data
    Map<String, Object> updateHistograms = getHistograms(layerIdx, gi, StatsType.Updates, lastUpdate);
    result.put("updateHist", updateHistograms);
    return ok(Json.toJson(result));
}
Also used : Persistable(org.deeplearning4j.api.storage.Persistable) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration)

Example 9 with StatsReport

use of org.deeplearning4j.ui.stats.api.StatsReport in project deeplearning4j by deeplearning4j.

the class TrainModule method getMemory.

private static Map<String, Object> getMemory(List<Persistable> staticInfoAllWorkers, List<Persistable> updatesLastNMinutes, I18N i18n) {
    Map<String, Object> ret = new HashMap<>();
    //First: map workers to JVMs
    Set<String> jvmIDs = new HashSet<>();
    Map<String, String> workersToJvms = new HashMap<>();
    Map<String, Integer> workerNumDevices = new HashMap<>();
    Map<String, String[]> deviceNames = new HashMap<>();
    for (Persistable p : staticInfoAllWorkers) {
        //TODO validation/checks
        StatsInitializationReport init = (StatsInitializationReport) p;
        String jvmuid = init.getSwJvmUID();
        workersToJvms.put(p.getWorkerID(), jvmuid);
        jvmIDs.add(jvmuid);
        int nDevices = init.getHwNumDevices();
        workerNumDevices.put(p.getWorkerID(), nDevices);
        if (nDevices > 0) {
            String[] deviceNamesArr = init.getHwDeviceDescription();
            deviceNames.put(p.getWorkerID(), deviceNamesArr);
        }
    }
    List<String> jvmList = new ArrayList<>(jvmIDs);
    Collections.sort(jvmList);
    //For each unique JVM, collect memory info
    //Do this by selecting the first worker
    int count = 0;
    for (String jvm : jvmList) {
        List<String> workersForJvm = new ArrayList<>();
        for (String s : workersToJvms.keySet()) {
            if (workersToJvms.get(s).equals(jvm)) {
                workersForJvm.add(s);
            }
        }
        Collections.sort(workersForJvm);
        String wid = workersForJvm.get(0);
        int numDevices = workerNumDevices.get(wid);
        Map<String, Object> jvmData = new HashMap<>();
        List<Long> timestamps = new ArrayList<>();
        List<Float> fracJvm = new ArrayList<>();
        List<Float> fracOffHeap = new ArrayList<>();
        long[] lastBytes = new long[2 + numDevices];
        long[] lastMaxBytes = new long[2 + numDevices];
        List<List<Float>> fracDeviceMem = null;
        if (numDevices > 0) {
            fracDeviceMem = new ArrayList<>(numDevices);
            for (int i = 0; i < numDevices; i++) {
                fracDeviceMem.add(new ArrayList<>());
            }
        }
        for (Persistable p : updatesLastNMinutes) {
            //TODO single pass
            if (!p.getWorkerID().equals(wid))
                continue;
            if (!(p instanceof StatsReport))
                continue;
            StatsReport sp = (StatsReport) p;
            timestamps.add(sp.getTimeStamp());
            long jvmCurrentBytes = sp.getJvmCurrentBytes();
            long jvmMaxBytes = sp.getJvmMaxBytes();
            long ohCurrentBytes = sp.getOffHeapCurrentBytes();
            long ohMaxBytes = sp.getOffHeapMaxBytes();
            double jvmFrac = jvmCurrentBytes / ((double) jvmMaxBytes);
            double offheapFrac = ohCurrentBytes / ((double) ohMaxBytes);
            if (Double.isNaN(jvmFrac))
                jvmFrac = 0.0;
            if (Double.isNaN(offheapFrac))
                offheapFrac = 0.0;
            fracJvm.add((float) jvmFrac);
            fracOffHeap.add((float) offheapFrac);
            lastBytes[0] = jvmCurrentBytes;
            lastBytes[1] = ohCurrentBytes;
            lastMaxBytes[0] = jvmMaxBytes;
            lastMaxBytes[1] = ohMaxBytes;
            if (numDevices > 0) {
                long[] devBytes = sp.getDeviceCurrentBytes();
                long[] devMaxBytes = sp.getDeviceMaxBytes();
                for (int i = 0; i < numDevices; i++) {
                    double frac = devBytes[i] / ((double) devMaxBytes[i]);
                    if (Double.isNaN(frac))
                        frac = 0.0;
                    fracDeviceMem.get(i).add((float) frac);
                    lastBytes[2 + i] = devBytes[i];
                    lastMaxBytes[2 + i] = devMaxBytes[i];
                }
            }
        }
        List<List<Float>> fracUtilized = new ArrayList<>();
        fracUtilized.add(fracJvm);
        fracUtilized.add(fracOffHeap);
        String[] seriesNames = new String[2 + numDevices];
        seriesNames[0] = i18n.getMessage("train.system.hwTable.jvmCurrent");
        seriesNames[1] = i18n.getMessage("train.system.hwTable.offHeapCurrent");
        boolean[] isDevice = new boolean[2 + numDevices];
        String[] devNames = deviceNames.get(wid);
        for (int i = 0; i < numDevices; i++) {
            seriesNames[2 + i] = devNames != null && devNames.length > i ? devNames[i] : "";
            fracUtilized.add(fracDeviceMem.get(i));
            isDevice[2 + i] = true;
        }
        jvmData.put("times", timestamps);
        jvmData.put("isDevice", isDevice);
        jvmData.put("seriesNames", seriesNames);
        jvmData.put("values", fracUtilized);
        jvmData.put("currentBytes", lastBytes);
        jvmData.put("maxBytes", lastMaxBytes);
        ret.put(String.valueOf(count), jvmData);
        count++;
    }
    return ret;
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) Persistable(org.deeplearning4j.api.storage.Persistable) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport) AtomicInteger(java.util.concurrent.atomic.AtomicInteger)

Aggregations

StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)9 Persistable (org.deeplearning4j.api.storage.Persistable)7 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 StatsStorage (org.deeplearning4j.api.storage.StatsStorage)3 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)3 Triple (org.deeplearning4j.berkeley.Triple)1 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1 Histogram (org.deeplearning4j.ui.stats.api.Histogram)1 SbeStatsReport (org.deeplearning4j.ui.stats.impl.SbeStatsReport)1 JavaStatsReport (org.deeplearning4j.ui.stats.impl.java.JavaStatsReport)1 CompactModelAndGradient (org.deeplearning4j.ui.weights.beans.CompactModelAndGradient)1