Search in sources :

Example 1 with SbeStatsReport

use of org.deeplearning4j.ui.stats.impl.SbeStatsReport in project deeplearning4j by deeplearning4j.

the class TestStatsClasses method testSbeStatsUpdate.

@Test
public void testSbeStatsUpdate() throws Exception {
    String[] paramNames = new String[] { "param0", "param1" };
    String[] layerNames = new String[] { "layer0", "layer1" };
    //IDs
    String sessionID = "sid";
    String typeID = "tid";
    String workerID = "wid";
    long timestamp = -1;
    long time = System.currentTimeMillis();
    int duration = 123456;
    int iterCount = 123;
    long perfRuntime = 1;
    long perfTotalEx = 2;
    long perfTotalMB = 3;
    double perfEPS = 4.0;
    double perfMBPS = 5.0;
    long memJC = 6;
    long memJM = 7;
    long memOC = 8;
    long memOM = 9;
    long[] memDC = new long[] { 10, 11 };
    long[] memDM = new long[] { 12, 13 };
    String gc1Name = "14";
    int gcdc1 = 16;
    int gcdt1 = 17;
    String gc2Name = "18";
    int gcdc2 = 20;
    int gcdt2 = 21;
    double score = 22.0;
    Map<String, Double> lrByParam = new HashMap<>();
    lrByParam.put(paramNames[0], 22.5);
    lrByParam.put(paramNames[1], 22.75);
    Map<String, Histogram> pHist = new HashMap<>();
    pHist.put(paramNames[0], new Histogram(23, 24, 2, new int[] { 25, 26 }));
    pHist.put(paramNames[1], new Histogram(27, 28, 3, new int[] { 29, 30, 31 }));
    Map<String, Histogram> gHist = new HashMap<>();
    gHist.put(paramNames[0], new Histogram(230, 240, 2, new int[] { 250, 260 }));
    gHist.put(paramNames[1], new Histogram(270, 280, 3, new int[] { 290, 300, 310 }));
    Map<String, Histogram> uHist = new HashMap<>();
    uHist.put(paramNames[0], new Histogram(32, 33, 2, new int[] { 34, 35 }));
    uHist.put(paramNames[1], new Histogram(36, 37, 3, new int[] { 38, 39, 40 }));
    Map<String, Histogram> aHist = new HashMap<>();
    aHist.put(layerNames[0], new Histogram(41, 42, 2, new int[] { 43, 44 }));
    aHist.put(layerNames[1], new Histogram(45, 46, 3, new int[] { 47, 48, 47 }));
    Map<String, Double> pMean = new HashMap<>();
    pMean.put(paramNames[0], 49.0);
    pMean.put(paramNames[1], 50.0);
    Map<String, Double> gMean = new HashMap<>();
    gMean.put(paramNames[0], 49.1);
    gMean.put(paramNames[1], 50.1);
    Map<String, Double> uMean = new HashMap<>();
    uMean.put(paramNames[0], 51.0);
    uMean.put(paramNames[1], 52.0);
    Map<String, Double> aMean = new HashMap<>();
    aMean.put(layerNames[0], 53.0);
    aMean.put(layerNames[1], 54.0);
    Map<String, Double> pStd = new HashMap<>();
    pStd.put(paramNames[0], 55.0);
    pStd.put(paramNames[1], 56.0);
    Map<String, Double> gStd = new HashMap<>();
    gStd.put(paramNames[0], 55.1);
    gStd.put(paramNames[1], 56.1);
    Map<String, Double> uStd = new HashMap<>();
    uStd.put(paramNames[0], 57.0);
    uStd.put(paramNames[1], 58.0);
    Map<String, Double> aStd = new HashMap<>();
    aStd.put(layerNames[0], 59.0);
    aStd.put(layerNames[1], 60.0);
    Map<String, Double> pMM = new HashMap<>();
    pMM.put(paramNames[0], 61.0);
    pMM.put(paramNames[1], 62.0);
    Map<String, Double> gMM = new HashMap<>();
    gMM.put(paramNames[0], 61.1);
    gMM.put(paramNames[1], 62.1);
    Map<String, Double> uMM = new HashMap<>();
    uMM.put(paramNames[0], 63.0);
    uMM.put(paramNames[1], 64.0);
    Map<String, Double> aMM = new HashMap<>();
    aMM.put(layerNames[0], 65.0);
    aMM.put(layerNames[1], 66.0);
    List<Serializable> metaDataList = new ArrayList<>();
    metaDataList.add("meta1");
    metaDataList.add("meta2");
    metaDataList.add("meta3");
    Class<?> metaDataClass = String.class;
    boolean[] tf = new boolean[] { true, false };
    boolean[][] tf4 = new boolean[][] { { false, false, false, false }, { true, false, false, false }, { false, true, false, false }, { false, false, true, false }, { false, false, false, true }, { true, true, true, true } };
    //Total tests: 2^6 x 6^3 = 13,824 separate tests
    int testCount = 0;
    for (boolean collectPerformanceStats : tf) {
        for (boolean collectMemoryStats : tf) {
            for (boolean collectGCStats : tf) {
                for (boolean collectScore : tf) {
                    for (boolean collectLearningRates : tf) {
                        for (boolean collectMetaData : tf) {
                            for (boolean[] collectHistograms : tf4) {
                                for (boolean[] collectMeanStdev : tf4) {
                                    for (boolean[] collectMM : tf4) {
                                        SbeStatsReport report = new SbeStatsReport();
                                        report.reportIDs(sessionID, typeID, workerID, time);
                                        report.reportStatsCollectionDurationMS(duration);
                                        report.reportIterationCount(iterCount);
                                        if (collectPerformanceStats) {
                                            report.reportPerformance(perfRuntime, perfTotalEx, perfTotalMB, perfEPS, perfMBPS);
                                        }
                                        if (collectMemoryStats) {
                                            report.reportMemoryUse(memJC, memJM, memOC, memOM, memDC, memDM);
                                        }
                                        if (collectGCStats) {
                                            report.reportGarbageCollection(gc1Name, gcdc1, gcdt1);
                                            report.reportGarbageCollection(gc2Name, gcdc2, gcdt2);
                                        }
                                        if (collectScore) {
                                            report.reportScore(score);
                                        }
                                        if (collectLearningRates) {
                                            report.reportLearningRates(lrByParam);
                                        }
                                        if (collectMetaData) {
                                            report.reportDataSetMetaData(metaDataList, metaDataClass);
                                        }
                                        if (collectHistograms[0]) {
                                            //Param hist
                                            report.reportHistograms(StatsType.Parameters, pHist);
                                        }
                                        if (collectHistograms[1]) {
                                            //Grad hist
                                            report.reportHistograms(StatsType.Gradients, gHist);
                                        }
                                        if (collectHistograms[2]) {
                                            //Update hist
                                            report.reportHistograms(StatsType.Updates, uHist);
                                        }
                                        if (collectHistograms[3]) {
                                            //Act hist
                                            report.reportHistograms(StatsType.Activations, aHist);
                                        }
                                        if (collectMeanStdev[0]) {
                                            //Param mean/stdev
                                            report.reportMean(StatsType.Parameters, pMean);
                                            report.reportStdev(StatsType.Parameters, pStd);
                                        }
                                        if (collectMeanStdev[1]) {
                                            //Gradient mean/stdev
                                            report.reportMean(StatsType.Gradients, gMean);
                                            report.reportStdev(StatsType.Gradients, gStd);
                                        }
                                        if (collectMeanStdev[2]) {
                                            //Update mean/stdev
                                            report.reportMean(StatsType.Updates, uMean);
                                            report.reportStdev(StatsType.Updates, uStd);
                                        }
                                        if (collectMeanStdev[3]) {
                                            //Act mean/stdev
                                            report.reportMean(StatsType.Activations, aMean);
                                            report.reportStdev(StatsType.Activations, aStd);
                                        }
                                        if (collectMM[0]) {
                                            //Param mean mag
                                            report.reportMeanMagnitudes(StatsType.Parameters, pMM);
                                        }
                                        if (collectMM[1]) {
                                            //Gradient mean mag
                                            report.reportMeanMagnitudes(StatsType.Gradients, gMM);
                                        }
                                        if (collectMM[2]) {
                                            //Update mm
                                            report.reportMeanMagnitudes(StatsType.Updates, uMM);
                                        }
                                        if (collectMM[3]) {
                                            //Act mm
                                            report.reportMeanMagnitudes(StatsType.Activations, aMM);
                                        }
                                        byte[] bytes = report.encode();
                                        StatsReport report2 = new SbeStatsReport();
                                        report2.decode(bytes);
                                        assertEquals(report, report2);
                                        assertEquals(sessionID, report2.getSessionID());
                                        assertEquals(typeID, report2.getTypeID());
                                        assertEquals(workerID, report2.getWorkerID());
                                        assertEquals(time, report2.getTimeStamp());
                                        assertEquals(time, report2.getTimeStamp());
                                        assertEquals(duration, report2.getStatsCollectionDurationMs());
                                        assertEquals(iterCount, report2.getIterationCount());
                                        if (collectPerformanceStats) {
                                            assertEquals(perfRuntime, report2.getTotalRuntimeMs());
                                            assertEquals(perfTotalEx, report2.getTotalExamples());
                                            assertEquals(perfTotalMB, report2.getTotalMinibatches());
                                            assertEquals(perfEPS, report2.getExamplesPerSecond(), 0.0);
                                            assertEquals(perfMBPS, report2.getMinibatchesPerSecond(), 0.0);
                                            Assert.assertTrue(report2.hasPerformance());
                                        } else {
                                            Assert.assertFalse(report2.hasPerformance());
                                        }
                                        if (collectMemoryStats) {
                                            assertEquals(memJC, report2.getJvmCurrentBytes());
                                            assertEquals(memJM, report2.getJvmMaxBytes());
                                            assertEquals(memOC, report2.getOffHeapCurrentBytes());
                                            assertEquals(memOM, report2.getOffHeapMaxBytes());
                                            assertArrayEquals(memDC, report2.getDeviceCurrentBytes());
                                            assertArrayEquals(memDM, report2.getDeviceMaxBytes());
                                            Assert.assertTrue(report2.hasMemoryUse());
                                        } else {
                                            Assert.assertFalse(report2.hasMemoryUse());
                                        }
                                        if (collectGCStats) {
                                            List<Pair<String, int[]>> gcs = report2.getGarbageCollectionStats();
                                            Assert.assertEquals(2, gcs.size());
                                            Assert.assertEquals(gc1Name, gcs.get(0).getFirst());
                                            Assert.assertArrayEquals(new int[] { gcdc1, gcdt1 }, gcs.get(0).getSecond());
                                            Assert.assertEquals(gc2Name, gcs.get(1).getFirst());
                                            Assert.assertArrayEquals(new int[] { gcdc2, gcdt2 }, gcs.get(1).getSecond());
                                            Assert.assertTrue(report2.hasGarbageCollection());
                                        } else {
                                            Assert.assertFalse(report2.hasGarbageCollection());
                                        }
                                        if (collectScore) {
                                            assertEquals(score, report2.getScore(), 0.0);
                                            Assert.assertTrue(report2.hasScore());
                                        } else {
                                            Assert.assertFalse(report2.hasScore());
                                        }
                                        if (collectLearningRates) {
                                            assertEquals(lrByParam.keySet(), report2.getLearningRates().keySet());
                                            for (String s : lrByParam.keySet()) {
                                                assertEquals(lrByParam.get(s), report2.getLearningRates().get(s), 1e-6);
                                            }
                                            Assert.assertTrue(report2.hasLearningRates());
                                        } else {
                                            Assert.assertFalse(report2.hasLearningRates());
                                        }
                                        if (collectMetaData) {
                                            assertNotNull(report2.getDataSetMetaData());
                                            assertEquals(metaDataList, report2.getDataSetMetaData());
                                            assertEquals(metaDataClass.getName(), report2.getDataSetMetaDataClassName());
                                            assertTrue(report2.hasDataSetMetaData());
                                        } else {
                                            assertFalse(report2.hasDataSetMetaData());
                                        }
                                        if (collectHistograms[0]) {
                                            assertEquals(pHist, report2.getHistograms(StatsType.Parameters));
                                            Assert.assertTrue(report2.hasHistograms(StatsType.Parameters));
                                        } else {
                                            Assert.assertFalse(report2.hasHistograms(StatsType.Parameters));
                                        }
                                        if (collectHistograms[1]) {
                                            assertEquals(gHist, report2.getHistograms(StatsType.Gradients));
                                            Assert.assertTrue(report2.hasHistograms(StatsType.Gradients));
                                        } else {
                                            Assert.assertFalse(report2.hasHistograms(StatsType.Gradients));
                                        }
                                        if (collectHistograms[2]) {
                                            assertEquals(uHist, report2.getHistograms(StatsType.Updates));
                                            Assert.assertTrue(report2.hasHistograms(StatsType.Updates));
                                        } else {
                                            Assert.assertFalse(report2.hasHistograms(StatsType.Updates));
                                        }
                                        if (collectHistograms[3]) {
                                            assertEquals(aHist, report2.getHistograms(StatsType.Activations));
                                            Assert.assertTrue(report2.hasHistograms(StatsType.Activations));
                                        } else {
                                            Assert.assertFalse(report2.hasHistograms(StatsType.Activations));
                                        }
                                        if (collectMeanStdev[0]) {
                                            assertEquals(pMean, report2.getMean(StatsType.Parameters));
                                            assertEquals(pStd, report2.getStdev(StatsType.Parameters));
                                            Assert.assertTrue(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Mean));
                                            Assert.assertTrue(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Stdev));
                                        } else {
                                            Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Mean));
                                            Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Stdev));
                                        }
                                        if (collectMeanStdev[1]) {
                                            assertEquals(gMean, report2.getMean(StatsType.Gradients));
                                            assertEquals(gStd, report2.getStdev(StatsType.Gradients));
                                            Assert.assertTrue(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Mean));
                                            Assert.assertTrue(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Stdev));
                                        } else {
                                            Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Mean));
                                            Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Stdev));
                                        }
                                        if (collectMeanStdev[2]) {
                                            assertEquals(uMean, report2.getMean(StatsType.Updates));
                                            assertEquals(uStd, report2.getStdev(StatsType.Updates));
                                            Assert.assertTrue(report2.hasSummaryStats(StatsType.Updates, SummaryType.Mean));
                                            Assert.assertTrue(report2.hasSummaryStats(StatsType.Updates, SummaryType.Stdev));
                                        } else {
                                            Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.Mean));
                                            Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.Stdev));
                                        }
                                        if (collectMeanStdev[3]) {
                                            assertEquals(aMean, report2.getMean(StatsType.Activations));
                                            assertEquals(aStd, report2.getStdev(StatsType.Activations));
                                            Assert.assertTrue(report2.hasSummaryStats(StatsType.Activations, SummaryType.Mean));
                                            Assert.assertTrue(report2.hasSummaryStats(StatsType.Activations, SummaryType.Stdev));
                                        } else {
                                            Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.Mean));
                                            Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.Stdev));
                                        }
                                        if (collectMM[0]) {
                                            assertEquals(pMM, report2.getMeanMagnitudes(StatsType.Parameters));
                                            Assert.assertTrue(report2.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes));
                                        } else {
                                            Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes));
                                        }
                                        if (collectMM[1]) {
                                            assertEquals(gMM, report2.getMeanMagnitudes(StatsType.Gradients));
                                            Assert.assertTrue(report2.hasSummaryStats(StatsType.Gradients, SummaryType.MeanMagnitudes));
                                        } else {
                                            Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.MeanMagnitudes));
                                        }
                                        if (collectMM[2]) {
                                            assertEquals(uMM, report2.getMeanMagnitudes(StatsType.Updates));
                                            Assert.assertTrue(report2.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes));
                                        } else {
                                            Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes));
                                        }
                                        if (collectMM[3]) {
                                            assertEquals(aMM, report2.getMeanMagnitudes(StatsType.Activations));
                                            Assert.assertTrue(report2.hasSummaryStats(StatsType.Activations, SummaryType.MeanMagnitudes));
                                        } else {
                                            Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.MeanMagnitudes));
                                        }
                                        //Check standard Java serialization
                                        ByteArrayOutputStream baos = new ByteArrayOutputStream();
                                        ObjectOutputStream oos = new ObjectOutputStream(baos);
                                        oos.writeObject(report);
                                        oos.close();
                                        byte[] javaBytes = baos.toByteArray();
                                        ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(javaBytes));
                                        SbeStatsReport report3 = (SbeStatsReport) ois.readObject();
                                        assertEquals(report, report3);
                                        testCount++;
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    Assert.assertEquals(13824, testCount);
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) SbeStatsReport(org.deeplearning4j.ui.stats.impl.SbeStatsReport) SbeStatsReport(org.deeplearning4j.ui.stats.impl.SbeStatsReport) Pair(org.deeplearning4j.berkeley.Pair) Test(org.junit.Test)

