Search in sources :

Example 11 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method getHardwareSoftwareInfo.

private static Pair<Map<String, Object>, Map<String, Object>> getHardwareSoftwareInfo(List<Persistable> staticInfoAllWorkers, I18N i18n) {
    Map<String, Object> retHw = new HashMap<>();
    Map<String, Object> retSw = new HashMap<>();
    //First: map workers to JVMs
    Set<String> jvmIDs = new HashSet<>();
    Map<String, StatsInitializationReport> staticByJvm = new HashMap<>();
    for (Persistable p : staticInfoAllWorkers) {
        //TODO validation/checks
        StatsInitializationReport init = (StatsInitializationReport) p;
        String jvmuid = init.getSwJvmUID();
        jvmIDs.add(jvmuid);
        staticByJvm.put(jvmuid, init);
    }
    List<String> jvmList = new ArrayList<>(jvmIDs);
    Collections.sort(jvmList);
    //For each unique JVM, collect hardware info
    int count = 0;
    for (String jvm : jvmList) {
        StatsInitializationReport sr = staticByJvm.get(jvm);
        //---- Harware Info ----
        List<String[]> hwInfo = new ArrayList<>();
        int numDevices = sr.getHwNumDevices();
        String[] deviceDescription = sr.getHwDeviceDescription();
        long[] devTotalMem = sr.getHwDeviceTotalMemory();
        hwInfo.add(new String[] { i18n.getMessage("train.system.hwTable.jvmMax"), String.valueOf(sr.getHwJvmMaxMemory()) });
        hwInfo.add(new String[] { i18n.getMessage("train.system.hwTable.offHeapMax"), String.valueOf(sr.getHwOffHeapMaxMemory()) });
        hwInfo.add(new String[] { i18n.getMessage("train.system.hwTable.jvmProcs"), String.valueOf(sr.getHwJvmAvailableProcessors()) });
        hwInfo.add(new String[] { i18n.getMessage("train.system.hwTable.computeDevices"), String.valueOf(numDevices) });
        for (int i = 0; i < numDevices; i++) {
            String label = i18n.getMessage("train.system.hwTable.deviceName") + " (" + i + ")";
            String name = (deviceDescription == null || i >= deviceDescription.length ? String.valueOf(i) : deviceDescription[i]);
            hwInfo.add(new String[] { label, name });
            String memLabel = i18n.getMessage("train.system.hwTable.deviceMemory") + " (" + i + ")";
            String memBytes = (devTotalMem == null | i >= devTotalMem.length ? "-" : String.valueOf(devTotalMem[i]));
            hwInfo.add(new String[] { memLabel, memBytes });
        }
        retHw.put(String.valueOf(count), hwInfo);
        //---- Software Info -----
        String nd4jBackend = sr.getSwNd4jBackendClass();
        if (nd4jBackend != null && nd4jBackend.contains(".")) {
            int idx = nd4jBackend.lastIndexOf('.');
            nd4jBackend = nd4jBackend.substring(idx + 1);
            String temp;
            switch(nd4jBackend) {
                case "CpuNDArrayFactory":
                    temp = "CPU";
                    break;
                case "JCublasNDArrayFactory":
                    temp = "CUDA";
                    break;
                default:
                    temp = nd4jBackend;
            }
            nd4jBackend = temp;
        }
        String datatype = sr.getSwNd4jDataTypeName();
        if (datatype == null)
            datatype = "";
        else
            datatype = datatype.toLowerCase();
        List<String[]> swInfo = new ArrayList<>();
        swInfo.add(new String[] { i18n.getMessage("train.system.swTable.os"), sr.getSwOsName() });
        swInfo.add(new String[] { i18n.getMessage("train.system.swTable.hostname"), sr.getSwHostName() });
        swInfo.add(new String[] { i18n.getMessage("train.system.swTable.osArch"), sr.getSwArch() });
        swInfo.add(new String[] { i18n.getMessage("train.system.swTable.jvmName"), sr.getSwJvmName() });
        swInfo.add(new String[] { i18n.getMessage("train.system.swTable.jvmVersion"), sr.getSwJvmVersion() });
        swInfo.add(new String[] { i18n.getMessage("train.system.swTable.nd4jBackend"), nd4jBackend });
        swInfo.add(new String[] { i18n.getMessage("train.system.swTable.nd4jDataType"), datatype });
        retSw.put(String.valueOf(count), swInfo);
        count++;
    }
    return new Pair<>(retHw, retSw);
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) Persistable(org.deeplearning4j.api.storage.Persistable) Pair(org.deeplearning4j.berkeley.Pair)

