Search in sources :

Example 6 with StatsInitializationReport

use of org.deeplearning4j.ui.stats.api.StatsInitializationReport in project deeplearning4j by deeplearning4j.

the class HistogramModule method processRequest.

private Result processRequest(String sessionId) {
    //TODO cache the relevant info and update, rather than querying StatsStorage and building from scratch each time
    StatsStorage ss = knownSessionIDs.get(sessionId);
    if (ss == null) {
        return Results.notFound("Unknown session ID: " + sessionId);
    }
    List<String> workerIDs = ss.listWorkerIDsForSession(sessionId);
    //TODO checks
    StatsInitializationReport initReport = (StatsInitializationReport) ss.getStaticInfo(sessionId, StatsListener.TYPE_ID, workerIDs.get(0));
    if (initReport == null)
        return Results.ok(Json.toJson(Collections.EMPTY_MAP));
    String[] paramNames = initReport.getModelParamNames();
    //Infer layer names from param names...
    Set<String> layerNameSet = new LinkedHashSet<>();
    for (String s : paramNames) {
        String[] split = s.split("_");
        if (!layerNameSet.contains(split[0])) {
            layerNameSet.add(split[0]);
        }
    }
    List<String> layerNameList = new ArrayList<>(layerNameSet);
    List<Persistable> list = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, workerIDs.get(0), 0);
    Collections.sort(list, (a, b) -> Long.compare(a.getTimeStamp(), b.getTimeStamp()));
    List<Double> scoreList = new ArrayList<>(list.size());
    //List.get(i) -> layer i. Maps: parameter for the given layer
    List<Map<String, List<Double>>> meanMagHistoryParams = new ArrayList<>();
    //List.get(i) -> layer i. Maps: updates for the given layer
    List<Map<String, List<Double>>> meanMagHistoryUpdates = new ArrayList<>();
    for (int i = 0; i < layerNameList.size(); i++) {
        meanMagHistoryParams.add(new HashMap<>());
        meanMagHistoryUpdates.add(new HashMap<>());
    }
    StatsReport last = null;
    for (Persistable p : list) {
        if (!(p instanceof StatsReport)) {
            log.debug("Encountered unexpected type: {}", p);
            continue;
        }
        StatsReport sp = (StatsReport) p;
        scoreList.add(sp.getScore());
        //Mean magnitudes
        if (sp.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)) {
            updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Parameters), layerNameList, meanMagHistoryParams);
        }
        if (sp.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)) {
            updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Updates), layerNameList, meanMagHistoryUpdates);
        }
        last = sp;
    }
    Map<String, Map> newParams = getHistogram(last.getHistograms(StatsType.Parameters));
    Map<String, Map> newGrad = getHistogram(last.getHistograms(StatsType.Updates));
    double lastScore = (scoreList.size() == 0 ? 0.0 : scoreList.get(scoreList.size() - 1));
    CompactModelAndGradient g = new CompactModelAndGradient();
    g.setGradients(newGrad);
    g.setParameters(newParams);
    g.setScore(lastScore);
    g.setScores(scoreList);
    //        g.setPath(subPath);
    g.setUpdateMagnitudes(meanMagHistoryUpdates);
    g.setParamMagnitudes(meanMagHistoryParams);
    //        g.setLayerNames(layerNames);
    g.setLastUpdateTime(last.getTimeStamp());
    return Results.ok(Json.toJson(g));
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable) CompactModelAndGradient(org.deeplearning4j.ui.weights.beans.CompactModelAndGradient) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport)

Example 7 with StatsInitializationReport

use of org.deeplearning4j.ui.stats.api.StatsInitializationReport in project deeplearning4j by deeplearning4j.

the class TrainModule method getMemory.

