Search in sources :

Example 1 with JavaStatsInitializationReport

use of org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport in project deeplearning4j by deeplearning4j.

the class TestStatsClasses method testStatsInitializationReport.

@Test
public void testStatsInitializationReport() throws Exception {
    boolean[] tf = new boolean[] { true, false };
    for (boolean useJ7 : new boolean[] { false, true }) {
        //IDs
        String sessionID = "sid";
        String typeID = "tid";
        String workerID = "wid";
        long timestamp = -1;
        //Hardware info
        int jvmAvailableProcessors = 1;
        int numDevices = 2;
        long jvmMaxMemory = 3;
        long offHeapMaxMemory = 4;
        long[] deviceTotalMemory = new long[] { 5, 6 };
        String[] deviceDescription = new String[] { "7", "8" };
        String hwUID = "8a";
        //Software info
        String arch = "9";
        String osName = "10";
        String jvmName = "11";
        String jvmVersion = "12";
        String jvmSpecVersion = "13";
        String nd4jBackendClass = "14";
        String nd4jDataTypeName = "15";
        String hostname = "15a";
        String jvmUID = "15b";
        Map<String, String> swEnvInfo = new HashMap<>();
        swEnvInfo.put("env15c-1", "SomeData");
        swEnvInfo.put("env15c-2", "OtherData");
        swEnvInfo.put("env15c-3", "EvenMoreData");
        //Model info
        String modelClassName = "16";
        String modelConfigJson = "17";
        String[] modelparamNames = new String[] { "18", "19", "20", "21" };
        int numLayers = 22;
        long numParams = 23;
        for (boolean hasHardwareInfo : tf) {
            for (boolean hasSoftwareInfo : tf) {
                for (boolean hasModelInfo : tf) {
                    StatsInitializationReport report;
                    if (useJ7) {
                        report = new JavaStatsInitializationReport();
                    } else {
                        report = new SbeStatsInitializationReport();
                    }
                    report.reportIDs(sessionID, typeID, workerID, timestamp);
                    if (hasHardwareInfo) {
                        report.reportHardwareInfo(jvmAvailableProcessors, numDevices, jvmMaxMemory, offHeapMaxMemory, deviceTotalMemory, deviceDescription, hwUID);
                    }
                    if (hasSoftwareInfo) {
                        report.reportSoftwareInfo(arch, osName, jvmName, jvmVersion, jvmSpecVersion, nd4jBackendClass, nd4jDataTypeName, hostname, jvmUID, swEnvInfo);
                    }
                    if (hasModelInfo) {
                        report.reportModelInfo(modelClassName, modelConfigJson, modelparamNames, numLayers, numParams);
                    }
                    byte[] asBytes = report.encode();
                    // = new SbeStatsInitializationReport();
                    StatsInitializationReport report2;
                    if (useJ7) {
                        report2 = new JavaStatsInitializationReport();
                    } else {
                        report2 = new SbeStatsInitializationReport();
                    }
                    report2.decode(asBytes);
                    assertEquals(report, report2);
                    assertEquals(sessionID, report2.getSessionID());
                    assertEquals(typeID, report2.getTypeID());
                    assertEquals(workerID, report2.getWorkerID());
                    assertEquals(timestamp, report2.getTimeStamp());
                    if (hasHardwareInfo) {
                        assertEquals(jvmAvailableProcessors, report2.getHwJvmAvailableProcessors());
                        assertEquals(numDevices, report2.getHwNumDevices());
                        assertEquals(jvmMaxMemory, report2.getHwJvmMaxMemory());
                        assertEquals(offHeapMaxMemory, report2.getHwOffHeapMaxMemory());
                        assertArrayEquals(deviceTotalMemory, report2.getHwDeviceTotalMemory());
                        assertArrayEquals(deviceDescription, report2.getHwDeviceDescription());
                        assertEquals(hwUID, report2.getHwHardwareUID());
                        assertTrue(report2.hasHardwareInfo());
                    } else {
                        assertFalse(report2.hasHardwareInfo());
                    }
                    if (hasSoftwareInfo) {
                        assertEquals(arch, report2.getSwArch());
                        assertEquals(osName, report2.getSwOsName());
                        assertEquals(jvmName, report2.getSwJvmName());
                        assertEquals(jvmVersion, report2.getSwJvmVersion());
                        assertEquals(jvmSpecVersion, report2.getSwJvmSpecVersion());
                        assertEquals(nd4jBackendClass, report2.getSwNd4jBackendClass());
                        assertEquals(nd4jDataTypeName, report2.getSwNd4jDataTypeName());
                        assertEquals(jvmUID, report2.getSwJvmUID());
                        assertEquals(hostname, report2.getSwHostName());
                        assertEquals(swEnvInfo, report2.getSwEnvironmentInfo());
                        assertTrue(report2.hasSoftwareInfo());
                    } else {
                        assertFalse(report2.hasSoftwareInfo());
                    }
                    if (hasModelInfo) {
                        assertEquals(modelClassName, report2.getModelClassName());
                        assertEquals(modelConfigJson, report2.getModelConfigJson());
                        assertArrayEquals(modelparamNames, report2.getModelParamNames());
                        assertEquals(numLayers, report2.getModelNumLayers());
                        assertEquals(numParams, report2.getModelNumParams());
                        assertTrue(report2.hasModelInfo());
                    } else {
                        assertFalse(report2.hasModelInfo());
                    }
                    //Check standard Java serialization
                    ByteArrayOutputStream baos = new ByteArrayOutputStream();
                    ObjectOutputStream oos = new ObjectOutputStream(baos);
                    oos.writeObject(report);
                    oos.close();
                    byte[] javaBytes = baos.toByteArray();
                    ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(javaBytes));
                    StatsInitializationReport report3 = (StatsInitializationReport) ois.readObject();
                    assertEquals(report, report3);
                }
            }
        }
    }
}
Also used : JavaStatsInitializationReport(org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport) SbeStatsInitializationReport(org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport) SbeStatsInitializationReport(org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport) HashMap(java.util.HashMap) JavaStatsInitializationReport(org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport) Test(org.junit.Test)