Example 12 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method getOverviewData.

private Result getOverviewData() {
    Long lastUpdate = lastUpdateForSession.get(currentSessionID);
    if (lastUpdate == null)
        lastUpdate = -1L;
    I18N i18N = I18NProvider.getInstance();
    boolean noData = currentSessionID == null;
    //First pass (optimize later): query all data...
    StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
    String wid = getWorkerIdForIndex(currentWorkerIdx);
    if (wid == null) {
        noData = true;
    }
    List<Integer> scoresIterCount = new ArrayList<>();
    List<Double> scores = new ArrayList<>();
    Map<String, Object> result = new HashMap<>();
    result.put("updateTimestamp", lastUpdate);
    result.put("scores", scores);
    result.put("scoresIter", scoresIterCount);
    //Get scores info
    List<Persistable> updates = (noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0));
    if (updates == null || updates.size() == 0) {
        noData = true;
    }
    //Collect update ratios for weights
    //Collect standard deviations: activations, gradients, updates
    //Mean magnitude (updates) / mean magnitude (parameters)
    Map<String, List<Double>> updateRatios = new HashMap<>();
    result.put("updateRatios", updateRatios);
    Map<String, List<Double>> stdevActivations = new HashMap<>();
    Map<String, List<Double>> stdevGradients = new HashMap<>();
    Map<String, List<Double>> stdevUpdates = new HashMap<>();
    result.put("stdevActivations", stdevActivations);
    result.put("stdevGradients", stdevGradients);
    result.put("stdevUpdates", stdevUpdates);
    if (!noData) {
        Persistable u = updates.get(0);
        if (u instanceof StatsReport) {
            StatsReport sp = (StatsReport) u;
            Map<String, Double> map = sp.getMeanMagnitudes(StatsType.Parameters);
            if (map != null) {
                for (String s : map.keySet()) {
                    if (!s.toLowerCase().endsWith("w"))
                        //TODO: more robust "weights only" approach...
                        continue;
                    updateRatios.put(s, new ArrayList<>());
                }
            }
            Map<String, Double> stdGrad = sp.getStdev(StatsType.Gradients);
            if (stdGrad != null) {
                for (String s : stdGrad.keySet()) {
                    if (!s.toLowerCase().endsWith("w"))
                        //TODO: more robust "weights only" approach...
                        continue;
                    stdevGradients.put(s, new ArrayList<>());
                }
            }
            Map<String, Double> stdUpdate = sp.getStdev(StatsType.Updates);
            if (stdUpdate != null) {
                for (String s : stdUpdate.keySet()) {
                    if (!s.toLowerCase().endsWith("w"))
                        //TODO: more robust "weights only" approach...
                        continue;
                    stdevUpdates.put(s, new ArrayList<>());
                }
            }
            Map<String, Double> stdAct = sp.getStdev(StatsType.Activations);
            if (stdAct != null) {
                for (String s : stdAct.keySet()) {
                    stdevActivations.put(s, new ArrayList<>());
                }
            }
        }
    }
    StatsReport last = null;
    int lastIterCount = -1;
    //Legacy issue - Spark training - iteration counts are used to be reset... which means: could go 0,1,2,0,1,2, etc...
    //Or, it could equally go 4,8,4,8,... or 5,5,5,5 - depending on the collection and averaging frequencies
    //Now, it should use the proper iteration counts
    boolean needToHandleLegacyIterCounts = false;
    if (!noData) {
        double lastScore;
        int totalUpdates = updates.size();
        int subsamplingFrequency = 1;
        if (totalUpdates > maxChartPoints) {
            subsamplingFrequency = totalUpdates / maxChartPoints;
        }
        int pCount = -1;
        int lastUpdateIdx = updates.size() - 1;
        for (Persistable u : updates) {
            pCount++;
            if (!(u instanceof StatsReport))
                continue;
            last = (StatsReport) u;
            int iterCount = last.getIterationCount();
            if (iterCount <= lastIterCount) {
                needToHandleLegacyIterCounts = true;
            }
            lastIterCount = iterCount;
            if (pCount > 0 && subsamplingFrequency > 1 && pCount % subsamplingFrequency != 0) {
                //Skip this - subsample the data
                if (pCount != lastUpdateIdx)
                    //Always keep the most recent value
                    continue;
            }
            scoresIterCount.add(iterCount);
            lastScore = last.getScore();
            if (Double.isFinite(lastScore)) {
                scores.add(lastScore);
            } else {
                scores.add(NAN_REPLACEMENT_VALUE);
            }
            //Update ratios: mean magnitudes(updates) / mean magnitudes (parameters)
            Map<String, Double> updateMM = last.getMeanMagnitudes(StatsType.Updates);
            Map<String, Double> paramMM = last.getMeanMagnitudes(StatsType.Parameters);
            if (updateMM != null && paramMM != null && updateMM.size() > 0 && paramMM.size() > 0) {
                for (String s : updateRatios.keySet()) {
                    List<Double> ratioHistory = updateRatios.get(s);
                    double currUpdate = updateMM.getOrDefault(s, 0.0);
                    double currParam = paramMM.getOrDefault(s, 0.0);
                    double ratio = currUpdate / currParam;
                    if (Double.isFinite(ratio)) {
                        ratioHistory.add(ratio);
                    } else {
                        ratioHistory.add(NAN_REPLACEMENT_VALUE);
                    }
                }
            }
            //Standard deviations: gradients, updates, activations
            Map<String, Double> stdGrad = last.getStdev(StatsType.Gradients);
            Map<String, Double> stdUpd = last.getStdev(StatsType.Updates);
            Map<String, Double> stdAct = last.getStdev(StatsType.Activations);
            if (stdGrad != null) {
                for (String s : stdevGradients.keySet()) {
                    double d = stdGrad.getOrDefault(s, 0.0);
                    stdevGradients.get(s).add(fixNaN(d));
                }
            }
            if (stdUpd != null) {
                for (String s : stdevUpdates.keySet()) {
                    double d = stdUpd.getOrDefault(s, 0.0);
                    stdevUpdates.get(s).add(fixNaN(d));
                }
            }
            if (stdAct != null) {
                for (String s : stdevActivations.keySet()) {
                    double d = stdAct.getOrDefault(s, 0.0);
                    stdevActivations.get(s).add(fixNaN(d));
                }
            }
        }
    }
    if (needToHandleLegacyIterCounts) {
        cleanLegacyIterationCounts(scoresIterCount);
    }
    //----- Performance Info -----
    String[][] perfInfo = new String[][] { { i18N.getMessage("train.overview.perftable.startTime"), "" }, { i18N.getMessage("train.overview.perftable.totalRuntime"), "" }, { i18N.getMessage("train.overview.perftable.lastUpdate"), "" }, { i18N.getMessage("train.overview.perftable.totalParamUpdates"), "" }, { i18N.getMessage("train.overview.perftable.updatesPerSec"), "" }, { i18N.getMessage("train.overview.perftable.examplesPerSec"), "" } };
    if (last != null) {
        perfInfo[2][1] = String.valueOf(dateFormat.format(new Date(last.getTimeStamp())));
        perfInfo[3][1] = String.valueOf(last.getTotalMinibatches());
        perfInfo[4][1] = String.valueOf(df2.format(last.getMinibatchesPerSecond()));
        perfInfo[5][1] = String.valueOf(df2.format(last.getExamplesPerSecond()));
    }
    result.put("perf", perfInfo);
    // ----- Model Info -----
    String[][] modelInfo = new String[][] { { i18N.getMessage("train.overview.modeltable.modeltype"), "" }, { i18N.getMessage("train.overview.modeltable.nLayers"), "" }, { i18N.getMessage("train.overview.modeltable.nParams"), "" } };
    if (!noData) {
        Persistable p = ss.getStaticInfo(currentSessionID, StatsListener.TYPE_ID, wid);
        if (p != null) {
            StatsInitializationReport initReport = (StatsInitializationReport) p;
            int nLayers = initReport.getModelNumLayers();
            long numParams = initReport.getModelNumParams();
            String className = initReport.getModelClassName();
            String modelType;
            if (className.endsWith("MultiLayerNetwork")) {
                modelType = "MultiLayerNetwork";
            } else if (className.endsWith("ComputationGraph")) {
                modelType = "ComputationGraph";
            } else {
                modelType = className;
                if (modelType.lastIndexOf('.') > 0) {
                    modelType = modelType.substring(modelType.lastIndexOf('.') + 1);
                }
            }
            modelInfo[0][1] = modelType;
            modelInfo[1][1] = String.valueOf(nLayers);
            modelInfo[2][1] = String.valueOf(numParams);
        }
    }
    result.put("model", modelInfo);
    return Results.ok(Json.toJson(result));
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) Persistable(org.deeplearning4j.api.storage.Persistable) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) AtomicInteger(java.util.concurrent.atomic.AtomicInteger)