private static Map<String, Object> getMemory(List<Persistable> staticInfoAllWorkers, List<Persistable> updatesLastNMinutes, I18N i18n) {
    Map<String, Object> ret = new HashMap<>();
    //First: map workers to JVMs
    Set<String> jvmIDs = new HashSet<>();
    Map<String, String> workersToJvms = new HashMap<>();
    Map<String, Integer> workerNumDevices = new HashMap<>();
    Map<String, String[]> deviceNames = new HashMap<>();
    for (Persistable p : staticInfoAllWorkers) {
        //TODO validation/checks
        StatsInitializationReport init = (StatsInitializationReport) p;
        String jvmuid = init.getSwJvmUID();
        workersToJvms.put(p.getWorkerID(), jvmuid);
        jvmIDs.add(jvmuid);
        int nDevices = init.getHwNumDevices();
        workerNumDevices.put(p.getWorkerID(), nDevices);
        if (nDevices > 0) {
            String[] deviceNamesArr = init.getHwDeviceDescription();
            deviceNames.put(p.getWorkerID(), deviceNamesArr);
        }
    }
    List<String> jvmList = new ArrayList<>(jvmIDs);
    Collections.sort(jvmList);
    //For each unique JVM, collect memory info
    //Do this by selecting the first worker
    int count = 0;
    for (String jvm : jvmList) {
        List<String> workersForJvm = new ArrayList<>();
        for (String s : workersToJvms.keySet()) {
            if (workersToJvms.get(s).equals(jvm)) {
                workersForJvm.add(s);
            }
        }
        Collections.sort(workersForJvm);
        String wid = workersForJvm.get(0);
        int numDevices = workerNumDevices.get(wid);
        Map<String, Object> jvmData = new HashMap<>();
        List<Long> timestamps = new ArrayList<>();
        List<Float> fracJvm = new ArrayList<>();
        List<Float> fracOffHeap = new ArrayList<>();
        long[] lastBytes = new long[2 + numDevices];
        long[] lastMaxBytes = new long[2 + numDevices];
        List<List<Float>> fracDeviceMem = null;
        if (numDevices > 0) {
            fracDeviceMem = new ArrayList<>(numDevices);
            for (int i = 0; i < numDevices; i++) {
                fracDeviceMem.add(new ArrayList<>());
            }
        }
        for (Persistable p : updatesLastNMinutes) {
            //TODO single pass
            if (!p.getWorkerID().equals(wid))
                continue;
            if (!(p instanceof StatsReport))
                continue;
            StatsReport sp = (StatsReport) p;
            timestamps.add(sp.getTimeStamp());
            long jvmCurrentBytes = sp.getJvmCurrentBytes();
            long jvmMaxBytes = sp.getJvmMaxBytes();
            long ohCurrentBytes = sp.getOffHeapCurrentBytes();
            long ohMaxBytes = sp.getOffHeapMaxBytes();
            double jvmFrac = jvmCurrentBytes / ((double) jvmMaxBytes);
            double offheapFrac = ohCurrentBytes / ((double) ohMaxBytes);
            if (Double.isNaN(jvmFrac))
                jvmFrac = 0.0;
            if (Double.isNaN(offheapFrac))
                offheapFrac = 0.0;
            fracJvm.add((float) jvmFrac);
            fracOffHeap.add((float) offheapFrac);
            lastBytes[0] = jvmCurrentBytes;
            lastBytes[1] = ohCurrentBytes;
            lastMaxBytes[0] = jvmMaxBytes;
            lastMaxBytes[1] = ohMaxBytes;
            if (numDevices > 0) {
                long[] devBytes = sp.getDeviceCurrentBytes();
                long[] devMaxBytes = sp.getDeviceMaxBytes();
                for (int i = 0; i < numDevices; i++) {
                    double frac = devBytes[i] / ((double) devMaxBytes[i]);
                    if (Double.isNaN(frac))
                        frac = 0.0;
                    fracDeviceMem.get(i).add((float) frac);
                    lastBytes[2 + i] = devBytes[i];
                    lastMaxBytes[2 + i] = devMaxBytes[i];
                }
            }
        }
        List<List<Float>> fracUtilized = new ArrayList<>();
        fracUtilized.add(fracJvm);
        fracUtilized.add(fracOffHeap);
        String[] seriesNames = new String[2 + numDevices];
        seriesNames[0] = i18n.getMessage("train.system.hwTable.jvmCurrent");
        seriesNames[1] = i18n.getMessage("train.system.hwTable.offHeapCurrent");
        boolean[] isDevice = new boolean[2 + numDevices];
        String[] devNames = deviceNames.get(wid);
        for (int i = 0; i < numDevices; i++) {
            seriesNames[2 + i] = devNames != null && devNames.length > i ? devNames[i] : "";
            fracUtilized.add(fracDeviceMem.get(i));
            isDevice[2 + i] = true;
        }
        jvmData.put("times", timestamps);
        jvmData.put("isDevice", isDevice);
        jvmData.put("seriesNames", seriesNames);
        jvmData.put("values", fracUtilized);
        jvmData.put("currentBytes", lastBytes);
        jvmData.put("maxBytes", lastMaxBytes);
        ret.put(String.valueOf(count), jvmData);
        count++;
    }
    return ret;
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) Persistable(org.deeplearning4j.api.storage.Persistable) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport) AtomicInteger(java.util.concurrent.atomic.AtomicInteger)

Example 8 with StatsInitializationReport

use of org.deeplearning4j.ui.stats.api.StatsInitializationReport in project deeplearning4j by deeplearning4j.

the class TestStatsStorage method getInitReport.

private static StatsInitializationReport getInitReport(int idNumber, int tid, int wid, boolean useJ7Storage) {
    StatsInitializationReport rep;
    if (useJ7Storage) {
        rep = new JavaStatsInitializationReport();
    } else {
        rep = new SbeStatsInitializationReport();
    }
    rep.reportModelInfo("classname", "jsonconfig", new String[] { "p0", "p1" }, 1, 10);
    rep.reportIDs("sid" + idNumber, "tid" + tid, "wid" + wid, 12345);
    rep.reportHardwareInfo(0, 2, 1000, 2000, new long[] { 3000, 4000 }, new String[] { "dev0", "dev1" }, "hardwareuid");
    Map<String, String> envInfo = new HashMap<>();
    envInfo.put("envInfo0", "value0");
    envInfo.put("envInfo1", "value1");
    rep.reportSoftwareInfo("arch", "osName", "jvmName", "jvmVersion", "1.8", "backend", "dtype", "hostname", "jvmuid", envInfo);
    return rep;
}
Also used : SbeStatsInitializationReport(org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport) JavaStatsInitializationReport(org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport) StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) SbeStatsInitializationReport(org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport) JavaStatsInitializationReport(org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport) HashMap(java.util.HashMap)

Aggregations

StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)8 Persistable (org.deeplearning4j.api.storage.Persistable)7 StatsStorage (org.deeplearning4j.api.storage.StatsStorage)4 StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)3 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)2 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)2 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)2 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)2 HashMap (java.util.HashMap)1 Pair (org.deeplearning4j.berkeley.Pair)1 Triple (org.deeplearning4j.berkeley.Triple)1 Updater (org.deeplearning4j.nn.conf.Updater)1 GraphVertex (org.deeplearning4j.nn.conf.graph.GraphVertex)1 LayerVertex (org.deeplearning4j.nn.conf.graph.LayerVertex)1 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)1 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)1 Layer (org.deeplearning4j.nn.conf.layers.Layer)1 SubsamplingLayer (org.deeplearning4j.nn.conf.layers.SubsamplingLayer)1 WeightInit (org.deeplearning4j.nn.weights.WeightInit)1 SbeStatsInitializationReport (org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport)1