Example 2 with SbeStatsReport

use of org.deeplearning4j.ui.stats.impl.SbeStatsReport in project deeplearning4j by deeplearning4j.

the class TestStatsClasses method testSbeStatsUpdateNullValues.

@Test
public void testSbeStatsUpdateNullValues() throws Exception {
    //new String[]{"param0", "param1"};
    String[] paramNames = null;
    long time = System.currentTimeMillis();
    int duration = 123456;
    int iterCount = 123;
    long perfRuntime = 1;
    long perfTotalEx = 2;
    long perfTotalMB = 3;
    double perfEPS = 4.0;
    double perfMBPS = 5.0;
    long memJC = 6;
    long memJM = 7;
    long memOC = 8;
    long memOM = 9;
    long[] memDC = null;
    long[] memDM = null;
    String gc1Name = null;
    int gcdc1 = 16;
    int gcdt1 = 17;
    String gc2Name = null;
    int gcdc2 = 20;
    int gcdt2 = 21;
    double score = 22.0;
    Map<String, Double> lrByParam = null;
    Map<String, Histogram> pHist = null;
    Map<String, Histogram> gHist = null;
    Map<String, Histogram> uHist = null;
    Map<String, Histogram> aHist = null;
    Map<String, Double> pMean = null;
    Map<String, Double> gMean = null;
    Map<String, Double> uMean = null;
    Map<String, Double> aMean = null;
    Map<String, Double> pStd = null;
    Map<String, Double> gStd = null;
    Map<String, Double> uStd = null;
    Map<String, Double> aStd = null;
    Map<String, Double> pMM = null;
    Map<String, Double> gMM = null;
    Map<String, Double> uMM = null;
    Map<String, Double> aMM = null;
    boolean[] tf = new boolean[] { true, false };
    boolean[][] tf4 = new boolean[][] { { false, false, false, false }, { true, false, false, false }, { false, true, false, false }, { false, false, true, false }, { false, false, false, true }, { true, true, true, true } };
    //Total tests: 2^6 x 6^3 = 13,824 separate tests
    int testCount = 0;
    for (boolean collectPerformanceStats : tf) {
        for (boolean collectMemoryStats : tf) {
            for (boolean collectGCStats : tf) {
                for (boolean collectDataSetMetaData : tf) {
                    for (boolean collectScore : tf) {
                        for (boolean collectLearningRates : tf) {
                            for (boolean[] collectHistograms : tf4) {
                                for (boolean[] collectMeanStdev : tf4) {
                                    for (boolean[] collectMM : tf4) {
                                        SbeStatsReport report = new SbeStatsReport();
                                        report.reportIDs(null, null, null, time);
                                        report.reportStatsCollectionDurationMS(duration);
                                        report.reportIterationCount(iterCount);
                                        if (collectPerformanceStats) {
                                            report.reportPerformance(perfRuntime, perfTotalEx, perfTotalMB, perfEPS, perfMBPS);
                                        }
                                        if (collectMemoryStats) {
                                            report.reportMemoryUse(memJC, memJM, memOC, memOM, memDC, memDM);
                                        }
                                        if (collectGCStats) {
                                            report.reportGarbageCollection(gc1Name, gcdc1, gcdt1);
                                            report.reportGarbageCollection(gc2Name, gcdc2, gcdt2);
                                        }
                                        if (collectDataSetMetaData) {
                                        //TODO
                                        }
                                        if (collectScore) {
                                            report.reportScore(score);
                                        }
                                        if (collectLearningRates) {
                                            report.reportLearningRates(lrByParam);
                                        }
                                        if (collectHistograms[0]) {
                                            //Param hist
                                            report.reportHistograms(StatsType.Parameters, pHist);
                                        }
                                        if (collectHistograms[1]) {
                                            report.reportHistograms(StatsType.Gradients, gHist);
                                        }
                                        if (collectHistograms[2]) {
                                            //Update hist
                                            report.reportHistograms(StatsType.Updates, uHist);
                                        }
                                        if (collectHistograms[3]) {
                                            //Act hist
                                            report.reportHistograms(StatsType.Activations, aHist);
                                        }
                                        if (collectMeanStdev[0]) {
                                            //Param mean/stdev
                                            report.reportMean(StatsType.Parameters, pMean);
                                            report.reportStdev(StatsType.Parameters, pStd);
                                        }
                                        if (collectMeanStdev[1]) {
                                            //Param mean/stdev
                                            report.reportMean(StatsType.Gradients, gMean);
                                            report.reportStdev(StatsType.Gradients, gStd);
                                        }
                                        if (collectMeanStdev[2]) {
                                            //Update mean/stdev
                                            report.reportMean(StatsType.Updates, uMean);
                                            report.reportStdev(StatsType.Updates, uStd);
                                        }
                                        if (collectMeanStdev[3]) {
                                            //Act mean/stdev
                                            report.reportMean(StatsType.Activations, aMean);
                                            report.reportStdev(StatsType.Activations, aStd);
                                        }
                                        if (collectMM[0]) {
                                            //Param mean mag
                                            report.reportMeanMagnitudes(StatsType.Parameters, pMM);
                                        }
                                        if (collectMM[1]) {
                                            //Param mean mag
                                            report.reportMeanMagnitudes(StatsType.Gradients, gMM);
                                        }
                                        if (collectMM[2]) {
                                            //Update mm
                                            report.reportMeanMagnitudes(StatsType.Updates, uMM);
                                        }
                                        if (collectMM[3]) {
                                            //Act mm
                                            report.reportMeanMagnitudes(StatsType.Activations, aMM);
                                        }
                                        byte[] bytes = report.encode();
                                        StatsReport report2 = new SbeStatsReport();
                                        report2.decode(bytes);
                                        assertEquals(time, report2.getTimeStamp());
                                        assertEquals(duration, report2.getStatsCollectionDurationMs());
                                        assertEquals(iterCount, report2.getIterationCount());
                                        if (collectPerformanceStats) {
                                            assertEquals(perfRuntime, report2.getTotalRuntimeMs());
                                            assertEquals(perfTotalEx, report2.getTotalExamples());
                                            assertEquals(perfTotalMB, report2.getTotalMinibatches());
                                            assertEquals(perfEPS, report2.getExamplesPerSecond(), 0.0);
                                            assertEquals(perfMBPS, report2.getMinibatchesPerSecond(), 0.0);
                                            Assert.assertTrue(report2.hasPerformance());
                                        } else {
                                            Assert.assertFalse(report2.hasPerformance());
                                        }
                                        if (collectMemoryStats) {
                                            assertEquals(memJC, report2.getJvmCurrentBytes());
                                            assertEquals(memJM, report2.getJvmMaxBytes());
                                            assertEquals(memOC, report2.getOffHeapCurrentBytes());
                                            assertEquals(memOM, report2.getOffHeapMaxBytes());
                                            assertArrayEquals(memDC, report2.getDeviceCurrentBytes());
                                            assertArrayEquals(memDM, report2.getDeviceMaxBytes());
                                            Assert.assertTrue(report2.hasMemoryUse());
                                        } else {
                                            Assert.assertFalse(report2.hasMemoryUse());
                                        }
                                        if (collectGCStats) {
                                            List<Pair<String, int[]>> gcs = report2.getGarbageCollectionStats();
                                            Assert.assertEquals(2, gcs.size());
                                            assertNullOrZeroLength(gcs.get(0).getFirst());
                                            Assert.assertArrayEquals(new int[] { gcdc1, gcdt1 }, gcs.get(0).getSecond());
                                            assertNullOrZeroLength(gcs.get(1).getFirst());
                                            Assert.assertArrayEquals(new int[] { gcdc2, gcdt2 }, gcs.get(1).getSecond());
                                            Assert.assertTrue(report2.hasGarbageCollection());
                                        } else {
                                            Assert.assertFalse(report2.hasGarbageCollection());
                                        }
                                        if (collectDataSetMetaData) {
                                        //TODO
                                        }
                                        if (collectScore) {
                                            assertEquals(score, report2.getScore(), 0.0);
                                            Assert.assertTrue(report2.hasScore());
                                        } else {
                                            Assert.assertFalse(report2.hasScore());
                                        }
                                        if (collectLearningRates) {
                                            assertNull(report2.getLearningRates());
                                        } else {
                                            Assert.assertFalse(report2.hasLearningRates());
                                        }
                                        assertNull(report2.getHistograms(StatsType.Parameters));
                                        Assert.assertFalse(report2.hasHistograms(StatsType.Parameters));
                                        assertNull(report2.getHistograms(StatsType.Gradients));
                                        Assert.assertFalse(report2.hasHistograms(StatsType.Gradients));
                                        assertNull(report2.getHistograms(StatsType.Updates));
                                        Assert.assertFalse(report2.hasHistograms(StatsType.Updates));
                                        assertNull(report2.getHistograms(StatsType.Activations));
                                        Assert.assertFalse(report2.hasHistograms(StatsType.Activations));
                                        assertNull(report2.getMean(StatsType.Parameters));
                                        assertNull(report2.getStdev(StatsType.Parameters));
                                        Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Mean));
                                        Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Stdev));
                                        assertNull(report2.getMean(StatsType.Gradients));
                                        assertNull(report2.getStdev(StatsType.Gradients));
                                        Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Mean));
                                        Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Stdev));
                                        assertNull(report2.getMean(StatsType.Updates));
                                        assertNull(report2.getStdev(StatsType.Updates));
                                        Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.Mean));
                                        Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.Stdev));
                                        assertNull(report2.getMean(StatsType.Activations));
                                        assertNull(report2.getStdev(StatsType.Activations));
                                        Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.Mean));
                                        Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.Stdev));
                                        assertNull(report2.getMeanMagnitudes(StatsType.Parameters));
                                        Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes));
                                        assertNull(report2.getMeanMagnitudes(StatsType.Gradients));
                                        Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.MeanMagnitudes));
                                        assertNull(report2.getMeanMagnitudes(StatsType.Updates));
                                        Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes));
                                        assertNull(report2.getMeanMagnitudes(StatsType.Activations));
                                        Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.MeanMagnitudes));
                                        //Check standard Java serialization
                                        ByteArrayOutputStream baos = new ByteArrayOutputStream();
                                        ObjectOutputStream oos = new ObjectOutputStream(baos);
                                        oos.writeObject(report);
                                        oos.close();
                                        byte[] javaBytes = baos.toByteArray();
                                        ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(javaBytes));
                                        SbeStatsReport report3 = (SbeStatsReport) ois.readObject();
                                        assertEquals(report, report3);
                                        testCount++;
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    Assert.assertEquals(13824, testCount);
}
Also used : SbeStatsReport(org.deeplearning4j.ui.stats.impl.SbeStatsReport) SbeStatsReport(org.deeplearning4j.ui.stats.impl.SbeStatsReport) Pair(org.deeplearning4j.berkeley.Pair) Test(org.junit.Test)