Example 13 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TestRemoteReceiver method testRemoteBasic.

@Test
@Ignore
public void testRemoteBasic() throws Exception {
    List<Persistable> updates = new ArrayList<>();
    List<Persistable> staticInfo = new ArrayList<>();
    List<StorageMetaData> metaData = new ArrayList<>();
    CollectionStatsStorageRouter collectionRouter = new CollectionStatsStorageRouter(metaData, staticInfo, updates);
    UIServer s = UIServer.getInstance();
    s.enableRemoteListener(collectionRouter, false);
    RemoteUIStatsStorageRouter remoteRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
    SbeStatsReport update1 = new SbeStatsReport();
    update1.setDeviceCurrentBytes(new long[] { 1, 2 });
    update1.reportIterationCount(10);
    update1.reportIDs("sid", "tid", "wid", 123456);
    update1.reportPerformance(10, 20, 30, 40, 50);
    SbeStatsReport update2 = new SbeStatsReport();
    update2.setDeviceCurrentBytes(new long[] { 3, 4 });
    update2.reportIterationCount(20);
    update2.reportIDs("sid2", "tid2", "wid2", 123456);
    update2.reportPerformance(11, 21, 31, 40, 50);
    StorageMetaData smd1 = new SbeStorageMetaData(123, "sid", "typeid", "wid", "initTypeClass", "updaterTypeClass");
    StorageMetaData smd2 = new SbeStorageMetaData(456, "sid2", "typeid2", "wid2", "initTypeClass2", "updaterTypeClass2");
    SbeStatsInitializationReport init1 = new SbeStatsInitializationReport();
    init1.reportIDs("sid", "wid", "tid", 3145253452L);
    init1.reportHardwareInfo(1, 2, 3, 4, null, null, "2344253");
    remoteRouter.putUpdate(update1);
    Thread.sleep(100);
    remoteRouter.putStorageMetaData(smd1);
    Thread.sleep(100);
    remoteRouter.putStaticInfo(init1);
    Thread.sleep(100);
    remoteRouter.putUpdate(update2);
    Thread.sleep(100);
    remoteRouter.putStorageMetaData(smd2);
    Thread.sleep(2000);
    assertEquals(2, metaData.size());
    assertEquals(2, updates.size());
    assertEquals(1, staticInfo.size());
    assertEquals(Arrays.asList(update1, update2), updates);
    assertEquals(Arrays.asList(smd1, smd2), metaData);
    assertEquals(Collections.singletonList(init1), staticInfo);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) SbeStorageMetaData(org.deeplearning4j.ui.storage.impl.SbeStorageMetaData) Persistable(org.deeplearning4j.api.storage.Persistable) SbeStatsInitializationReport(org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport) SbeStatsReport(org.deeplearning4j.ui.stats.impl.SbeStatsReport) UIServer(org.deeplearning4j.ui.api.UIServer) ArrayList(java.util.ArrayList) CollectionStatsStorageRouter(org.deeplearning4j.api.storage.impl.CollectionStatsStorageRouter) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) SbeStorageMetaData(org.deeplearning4j.ui.storage.impl.SbeStorageMetaData) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 14 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class FlowListenerModule method getStaticInfo.

