Search in sources :

Example 1 with StatsStorage

use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.

the class TestStatsStorage method testFileStatsStore.

@Test
public void testFileStatsStore() throws IOException {
    for (boolean useJ7Storage : new boolean[] { false, true }) {
        for (int i = 0; i < 2; i++) {
            File f;
            if (i == 0) {
                f = Files.createTempFile("TestMapDbStatsStore", ".db").toFile();
            } else {
                f = Files.createTempFile("TestSqliteStatsStore", ".db").toFile();
            }
            //Don't want file to exist...
            f.delete();
            StatsStorage ss;
            if (i == 0) {
                ss = new MapDBStatsStorage.Builder().file(f).build();
            } else {
                ss = new J7FileStatsStorage(f);
            }
            CountingListener l = new CountingListener();
            ss.registerStatsStorageListener(l);
            assertEquals(1, ss.getListeners().size());
            assertEquals(0, ss.listSessionIDs().size());
            assertNull(ss.getLatestUpdate("sessionID", "typeID", "workerID"));
            assertEquals(0, ss.listSessionIDs().size());
            ss.putStaticInfo(getInitReport(0, 0, 0, useJ7Storage));
            assertEquals(1, l.countNewSession);
            assertEquals(1, l.countNewWorkerId);
            assertEquals(1, l.countStaticInfo);
            assertEquals(0, l.countUpdate);
            assertEquals(Collections.singletonList("sid0"), ss.listSessionIDs());
            assertTrue(ss.sessionExists("sid0"));
            assertFalse(ss.sessionExists("sid1"));
            Persistable expected = getInitReport(0, 0, 0, useJ7Storage);
            Persistable p = ss.getStaticInfo("sid0", "tid0", "wid0");
            assertEquals(expected, p);
            List<Persistable> allStatic = ss.getAllStaticInfos("sid0", "tid0");
            assertEquals(Collections.singletonList(expected), allStatic);
            assertNull(ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
            assertEquals(0, ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0).size());
            assertEquals(0, ss.getNumUpdateRecordsFor("sid0"));
            assertEquals(0, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
            //Put first update
            ss.putUpdate(getReport(0, 0, 0, 12345, useJ7Storage));
            assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
            assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
            assertEquals(Collections.singletonList(getReport(0, 0, 0, 12345, useJ7Storage)), ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0));
            assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
            assertEquals(1, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
            List<Persistable> list = ss.getLatestUpdateAllWorkers("sid0", "tid0");
            assertEquals(1, list.size());
            assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12345));
            assertEquals(1, l.countNewSession);
            assertEquals(1, l.countNewWorkerId);
            assertEquals(1, l.countStaticInfo);
            assertEquals(1, l.countUpdate);
            //Put second update
            ss.putUpdate(getReport(0, 0, 0, 12346, useJ7Storage));
            assertEquals(1, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
            assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12346));
            ss.putUpdate(getReport(0, 0, 1, 12345, useJ7Storage));
            assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid1"));
            assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid1", 12345));
            assertEquals(1, l.countNewSession);
            assertEquals(2, l.countNewWorkerId);
            assertEquals(1, l.countStaticInfo);
            assertEquals(3, l.countUpdate);
            //Put static info and update with different session, type and worker IDs
            ss.putStaticInfo(getInitReport(100, 200, 300, useJ7Storage));
            assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            ss.putUpdate(getReport(100, 200, 300, 12346, useJ7Storage));
            assertEquals(Collections.singletonList(getReport(100, 200, 300, 12346, useJ7Storage)), ss.getLatestUpdateAllWorkers("sid100", "tid200"));
            assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100"));
            List<String> temp = ss.listWorkerIDsForSession("sid100");
            System.out.println("temp: " + temp);
            assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100"));
            assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSessionAndType("sid100", "tid200"));
            assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getLatestUpdate("sid100", "tid200", "wid300"));
            assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getUpdate("sid100", "tid200", "wid300", 12346));
            assertEquals(2, l.countNewSession);
            assertEquals(3, l.countNewWorkerId);
            assertEquals(2, l.countStaticInfo);
            assertEquals(4, l.countUpdate);
            //Close and re-open
            ss.close();
            assertTrue(ss.isClosed());
            if (i == 0) {
                ss = new MapDBStatsStorage.Builder().file(f).build();
            } else {
                ss = new J7FileStatsStorage(f);
            }
            assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12345));
            assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12346));
            assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid1"));
            assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid1", 12345));
            assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            assertEquals(1, ss.getLatestUpdateAllWorkers("sid100", "tid200").size());
            assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100"));
            assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100"));
            assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSessionAndType("sid100", "tid200"));
            assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getLatestUpdate("sid100", "tid200", "wid300"));
            assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getUpdate("sid100", "tid200", "wid300", 12346));
        }
    }
}
Also used : MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) J7FileStatsStorage(org.deeplearning4j.ui.storage.sqlite.J7FileStatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable) J7FileStatsStorage(org.deeplearning4j.ui.storage.sqlite.J7FileStatsStorage) MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) File(java.io.File) Test(org.junit.Test)