Example 2 with JavaStatsInitializationReport

use of org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport in project deeplearning4j by deeplearning4j.

the class TestStatsClasses method testStatsInitializationReportNullValues.

@Test
public void testStatsInitializationReportNullValues() throws Exception {
    //Sanity check: shouldn't have any issues with encoding/decoding null values...
    boolean[] tf = new boolean[] { true, false };
    for (boolean useJ7 : new boolean[] { false, true }) {
        //Hardware info
        int jvmAvailableProcessors = 1;
        int numDevices = 2;
        long jvmMaxMemory = 3;
        long offHeapMaxMemory = 4;
        long[] deviceTotalMemory = null;
        String[] deviceDescription = null;
        String hwUID = null;
        //Software info
        String arch = null;
        String osName = null;
        String jvmName = null;
        String jvmVersion = null;
        String jvmSpecVersion = null;
        String nd4jBackendClass = null;
        String nd4jDataTypeName = null;
        String hostname = null;
        String jvmUID = null;
        Map<String, String> swEnvInfo = null;
        //Model info
        String modelClassName = null;
        String modelConfigJson = null;
        String[] modelparamNames = null;
        int numLayers = 22;
        long numParams = 23;
        for (boolean hasHardwareInfo : tf) {
            for (boolean hasSoftwareInfo : tf) {
                for (boolean hasModelInfo : tf) {
                    System.out.println(hasHardwareInfo + "\t" + hasSoftwareInfo + "\t" + hasModelInfo);
                    StatsInitializationReport report;
                    if (useJ7) {
                        report = new JavaStatsInitializationReport();
                    } else {
                        report = new SbeStatsInitializationReport();
                    }
                    report.reportIDs(null, null, null, -1);
                    if (hasHardwareInfo) {
                        report.reportHardwareInfo(jvmAvailableProcessors, numDevices, jvmMaxMemory, offHeapMaxMemory, deviceTotalMemory, deviceDescription, hwUID);
                    }
                    if (hasSoftwareInfo) {
                        report.reportSoftwareInfo(arch, osName, jvmName, jvmVersion, jvmSpecVersion, nd4jBackendClass, nd4jDataTypeName, hostname, jvmUID, swEnvInfo);
                    }
                    if (hasModelInfo) {
                        report.reportModelInfo(modelClassName, modelConfigJson, modelparamNames, numLayers, numParams);
                    }
                    byte[] asBytes = report.encode();
                    StatsInitializationReport report2;
                    if (useJ7) {
                        report2 = new JavaStatsInitializationReport();
                    } else {
                        report2 = new SbeStatsInitializationReport();
                    }
                    report2.decode(asBytes);
                    if (hasHardwareInfo) {
                        assertEquals(jvmAvailableProcessors, report2.getHwJvmAvailableProcessors());
                        assertEquals(numDevices, report2.getHwNumDevices());
                        assertEquals(jvmMaxMemory, report2.getHwJvmMaxMemory());
                        assertEquals(offHeapMaxMemory, report2.getHwOffHeapMaxMemory());
                        if (useJ7) {
                            assertArrayEquals(null, report2.getHwDeviceTotalMemory());
                            assertArrayEquals(null, report2.getHwDeviceDescription());
                        } else {
                            //Edge case: nDevices = 2, but missing mem data -> expect long[] of 0s out, due to fixed encoding
                            assertArrayEquals(new long[] { 0, 0 }, report2.getHwDeviceTotalMemory());
                            //As above
                            assertArrayEquals(new String[] { "", "" }, report2.getHwDeviceDescription());
                        }
                        assertNullOrZeroLength(report2.getHwHardwareUID());
                        assertTrue(report2.hasHardwareInfo());
                    } else {
                        assertFalse(report2.hasHardwareInfo());
                    }
                    if (hasSoftwareInfo) {
                        assertNullOrZeroLength(report2.getSwArch());
                        assertNullOrZeroLength(report2.getSwOsName());
                        assertNullOrZeroLength(report2.getSwJvmName());
                        assertNullOrZeroLength(report2.getSwJvmVersion());
                        assertNullOrZeroLength(report2.getSwJvmSpecVersion());
                        assertNullOrZeroLength(report2.getSwNd4jBackendClass());
                        assertNullOrZeroLength(report2.getSwNd4jDataTypeName());
                        assertNullOrZeroLength(report2.getSwJvmUID());
                        assertNull(report2.getSwEnvironmentInfo());
                        assertTrue(report2.hasSoftwareInfo());
                    } else {
                        assertFalse(report2.hasSoftwareInfo());
                    }
                    if (hasModelInfo) {
                        assertNullOrZeroLength(report2.getModelClassName());
                        assertNullOrZeroLength(report2.getModelConfigJson());
                        assertNullOrZeroLengthArray(report2.getModelParamNames());
                        assertEquals(numLayers, report2.getModelNumLayers());
                        assertEquals(numParams, report2.getModelNumParams());
                        assertTrue(report2.hasModelInfo());
                    } else {
                        assertFalse(report2.hasModelInfo());
                    }
                    //Check standard Java serialization
                    ByteArrayOutputStream baos = new ByteArrayOutputStream();
                    ObjectOutputStream oos = new ObjectOutputStream(baos);
                    oos.writeObject(report);
                    oos.close();
                    byte[] javaBytes = baos.toByteArray();
                    ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(javaBytes));
                    StatsInitializationReport report3 = (StatsInitializationReport) ois.readObject();
                    assertEquals(report, report3);
                }
            }
        }
    }
}
Also used : JavaStatsInitializationReport(org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport) SbeStatsInitializationReport(org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport) SbeStatsInitializationReport(org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport) JavaStatsInitializationReport(org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport) Test(org.junit.Test)