private Result getStaticInfo(String sessionID) {
    if (!knownSessionIDs.containsKey(sessionID))
        return ok("Unknown session ID");
    StatsStorage ss = knownSessionIDs.get(sessionID);
    List<Persistable> list = ss.getAllStaticInfos(sessionID, TYPE_ID);
    if (list == null || list.size() == 0)
        return ok();
    Persistable p = list.get(0);
    if (!(p instanceof FlowStaticPersistable))
        return ok();
    FlowStaticPersistable f = (FlowStaticPersistable) p;
    return ok(Json.toJson(f.getModelInfo()));
}
Also used : StatsStorage(org.deeplearning4j.api.storage.StatsStorage) FlowUpdatePersistable(org.deeplearning4j.ui.flow.data.FlowUpdatePersistable) FlowStaticPersistable(org.deeplearning4j.ui.flow.data.FlowStaticPersistable) Persistable(org.deeplearning4j.api.storage.Persistable) FlowStaticPersistable(org.deeplearning4j.ui.flow.data.FlowStaticPersistable)

Example 15 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class HistogramModule method processRequest.

private Result processRequest(String sessionId) {
    //TODO cache the relevant info and update, rather than querying StatsStorage and building from scratch each time
    StatsStorage ss = knownSessionIDs.get(sessionId);
    if (ss == null) {
        return Results.notFound("Unknown session ID: " + sessionId);
    }
    List<String> workerIDs = ss.listWorkerIDsForSession(sessionId);
    //TODO checks
    StatsInitializationReport initReport = (StatsInitializationReport) ss.getStaticInfo(sessionId, StatsListener.TYPE_ID, workerIDs.get(0));
    if (initReport == null)
        return Results.ok(Json.toJson(Collections.EMPTY_MAP));
    String[] paramNames = initReport.getModelParamNames();
    //Infer layer names from param names...
    Set<String> layerNameSet = new LinkedHashSet<>();
    for (String s : paramNames) {
        String[] split = s.split("_");
        if (!layerNameSet.contains(split[0])) {
            layerNameSet.add(split[0]);
        }
    }
    List<String> layerNameList = new ArrayList<>(layerNameSet);
    List<Persistable> list = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, workerIDs.get(0), 0);
    Collections.sort(list, (a, b) -> Long.compare(a.getTimeStamp(), b.getTimeStamp()));
    List<Double> scoreList = new ArrayList<>(list.size());
    //List.get(i) -> layer i. Maps: parameter for the given layer
    List<Map<String, List<Double>>> meanMagHistoryParams = new ArrayList<>();
    //List.get(i) -> layer i. Maps: updates for the given layer
    List<Map<String, List<Double>>> meanMagHistoryUpdates = new ArrayList<>();
    for (int i = 0; i < layerNameList.size(); i++) {
        meanMagHistoryParams.add(new HashMap<>());
        meanMagHistoryUpdates.add(new HashMap<>());
    }
    StatsReport last = null;
    for (Persistable p : list) {
        if (!(p instanceof StatsReport)) {
            log.debug("Encountered unexpected type: {}", p);
            continue;
        }
        StatsReport sp = (StatsReport) p;
        scoreList.add(sp.getScore());
        //Mean magnitudes
        if (sp.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)) {
            updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Parameters), layerNameList, meanMagHistoryParams);
        }
        if (sp.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)) {
            updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Updates), layerNameList, meanMagHistoryUpdates);
        }
        last = sp;
    }
    Map<String, Map> newParams = getHistogram(last.getHistograms(StatsType.Parameters));
    Map<String, Map> newGrad = getHistogram(last.getHistograms(StatsType.Updates));
    double lastScore = (scoreList.size() == 0 ? 0.0 : scoreList.get(scoreList.size() - 1));
    CompactModelAndGradient g = new CompactModelAndGradient();
    g.setGradients(newGrad);
    g.setParameters(newParams);
    g.setScore(lastScore);
    g.setScores(scoreList);
    //        g.setPath(subPath);
    g.setUpdateMagnitudes(meanMagHistoryUpdates);
    g.setParamMagnitudes(meanMagHistoryParams);
    //        g.setLayerNames(layerNames);
    g.setLastUpdateTime(last.getTimeStamp());
    return Results.ok(Json.toJson(g));
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable) CompactModelAndGradient(org.deeplearning4j.ui.weights.beans.CompactModelAndGradient) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport)

Aggregations

Persistable (org.deeplearning4j.api.storage.Persistable)30 StatsStorage (org.deeplearning4j.api.storage.StatsStorage)14 StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)7 StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)6 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)6 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 Test (org.junit.Test)5 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 MapDBStatsStorage (org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 IOException (java.io.IOException)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)3 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)3 FlowStaticPersistable (org.deeplearning4j.ui.flow.data.FlowStaticPersistable)3 FlowUpdatePersistable (org.deeplearning4j.ui.flow.data.FlowUpdatePersistable)3 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)3 BufferedImage (java.awt.image.BufferedImage)2 File (java.io.File)2 ArrayList (java.util.ArrayList)2