Example 2 with StatsStorage

use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.

the class TrainModule method getConfig.

private Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> getConfig() {
    boolean noData = currentSessionID == null;
    StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
    List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
    if (allStatic.size() == 0)
        return null;
    StatsInitializationReport p = (StatsInitializationReport) allStatic.get(0);
    String modelClass = p.getModelClassName();
    String config = p.getModelConfigJson();
    if (modelClass.endsWith("MultiLayerNetwork")) {
        MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(config);
        return new Triple<>(conf, null, null);
    } else if (modelClass.endsWith("ComputationGraph")) {
        ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(config);
        return new Triple<>(null, conf, null);
    } else {
        try {
            NeuralNetConfiguration layer = NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class);
            return new Triple<>(null, null, layer);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    return null;
}
Also used : Triple(org.deeplearning4j.berkeley.Triple) StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration)

Example 3 with StatsStorage

use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.

the class TrainModule method getModelGraph.

private Result getModelGraph() {
    boolean noData = currentSessionID == null;
    StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
    List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
    if (allStatic.size() == 0) {
        return ok();
    }
    TrainModuleUtils.GraphInfo gi = getGraphInfo();
    if (gi == null)
        return ok();
    return ok(Json.toJson(gi));
}
Also used : StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable)

Example 4 with StatsStorage

use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.

the class TrainModule method sessionInfo.

private Result sessionInfo() {
    //Display, for each session: session ID, start time, number of workers, last update
    Map<String, Object> dataEachSession = new HashMap<>();
    for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
        Map<String, Object> dataThisSession = new HashMap<>();
        String sid = entry.getKey();
        StatsStorage ss = entry.getValue();
        List<String> workerIDs = ss.listWorkerIDsForSessionAndType(sid, StatsListener.TYPE_ID);
        int workerCount = (workerIDs == null ? 0 : workerIDs.size());
        List<Persistable> staticInfo = ss.getAllStaticInfos(sid, StatsListener.TYPE_ID);
        long initTime = Long.MAX_VALUE;
        if (staticInfo != null) {
            for (Persistable p : staticInfo) {
                initTime = Math.min(p.getTimeStamp(), initTime);
            }
        }
        long lastUpdateTime = Long.MIN_VALUE;
        List<Persistable> lastUpdatesAllWorkers = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
        for (Persistable p : lastUpdatesAllWorkers) {
            lastUpdateTime = Math.max(lastUpdateTime, p.getTimeStamp());
        }
        dataThisSession.put("numWorkers", workerCount);
        dataThisSession.put("initTime", initTime == Long.MAX_VALUE ? "" : initTime);
        dataThisSession.put("lastUpdate", lastUpdateTime == Long.MIN_VALUE ? "" : lastUpdateTime);
        // add hashmap of workers
        if (workerCount > 0) {
            dataThisSession.put("workers", workerIDs);
        }
        //Model info: type, # layers, # params...
        if (staticInfo != null && staticInfo.size() > 0) {
            StatsInitializationReport sr = (StatsInitializationReport) staticInfo.get(0);
            String modelClassName = sr.getModelClassName();
            if (modelClassName.endsWith("MultiLayerNetwork")) {
                modelClassName = "MultiLayerNetwork";
            } else if (modelClassName.endsWith("ComputationGraph")) {
                modelClassName = "ComputationGraph";
            }
            int numLayers = sr.getModelNumLayers();
            long numParams = sr.getModelNumParams();
            dataThisSession.put("modelType", modelClassName);
            dataThisSession.put("numLayers", numLayers);
            dataThisSession.put("numParams", numParams);
        } else {
            dataThisSession.put("modelType", "");
            dataThisSession.put("numLayers", "");
            dataThisSession.put("numParams", "");
        }
        dataEachSession.put(sid, dataThisSession);
    }
    return ok(Json.toJson(dataEachSession));
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable)

Example 5 with StatsStorage

use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.

the class TrainModule method getOverviewData.