Example 3 with JavaStatsInitializationReport

use of org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport in project deeplearning4j by deeplearning4j.

the class TestStatsStorage method getInitReport.

private static StatsInitializationReport getInitReport(int idNumber, int tid, int wid, boolean useJ7Storage) {
    StatsInitializationReport rep;
    if (useJ7Storage) {
        rep = new JavaStatsInitializationReport();
    } else {
        rep = new SbeStatsInitializationReport();
    }
    rep.reportModelInfo("classname", "jsonconfig", new String[] { "p0", "p1" }, 1, 10);
    rep.reportIDs("sid" + idNumber, "tid" + tid, "wid" + wid, 12345);
    rep.reportHardwareInfo(0, 2, 1000, 2000, new long[] { 3000, 4000 }, new String[] { "dev0", "dev1" }, "hardwareuid");
    Map<String, String> envInfo = new HashMap<>();
    envInfo.put("envInfo0", "value0");
    envInfo.put("envInfo1", "value1");
    rep.reportSoftwareInfo("arch", "osName", "jvmName", "jvmVersion", "1.8", "backend", "dtype", "hostname", "jvmuid", envInfo);
    return rep;
}
Also used : SbeStatsInitializationReport(org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport) JavaStatsInitializationReport(org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport) StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) SbeStatsInitializationReport(org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport) JavaStatsInitializationReport(org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport) HashMap(java.util.HashMap)

Aggregations

SbeStatsInitializationReport (org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport)3 JavaStatsInitializationReport (org.deeplearning4j.ui.stats.impl.java.JavaStatsInitializationReport)3 HashMap (java.util.HashMap)2 Test (org.junit.Test)2 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)1