use of org.deeplearning4j.ui.stats.impl.SbeStatsReport in project deeplearning4j by deeplearning4j.
the class TestStatsClasses method testSbeStatsUpdate.
@Test
public void testSbeStatsUpdate() throws Exception {
String[] paramNames = new String[] { "param0", "param1" };
String[] layerNames = new String[] { "layer0", "layer1" };
//IDs
String sessionID = "sid";
String typeID = "tid";
String workerID = "wid";
long timestamp = -1;
long time = System.currentTimeMillis();
int duration = 123456;
int iterCount = 123;
long perfRuntime = 1;
long perfTotalEx = 2;
long perfTotalMB = 3;
double perfEPS = 4.0;
double perfMBPS = 5.0;
long memJC = 6;
long memJM = 7;
long memOC = 8;
long memOM = 9;
long[] memDC = new long[] { 10, 11 };
long[] memDM = new long[] { 12, 13 };
String gc1Name = "14";
int gcdc1 = 16;
int gcdt1 = 17;
String gc2Name = "18";
int gcdc2 = 20;
int gcdt2 = 21;
double score = 22.0;
Map<String, Double> lrByParam = new HashMap<>();
lrByParam.put(paramNames[0], 22.5);
lrByParam.put(paramNames[1], 22.75);
Map<String, Histogram> pHist = new HashMap<>();
pHist.put(paramNames[0], new Histogram(23, 24, 2, new int[] { 25, 26 }));
pHist.put(paramNames[1], new Histogram(27, 28, 3, new int[] { 29, 30, 31 }));
Map<String, Histogram> gHist = new HashMap<>();
gHist.put(paramNames[0], new Histogram(230, 240, 2, new int[] { 250, 260 }));
gHist.put(paramNames[1], new Histogram(270, 280, 3, new int[] { 290, 300, 310 }));
Map<String, Histogram> uHist = new HashMap<>();
uHist.put(paramNames[0], new Histogram(32, 33, 2, new int[] { 34, 35 }));
uHist.put(paramNames[1], new Histogram(36, 37, 3, new int[] { 38, 39, 40 }));
Map<String, Histogram> aHist = new HashMap<>();
aHist.put(layerNames[0], new Histogram(41, 42, 2, new int[] { 43, 44 }));
aHist.put(layerNames[1], new Histogram(45, 46, 3, new int[] { 47, 48, 47 }));
Map<String, Double> pMean = new HashMap<>();
pMean.put(paramNames[0], 49.0);
pMean.put(paramNames[1], 50.0);
Map<String, Double> gMean = new HashMap<>();
gMean.put(paramNames[0], 49.1);
gMean.put(paramNames[1], 50.1);
Map<String, Double> uMean = new HashMap<>();
uMean.put(paramNames[0], 51.0);
uMean.put(paramNames[1], 52.0);
Map<String, Double> aMean = new HashMap<>();
aMean.put(layerNames[0], 53.0);
aMean.put(layerNames[1], 54.0);
Map<String, Double> pStd = new HashMap<>();
pStd.put(paramNames[0], 55.0);
pStd.put(paramNames[1], 56.0);
Map<String, Double> gStd = new HashMap<>();
gStd.put(paramNames[0], 55.1);
gStd.put(paramNames[1], 56.1);
Map<String, Double> uStd = new HashMap<>();
uStd.put(paramNames[0], 57.0);
uStd.put(paramNames[1], 58.0);
Map<String, Double> aStd = new HashMap<>();
aStd.put(layerNames[0], 59.0);
aStd.put(layerNames[1], 60.0);
Map<String, Double> pMM = new HashMap<>();
pMM.put(paramNames[0], 61.0);
pMM.put(paramNames[1], 62.0);
Map<String, Double> gMM = new HashMap<>();
gMM.put(paramNames[0], 61.1);
gMM.put(paramNames[1], 62.1);
Map<String, Double> uMM = new HashMap<>();
uMM.put(paramNames[0], 63.0);
uMM.put(paramNames[1], 64.0);
Map<String, Double> aMM = new HashMap<>();
aMM.put(layerNames[0], 65.0);
aMM.put(layerNames[1], 66.0);
List<Serializable> metaDataList = new ArrayList<>();
metaDataList.add("meta1");
metaDataList.add("meta2");
metaDataList.add("meta3");
Class<?> metaDataClass = String.class;
boolean[] tf = new boolean[] { true, false };
boolean[][] tf4 = new boolean[][] { { false, false, false, false }, { true, false, false, false }, { false, true, false, false }, { false, false, true, false }, { false, false, false, true }, { true, true, true, true } };
//Total tests: 2^6 x 6^3 = 13,824 separate tests
int testCount = 0;
for (boolean collectPerformanceStats : tf) {
for (boolean collectMemoryStats : tf) {
for (boolean collectGCStats : tf) {
for (boolean collectScore : tf) {
for (boolean collectLearningRates : tf) {
for (boolean collectMetaData : tf) {
for (boolean[] collectHistograms : tf4) {
for (boolean[] collectMeanStdev : tf4) {
for (boolean[] collectMM : tf4) {
SbeStatsReport report = new SbeStatsReport();
report.reportIDs(sessionID, typeID, workerID, time);
report.reportStatsCollectionDurationMS(duration);
report.reportIterationCount(iterCount);
if (collectPerformanceStats) {
report.reportPerformance(perfRuntime, perfTotalEx, perfTotalMB, perfEPS, perfMBPS);
}
if (collectMemoryStats) {
report.reportMemoryUse(memJC, memJM, memOC, memOM, memDC, memDM);
}
if (collectGCStats) {
report.reportGarbageCollection(gc1Name, gcdc1, gcdt1);
report.reportGarbageCollection(gc2Name, gcdc2, gcdt2);
}
if (collectScore) {
report.reportScore(score);
}
if (collectLearningRates) {
report.reportLearningRates(lrByParam);
}
if (collectMetaData) {
report.reportDataSetMetaData(metaDataList, metaDataClass);
}
if (collectHistograms[0]) {
//Param hist
report.reportHistograms(StatsType.Parameters, pHist);
}
if (collectHistograms[1]) {
//Grad hist
report.reportHistograms(StatsType.Gradients, gHist);
}
if (collectHistograms[2]) {
//Update hist
report.reportHistograms(StatsType.Updates, uHist);
}
if (collectHistograms[3]) {
//Act hist
report.reportHistograms(StatsType.Activations, aHist);
}
if (collectMeanStdev[0]) {
//Param mean/stdev
report.reportMean(StatsType.Parameters, pMean);
report.reportStdev(StatsType.Parameters, pStd);
}
if (collectMeanStdev[1]) {
//Gradient mean/stdev
report.reportMean(StatsType.Gradients, gMean);
report.reportStdev(StatsType.Gradients, gStd);
}
if (collectMeanStdev[2]) {
//Update mean/stdev
report.reportMean(StatsType.Updates, uMean);
report.reportStdev(StatsType.Updates, uStd);
}
if (collectMeanStdev[3]) {
//Act mean/stdev
report.reportMean(StatsType.Activations, aMean);
report.reportStdev(StatsType.Activations, aStd);
}
if (collectMM[0]) {
//Param mean mag
report.reportMeanMagnitudes(StatsType.Parameters, pMM);
}
if (collectMM[1]) {
//Gradient mean mag
report.reportMeanMagnitudes(StatsType.Gradients, gMM);
}
if (collectMM[2]) {
//Update mm
report.reportMeanMagnitudes(StatsType.Updates, uMM);
}
if (collectMM[3]) {
//Act mm
report.reportMeanMagnitudes(StatsType.Activations, aMM);
}
byte[] bytes = report.encode();
StatsReport report2 = new SbeStatsReport();
report2.decode(bytes);
assertEquals(report, report2);
assertEquals(sessionID, report2.getSessionID());
assertEquals(typeID, report2.getTypeID());
assertEquals(workerID, report2.getWorkerID());
assertEquals(time, report2.getTimeStamp());
assertEquals(time, report2.getTimeStamp());
assertEquals(duration, report2.getStatsCollectionDurationMs());
assertEquals(iterCount, report2.getIterationCount());
if (collectPerformanceStats) {
assertEquals(perfRuntime, report2.getTotalRuntimeMs());
assertEquals(perfTotalEx, report2.getTotalExamples());
assertEquals(perfTotalMB, report2.getTotalMinibatches());
assertEquals(perfEPS, report2.getExamplesPerSecond(), 0.0);
assertEquals(perfMBPS, report2.getMinibatchesPerSecond(), 0.0);
Assert.assertTrue(report2.hasPerformance());
} else {
Assert.assertFalse(report2.hasPerformance());
}
if (collectMemoryStats) {
assertEquals(memJC, report2.getJvmCurrentBytes());
assertEquals(memJM, report2.getJvmMaxBytes());
assertEquals(memOC, report2.getOffHeapCurrentBytes());
assertEquals(memOM, report2.getOffHeapMaxBytes());
assertArrayEquals(memDC, report2.getDeviceCurrentBytes());
assertArrayEquals(memDM, report2.getDeviceMaxBytes());
Assert.assertTrue(report2.hasMemoryUse());
} else {
Assert.assertFalse(report2.hasMemoryUse());
}
if (collectGCStats) {
List<Pair<String, int[]>> gcs = report2.getGarbageCollectionStats();
Assert.assertEquals(2, gcs.size());
Assert.assertEquals(gc1Name, gcs.get(0).getFirst());
Assert.assertArrayEquals(new int[] { gcdc1, gcdt1 }, gcs.get(0).getSecond());
Assert.assertEquals(gc2Name, gcs.get(1).getFirst());
Assert.assertArrayEquals(new int[] { gcdc2, gcdt2 }, gcs.get(1).getSecond());
Assert.assertTrue(report2.hasGarbageCollection());
} else {
Assert.assertFalse(report2.hasGarbageCollection());
}
if (collectScore) {
assertEquals(score, report2.getScore(), 0.0);
Assert.assertTrue(report2.hasScore());
} else {
Assert.assertFalse(report2.hasScore());
}
if (collectLearningRates) {
assertEquals(lrByParam.keySet(), report2.getLearningRates().keySet());
for (String s : lrByParam.keySet()) {
assertEquals(lrByParam.get(s), report2.getLearningRates().get(s), 1e-6);
}
Assert.assertTrue(report2.hasLearningRates());
} else {
Assert.assertFalse(report2.hasLearningRates());
}
if (collectMetaData) {
assertNotNull(report2.getDataSetMetaData());
assertEquals(metaDataList, report2.getDataSetMetaData());
assertEquals(metaDataClass.getName(), report2.getDataSetMetaDataClassName());
assertTrue(report2.hasDataSetMetaData());
} else {
assertFalse(report2.hasDataSetMetaData());
}
if (collectHistograms[0]) {
assertEquals(pHist, report2.getHistograms(StatsType.Parameters));
Assert.assertTrue(report2.hasHistograms(StatsType.Parameters));
} else {
Assert.assertFalse(report2.hasHistograms(StatsType.Parameters));
}
if (collectHistograms[1]) {
assertEquals(gHist, report2.getHistograms(StatsType.Gradients));
Assert.assertTrue(report2.hasHistograms(StatsType.Gradients));
} else {
Assert.assertFalse(report2.hasHistograms(StatsType.Gradients));
}
if (collectHistograms[2]) {
assertEquals(uHist, report2.getHistograms(StatsType.Updates));
Assert.assertTrue(report2.hasHistograms(StatsType.Updates));
} else {
Assert.assertFalse(report2.hasHistograms(StatsType.Updates));
}
if (collectHistograms[3]) {
assertEquals(aHist, report2.getHistograms(StatsType.Activations));
Assert.assertTrue(report2.hasHistograms(StatsType.Activations));
} else {
Assert.assertFalse(report2.hasHistograms(StatsType.Activations));
}
if (collectMeanStdev[0]) {
assertEquals(pMean, report2.getMean(StatsType.Parameters));
assertEquals(pStd, report2.getStdev(StatsType.Parameters));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Mean));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Stdev));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Stdev));
}
if (collectMeanStdev[1]) {
assertEquals(gMean, report2.getMean(StatsType.Gradients));
assertEquals(gStd, report2.getStdev(StatsType.Gradients));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Mean));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Stdev));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Stdev));
}
if (collectMeanStdev[2]) {
assertEquals(uMean, report2.getMean(StatsType.Updates));
assertEquals(uStd, report2.getStdev(StatsType.Updates));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Updates, SummaryType.Mean));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Updates, SummaryType.Stdev));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.Stdev));
}
if (collectMeanStdev[3]) {
assertEquals(aMean, report2.getMean(StatsType.Activations));
assertEquals(aStd, report2.getStdev(StatsType.Activations));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Activations, SummaryType.Mean));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Activations, SummaryType.Stdev));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.Stdev));
}
if (collectMM[0]) {
assertEquals(pMM, report2.getMeanMagnitudes(StatsType.Parameters));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes));
}
if (collectMM[1]) {
assertEquals(gMM, report2.getMeanMagnitudes(StatsType.Gradients));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Gradients, SummaryType.MeanMagnitudes));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.MeanMagnitudes));
}
if (collectMM[2]) {
assertEquals(uMM, report2.getMeanMagnitudes(StatsType.Updates));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes));
}
if (collectMM[3]) {
assertEquals(aMM, report2.getMeanMagnitudes(StatsType.Activations));
Assert.assertTrue(report2.hasSummaryStats(StatsType.Activations, SummaryType.MeanMagnitudes));
} else {
Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.MeanMagnitudes));
}
//Check standard Java serialization
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(report);
oos.close();
byte[] javaBytes = baos.toByteArray();
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(javaBytes));
SbeStatsReport report3 = (SbeStatsReport) ois.readObject();
assertEquals(report, report3);
testCount++;
}
}
}
}
}
}
}
}
}
Assert.assertEquals(13824, testCount);
}
use of org.deeplearning4j.ui.stats.impl.SbeStatsReport in project deeplearning4j by deeplearning4j.
the class TestStatsClasses method testSbeStatsUpdateNullValues.
@Test
public void testSbeStatsUpdateNullValues() throws Exception {
//new String[]{"param0", "param1"};
String[] paramNames = null;
long time = System.currentTimeMillis();
int duration = 123456;
int iterCount = 123;
long perfRuntime = 1;
long perfTotalEx = 2;
long perfTotalMB = 3;
double perfEPS = 4.0;
double perfMBPS = 5.0;
long memJC = 6;
long memJM = 7;
long memOC = 8;
long memOM = 9;
long[] memDC = null;
long[] memDM = null;
String gc1Name = null;
int gcdc1 = 16;
int gcdt1 = 17;
String gc2Name = null;
int gcdc2 = 20;
int gcdt2 = 21;
double score = 22.0;
Map<String, Double> lrByParam = null;
Map<String, Histogram> pHist = null;
Map<String, Histogram> gHist = null;
Map<String, Histogram> uHist = null;
Map<String, Histogram> aHist = null;
Map<String, Double> pMean = null;
Map<String, Double> gMean = null;
Map<String, Double> uMean = null;
Map<String, Double> aMean = null;
Map<String, Double> pStd = null;
Map<String, Double> gStd = null;
Map<String, Double> uStd = null;
Map<String, Double> aStd = null;
Map<String, Double> pMM = null;
Map<String, Double> gMM = null;
Map<String, Double> uMM = null;
Map<String, Double> aMM = null;
boolean[] tf = new boolean[] { true, false };
boolean[][] tf4 = new boolean[][] { { false, false, false, false }, { true, false, false, false }, { false, true, false, false }, { false, false, true, false }, { false, false, false, true }, { true, true, true, true } };
//Total tests: 2^6 x 6^3 = 13,824 separate tests
int testCount = 0;
for (boolean collectPerformanceStats : tf) {
for (boolean collectMemoryStats : tf) {
for (boolean collectGCStats : tf) {
for (boolean collectDataSetMetaData : tf) {
for (boolean collectScore : tf) {
for (boolean collectLearningRates : tf) {
for (boolean[] collectHistograms : tf4) {
for (boolean[] collectMeanStdev : tf4) {
for (boolean[] collectMM : tf4) {
SbeStatsReport report = new SbeStatsReport();
report.reportIDs(null, null, null, time);
report.reportStatsCollectionDurationMS(duration);
report.reportIterationCount(iterCount);
if (collectPerformanceStats) {
report.reportPerformance(perfRuntime, perfTotalEx, perfTotalMB, perfEPS, perfMBPS);
}
if (collectMemoryStats) {
report.reportMemoryUse(memJC, memJM, memOC, memOM, memDC, memDM);
}
if (collectGCStats) {
report.reportGarbageCollection(gc1Name, gcdc1, gcdt1);
report.reportGarbageCollection(gc2Name, gcdc2, gcdt2);
}
if (collectDataSetMetaData) {
//TODO
}
if (collectScore) {
report.reportScore(score);
}
if (collectLearningRates) {
report.reportLearningRates(lrByParam);
}
if (collectHistograms[0]) {
//Param hist
report.reportHistograms(StatsType.Parameters, pHist);
}
if (collectHistograms[1]) {
report.reportHistograms(StatsType.Gradients, gHist);
}
if (collectHistograms[2]) {
//Update hist
report.reportHistograms(StatsType.Updates, uHist);
}
if (collectHistograms[3]) {
//Act hist
report.reportHistograms(StatsType.Activations, aHist);
}
if (collectMeanStdev[0]) {
//Param mean/stdev
report.reportMean(StatsType.Parameters, pMean);
report.reportStdev(StatsType.Parameters, pStd);
}
if (collectMeanStdev[1]) {
//Param mean/stdev
report.reportMean(StatsType.Gradients, gMean);
report.reportStdev(StatsType.Gradients, gStd);
}
if (collectMeanStdev[2]) {
//Update mean/stdev
report.reportMean(StatsType.Updates, uMean);
report.reportStdev(StatsType.Updates, uStd);
}
if (collectMeanStdev[3]) {
//Act mean/stdev
report.reportMean(StatsType.Activations, aMean);
report.reportStdev(StatsType.Activations, aStd);
}
if (collectMM[0]) {
//Param mean mag
report.reportMeanMagnitudes(StatsType.Parameters, pMM);
}
if (collectMM[1]) {
//Param mean mag
report.reportMeanMagnitudes(StatsType.Gradients, gMM);
}
if (collectMM[2]) {
//Update mm
report.reportMeanMagnitudes(StatsType.Updates, uMM);
}
if (collectMM[3]) {
//Act mm
report.reportMeanMagnitudes(StatsType.Activations, aMM);
}
byte[] bytes = report.encode();
StatsReport report2 = new SbeStatsReport();
report2.decode(bytes);
assertEquals(time, report2.getTimeStamp());
assertEquals(duration, report2.getStatsCollectionDurationMs());
assertEquals(iterCount, report2.getIterationCount());
if (collectPerformanceStats) {
assertEquals(perfRuntime, report2.getTotalRuntimeMs());
assertEquals(perfTotalEx, report2.getTotalExamples());
assertEquals(perfTotalMB, report2.getTotalMinibatches());
assertEquals(perfEPS, report2.getExamplesPerSecond(), 0.0);
assertEquals(perfMBPS, report2.getMinibatchesPerSecond(), 0.0);
Assert.assertTrue(report2.hasPerformance());
} else {
Assert.assertFalse(report2.hasPerformance());
}
if (collectMemoryStats) {
assertEquals(memJC, report2.getJvmCurrentBytes());
assertEquals(memJM, report2.getJvmMaxBytes());
assertEquals(memOC, report2.getOffHeapCurrentBytes());
assertEquals(memOM, report2.getOffHeapMaxBytes());
assertArrayEquals(memDC, report2.getDeviceCurrentBytes());
assertArrayEquals(memDM, report2.getDeviceMaxBytes());
Assert.assertTrue(report2.hasMemoryUse());
} else {
Assert.assertFalse(report2.hasMemoryUse());
}
if (collectGCStats) {
List<Pair<String, int[]>> gcs = report2.getGarbageCollectionStats();
Assert.assertEquals(2, gcs.size());
assertNullOrZeroLength(gcs.get(0).getFirst());
Assert.assertArrayEquals(new int[] { gcdc1, gcdt1 }, gcs.get(0).getSecond());
assertNullOrZeroLength(gcs.get(1).getFirst());
Assert.assertArrayEquals(new int[] { gcdc2, gcdt2 }, gcs.get(1).getSecond());
Assert.assertTrue(report2.hasGarbageCollection());
} else {
Assert.assertFalse(report2.hasGarbageCollection());
}
if (collectDataSetMetaData) {
//TODO
}
if (collectScore) {
assertEquals(score, report2.getScore(), 0.0);
Assert.assertTrue(report2.hasScore());
} else {
Assert.assertFalse(report2.hasScore());
}
if (collectLearningRates) {
assertNull(report2.getLearningRates());
} else {
Assert.assertFalse(report2.hasLearningRates());
}
assertNull(report2.getHistograms(StatsType.Parameters));
Assert.assertFalse(report2.hasHistograms(StatsType.Parameters));
assertNull(report2.getHistograms(StatsType.Gradients));
Assert.assertFalse(report2.hasHistograms(StatsType.Gradients));
assertNull(report2.getHistograms(StatsType.Updates));
Assert.assertFalse(report2.hasHistograms(StatsType.Updates));
assertNull(report2.getHistograms(StatsType.Activations));
Assert.assertFalse(report2.hasHistograms(StatsType.Activations));
assertNull(report2.getMean(StatsType.Parameters));
assertNull(report2.getStdev(StatsType.Parameters));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.Stdev));
assertNull(report2.getMean(StatsType.Gradients));
assertNull(report2.getStdev(StatsType.Gradients));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.Stdev));
assertNull(report2.getMean(StatsType.Updates));
assertNull(report2.getStdev(StatsType.Updates));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.Stdev));
assertNull(report2.getMean(StatsType.Activations));
assertNull(report2.getStdev(StatsType.Activations));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.Mean));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.Stdev));
assertNull(report2.getMeanMagnitudes(StatsType.Parameters));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes));
assertNull(report2.getMeanMagnitudes(StatsType.Gradients));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Gradients, SummaryType.MeanMagnitudes));
assertNull(report2.getMeanMagnitudes(StatsType.Updates));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes));
assertNull(report2.getMeanMagnitudes(StatsType.Activations));
Assert.assertFalse(report2.hasSummaryStats(StatsType.Activations, SummaryType.MeanMagnitudes));
//Check standard Java serialization
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(report);
oos.close();
byte[] javaBytes = baos.toByteArray();
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(javaBytes));
SbeStatsReport report3 = (SbeStatsReport) ois.readObject();
assertEquals(report, report3);
testCount++;
}
}
}
}
}
}
}
}
}
Assert.assertEquals(13824, testCount);
}
use of org.deeplearning4j.ui.stats.impl.SbeStatsReport in project deeplearning4j by deeplearning4j.
the class TestStatsStorage method getReport.
private static StatsReport getReport(int sid, int tid, int wid, long time, boolean useJ7Storage) {
StatsReport rep;
if (useJ7Storage) {
rep = new JavaStatsReport();
} else {
rep = new SbeStatsReport();
}
rep.reportIDs("sid" + sid, "tid" + tid, "wid" + wid, time);
rep.reportScore(100.0);
rep.reportPerformance(1000, 1001, 1002, 1003.0, 1004.0);
return rep;
}
use of org.deeplearning4j.ui.stats.impl.SbeStatsReport in project deeplearning4j by deeplearning4j.
the class TestRemoteReceiver method testRemoteBasic.
@Test
@Ignore
public void testRemoteBasic() throws Exception {
List<Persistable> updates = new ArrayList<>();
List<Persistable> staticInfo = new ArrayList<>();
List<StorageMetaData> metaData = new ArrayList<>();
CollectionStatsStorageRouter collectionRouter = new CollectionStatsStorageRouter(metaData, staticInfo, updates);
UIServer s = UIServer.getInstance();
s.enableRemoteListener(collectionRouter, false);
RemoteUIStatsStorageRouter remoteRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
SbeStatsReport update1 = new SbeStatsReport();
update1.setDeviceCurrentBytes(new long[] { 1, 2 });
update1.reportIterationCount(10);
update1.reportIDs("sid", "tid", "wid", 123456);
update1.reportPerformance(10, 20, 30, 40, 50);
SbeStatsReport update2 = new SbeStatsReport();
update2.setDeviceCurrentBytes(new long[] { 3, 4 });
update2.reportIterationCount(20);
update2.reportIDs("sid2", "tid2", "wid2", 123456);
update2.reportPerformance(11, 21, 31, 40, 50);
StorageMetaData smd1 = new SbeStorageMetaData(123, "sid", "typeid", "wid", "initTypeClass", "updaterTypeClass");
StorageMetaData smd2 = new SbeStorageMetaData(456, "sid2", "typeid2", "wid2", "initTypeClass2", "updaterTypeClass2");
SbeStatsInitializationReport init1 = new SbeStatsInitializationReport();
init1.reportIDs("sid", "wid", "tid", 3145253452L);
init1.reportHardwareInfo(1, 2, 3, 4, null, null, "2344253");
remoteRouter.putUpdate(update1);
Thread.sleep(100);
remoteRouter.putStorageMetaData(smd1);
Thread.sleep(100);
remoteRouter.putStaticInfo(init1);
Thread.sleep(100);
remoteRouter.putUpdate(update2);
Thread.sleep(100);
remoteRouter.putStorageMetaData(smd2);
Thread.sleep(2000);
assertEquals(2, metaData.size());
assertEquals(2, updates.size());
assertEquals(1, staticInfo.size());
assertEquals(Arrays.asList(update1, update2), updates);
assertEquals(Arrays.asList(smd1, smd2), metaData);
assertEquals(Collections.singletonList(init1), staticInfo);
}
Aggregations