private Result getOverviewData() {
    Long lastUpdate = lastUpdateForSession.get(currentSessionID);
    if (lastUpdate == null)
        lastUpdate = -1L;
    I18N i18N = I18NProvider.getInstance();
    boolean noData = currentSessionID == null;
    //First pass (optimize later): query all data...
    StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
    String wid = getWorkerIdForIndex(currentWorkerIdx);
    if (wid == null) {
        noData = true;
    }
    List<Integer> scoresIterCount = new ArrayList<>();
    List<Double> scores = new ArrayList<>();
    Map<String, Object> result = new HashMap<>();
    result.put("updateTimestamp", lastUpdate);
    result.put("scores", scores);
    result.put("scoresIter", scoresIterCount);
    //Get scores info
    List<Persistable> updates = (noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0));
    if (updates == null || updates.size() == 0) {
        noData = true;
    }
    //Collect update ratios for weights
    //Collect standard deviations: activations, gradients, updates
    //Mean magnitude (updates) / mean magnitude (parameters)
    Map<String, List<Double>> updateRatios = new HashMap<>();
    result.put("updateRatios", updateRatios);
    Map<String, List<Double>> stdevActivations = new HashMap<>();
    Map<String, List<Double>> stdevGradients = new HashMap<>();
    Map<String, List<Double>> stdevUpdates = new HashMap<>();
    result.put("stdevActivations", stdevActivations);
    result.put("stdevGradients", stdevGradients);
    result.put("stdevUpdates", stdevUpdates);
    if (!noData) {
        Persistable u = updates.get(0);
        if (u instanceof StatsReport) {
            StatsReport sp = (StatsReport) u;
            Map<String, Double> map = sp.getMeanMagnitudes(StatsType.Parameters);
            if (map != null) {
                for (String s : map.keySet()) {
                    if (!s.toLowerCase().endsWith("w"))
                        //TODO: more robust "weights only" approach...
                        continue;
                    updateRatios.put(s, new ArrayList<>());
                }
            }
            Map<String, Double> stdGrad = sp.getStdev(StatsType.Gradients);
            if (stdGrad != null) {
                for (String s : stdGrad.keySet()) {
                    if (!s.toLowerCase().endsWith("w"))
                        //TODO: more robust "weights only" approach...
                        continue;
                    stdevGradients.put(s, new ArrayList<>());
                }
            }
            Map<String, Double> stdUpdate = sp.getStdev(StatsType.Updates);
            if (stdUpdate != null) {
                for (String s : stdUpdate.keySet()) {
                    if (!s.toLowerCase().endsWith("w"))
                        //TODO: more robust "weights only" approach...
                        continue;
                    stdevUpdates.put(s, new ArrayList<>());
                }
            }
            Map<String, Double> stdAct = sp.getStdev(StatsType.Activations);
            if (stdAct != null) {
                for (String s : stdAct.keySet()) {
                    stdevActivations.put(s, new ArrayList<>());
                }
            }
        }
    }
    StatsReport last = null;
    int lastIterCount = -1;
    //Legacy issue - Spark training - iteration counts are used to be reset... which means: could go 0,1,2,0,1,2, etc...
    //Or, it could equally go 4,8,4,8,... or 5,5,5,5 - depending on the collection and averaging frequencies
    //Now, it should use the proper iteration counts
    boolean needToHandleLegacyIterCounts = false;
    if (!noData) {
        double lastScore;
        int totalUpdates = updates.size();
        int subsamplingFrequency = 1;
        if (totalUpdates > maxChartPoints) {
            subsamplingFrequency = totalUpdates / maxChartPoints;
        }
        int pCount = -1;
        int lastUpdateIdx = updates.size() - 1;
        for (Persistable u : updates) {
            pCount++;
            if (!(u instanceof StatsReport))
                continue;
            last = (StatsReport) u;
            int iterCount = last.getIterationCount();
            if (iterCount <= lastIterCount) {
                needToHandleLegacyIterCounts = true;
            }
            lastIterCount = iterCount;
            if (pCount > 0 && subsamplingFrequency > 1 && pCount % subsamplingFrequency != 0) {
                //Skip this - subsample the data
                if (pCount != lastUpdateIdx)
                    //Always keep the most recent value
                    continue;
            }
            scoresIterCount.add(iterCount);
            lastScore = last.getScore();
            if (Double.isFinite(lastScore)) {
                scores.add(lastScore);
            } else {
                scores.add(NAN_REPLACEMENT_VALUE);
            }
            //Update ratios: mean magnitudes(updates) / mean magnitudes (parameters)
            Map<String, Double> updateMM = last.getMeanMagnitudes(StatsType.Updates);
            Map<String, Double> paramMM = last.getMeanMagnitudes(StatsType.Parameters);
            if (updateMM != null && paramMM != null && updateMM.size() > 0 && paramMM.size() > 0) {
                for (String s : updateRatios.keySet()) {
                    List<Double> ratioHistory = updateRatios.get(s);
                    double currUpdate = updateMM.getOrDefault(s, 0.0);
                    double currParam = paramMM.getOrDefault(s, 0.0);
                    double ratio = currUpdate / currParam;
                    if (Double.isFinite(ratio)) {
                        ratioHistory.add(ratio);
                    } else {
                        ratioHistory.add(NAN_REPLACEMENT_VALUE);
                    }
                }
            }
            //Standard deviations: gradients, updates, activations
            Map<String, Double> stdGrad = last.getStdev(StatsType.Gradients);
            Map<String, Double> stdUpd = last.getStdev(StatsType.Updates);
            Map<String, Double> stdAct = last.getStdev(StatsType.Activations);
            if (stdGrad != null) {
                for (String s : stdevGradients.keySet()) {
                    double d = stdGrad.getOrDefault(s, 0.0);
                    stdevGradients.get(s).add(fixNaN(d));
                }
            }
            if (stdUpd != null) {
                for (String s : stdevUpdates.keySet()) {
                    double d = stdUpd.getOrDefault(s, 0.0);
                    stdevUpdates.get(s).add(fixNaN(d));
                }
            }
            if (stdAct != null) {
                for (String s : stdevActivations.keySet()) {
                    double d = stdAct.getOrDefault(s, 0.0);
                    stdevActivations.get(s).add(fixNaN(d));
                }
            }
        }
    }
    if (needToHandleLegacyIterCounts) {
        cleanLegacyIterationCounts(scoresIterCount);
    }
    //----- Performance Info -----
    String[][] perfInfo = new String[][] { { i18N.getMessage("train.overview.perftable.startTime"), "" }, { i18N.getMessage("train.overview.perftable.totalRuntime"), "" }, { i18N.getMessage("train.overview.perftable.lastUpdate"), "" }, { i18N.getMessage("train.overview.perftable.totalParamUpdates"), "" }, { i18N.getMessage("train.overview.perftable.updatesPerSec"), "" }, { i18N.getMessage("train.overview.perftable.examplesPerSec"), "" } };
    if (last != null) {
        perfInfo[2][1] = String.valueOf(dateFormat.format(new Date(last.getTimeStamp())));
        perfInfo[3][1] = String.valueOf(last.getTotalMinibatches());
        perfInfo[4][1] = String.valueOf(df2.format(last.getMinibatchesPerSecond()));
        perfInfo[5][1] = String.valueOf(df2.format(last.getExamplesPerSecond()));
    }
    result.put("perf", perfInfo);
    // ----- Model Info -----
    String[][] modelInfo = new String[][] { { i18N.getMessage("train.overview.modeltable.modeltype"), "" }, { i18N.getMessage("train.overview.modeltable.nLayers"), "" }, { i18N.getMessage("train.overview.modeltable.nParams"), "" } };
    if (!noData) {
        Persistable p = ss.getStaticInfo(currentSessionID, StatsListener.TYPE_ID, wid);
        if (p != null) {
            StatsInitializationReport initReport = (StatsInitializationReport) p;
            int nLayers = initReport.getModelNumLayers();
            long numParams = initReport.getModelNumParams();
            String className = initReport.getModelClassName();
            String modelType;
            if (className.endsWith("MultiLayerNetwork")) {
                modelType = "MultiLayerNetwork";
            } else if (className.endsWith("ComputationGraph")) {
                modelType = "ComputationGraph";
            } else {
                modelType = className;
                if (modelType.lastIndexOf('.') > 0) {
                    modelType = modelType.substring(modelType.lastIndexOf('.') + 1);
                }
            }
            modelInfo[0][1] = modelType;
            modelInfo[1][1] = String.valueOf(nLayers);
            modelInfo[2][1] = String.valueOf(numParams);
        }
    }
    result.put("model", modelInfo);
    return Results.ok(Json.toJson(result));
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) Persistable(org.deeplearning4j.api.storage.Persistable) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) AtomicInteger(java.util.concurrent.atomic.AtomicInteger)

Aggregations

StatsStorage (org.deeplearning4j.api.storage.StatsStorage)22 Persistable (org.deeplearning4j.api.storage.Persistable)14 Test (org.junit.Test)10 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)8 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)7 InMemoryStatsStorage (org.deeplearning4j.ui.storage.InMemoryStatsStorage)7 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)6 StatsListener (org.deeplearning4j.ui.stats.StatsListener)6 Ignore (org.junit.Ignore)6 UIServer (org.deeplearning4j.ui.api.UIServer)5 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)4 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)4 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)4 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)4 MapDBStatsStorage (org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage)4 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)3 StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)3 File (java.io.File)2