Example 3 with SbeStatsReport

use of org.deeplearning4j.ui.stats.impl.SbeStatsReport in project deeplearning4j by deeplearning4j.

the class TestStatsStorage method getReport.

private static StatsReport getReport(int sid, int tid, int wid, long time, boolean useJ7Storage) {
    StatsReport rep;
    if (useJ7Storage) {
        rep = new JavaStatsReport();
    } else {
        rep = new SbeStatsReport();
    }
    rep.reportIDs("sid" + sid, "tid" + tid, "wid" + wid, time);
    rep.reportScore(100.0);
    rep.reportPerformance(1000, 1001, 1002, 1003.0, 1004.0);
    return rep;
}
Also used : SbeStatsReport(org.deeplearning4j.ui.stats.impl.SbeStatsReport) JavaStatsReport(org.deeplearning4j.ui.stats.impl.java.JavaStatsReport) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport) SbeStatsReport(org.deeplearning4j.ui.stats.impl.SbeStatsReport) JavaStatsReport(org.deeplearning4j.ui.stats.impl.java.JavaStatsReport)

Example 4 with SbeStatsReport

use of org.deeplearning4j.ui.stats.impl.SbeStatsReport in project deeplearning4j by deeplearning4j.

the class TestRemoteReceiver method testRemoteBasic.

@Test
@Ignore
public void testRemoteBasic() throws Exception {
    List<Persistable> updates = new ArrayList<>();
    List<Persistable> staticInfo = new ArrayList<>();
    List<StorageMetaData> metaData = new ArrayList<>();
    CollectionStatsStorageRouter collectionRouter = new CollectionStatsStorageRouter(metaData, staticInfo, updates);
    UIServer s = UIServer.getInstance();
    s.enableRemoteListener(collectionRouter, false);
    RemoteUIStatsStorageRouter remoteRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
    SbeStatsReport update1 = new SbeStatsReport();
    update1.setDeviceCurrentBytes(new long[] { 1, 2 });
    update1.reportIterationCount(10);
    update1.reportIDs("sid", "tid", "wid", 123456);
    update1.reportPerformance(10, 20, 30, 40, 50);
    SbeStatsReport update2 = new SbeStatsReport();
    update2.setDeviceCurrentBytes(new long[] { 3, 4 });
    update2.reportIterationCount(20);
    update2.reportIDs("sid2", "tid2", "wid2", 123456);
    update2.reportPerformance(11, 21, 31, 40, 50);
    StorageMetaData smd1 = new SbeStorageMetaData(123, "sid", "typeid", "wid", "initTypeClass", "updaterTypeClass");
    StorageMetaData smd2 = new SbeStorageMetaData(456, "sid2", "typeid2", "wid2", "initTypeClass2", "updaterTypeClass2");
    SbeStatsInitializationReport init1 = new SbeStatsInitializationReport();
    init1.reportIDs("sid", "wid", "tid", 3145253452L);
    init1.reportHardwareInfo(1, 2, 3, 4, null, null, "2344253");
    remoteRouter.putUpdate(update1);
    Thread.sleep(100);
    remoteRouter.putStorageMetaData(smd1);
    Thread.sleep(100);
    remoteRouter.putStaticInfo(init1);
    Thread.sleep(100);
    remoteRouter.putUpdate(update2);
    Thread.sleep(100);
    remoteRouter.putStorageMetaData(smd2);
    Thread.sleep(2000);
    assertEquals(2, metaData.size());
    assertEquals(2, updates.size());
    assertEquals(1, staticInfo.size());
    assertEquals(Arrays.asList(update1, update2), updates);
    assertEquals(Arrays.asList(smd1, smd2), metaData);
    assertEquals(Collections.singletonList(init1), staticInfo);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) SbeStorageMetaData(org.deeplearning4j.ui.storage.impl.SbeStorageMetaData) Persistable(org.deeplearning4j.api.storage.Persistable) SbeStatsInitializationReport(org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport) SbeStatsReport(org.deeplearning4j.ui.stats.impl.SbeStatsReport) UIServer(org.deeplearning4j.ui.api.UIServer) ArrayList(java.util.ArrayList) CollectionStatsStorageRouter(org.deeplearning4j.api.storage.impl.CollectionStatsStorageRouter) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) SbeStorageMetaData(org.deeplearning4j.ui.storage.impl.SbeStorageMetaData) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

SbeStatsReport (org.deeplearning4j.ui.stats.impl.SbeStatsReport)4 Test (org.junit.Test)3 ArrayList (java.util.ArrayList)2 Pair (org.deeplearning4j.berkeley.Pair)2 HashMap (java.util.HashMap)1 Persistable (org.deeplearning4j.api.storage.Persistable)1 StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)1 CollectionStatsStorageRouter (org.deeplearning4j.api.storage.impl.CollectionStatsStorageRouter)1 RemoteUIStatsStorageRouter (org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter)1 UIServer (org.deeplearning4j.ui.api.UIServer)1 StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)1 SbeStatsInitializationReport (org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport)1 JavaStatsReport (org.deeplearning4j.ui.stats.impl.java.JavaStatsReport)1 SbeStorageMetaData (org.deeplearning4j.ui.storage.impl.SbeStorageMetaData)1 Ignore (org.junit.Ignore)1