Search in sources :

Example 21 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method getMemory.

private static Map<String, Object> getMemory(List<Persistable> staticInfoAllWorkers, List<Persistable> updatesLastNMinutes, I18N i18n) {
    Map<String, Object> ret = new HashMap<>();
    //First: map workers to JVMs
    Set<String> jvmIDs = new HashSet<>();
    Map<String, String> workersToJvms = new HashMap<>();
    Map<String, Integer> workerNumDevices = new HashMap<>();
    Map<String, String[]> deviceNames = new HashMap<>();
    for (Persistable p : staticInfoAllWorkers) {
        //TODO validation/checks
        StatsInitializationReport init = (StatsInitializationReport) p;
        String jvmuid = init.getSwJvmUID();
        workersToJvms.put(p.getWorkerID(), jvmuid);
        jvmIDs.add(jvmuid);
        int nDevices = init.getHwNumDevices();
        workerNumDevices.put(p.getWorkerID(), nDevices);
        if (nDevices > 0) {
            String[] deviceNamesArr = init.getHwDeviceDescription();
            deviceNames.put(p.getWorkerID(), deviceNamesArr);
        }
    }
    List<String> jvmList = new ArrayList<>(jvmIDs);
    Collections.sort(jvmList);
    //For each unique JVM, collect memory info
    //Do this by selecting the first worker
    int count = 0;
    for (String jvm : jvmList) {
        List<String> workersForJvm = new ArrayList<>();
        for (String s : workersToJvms.keySet()) {
            if (workersToJvms.get(s).equals(jvm)) {
                workersForJvm.add(s);
            }
        }
        Collections.sort(workersForJvm);
        String wid = workersForJvm.get(0);
        int numDevices = workerNumDevices.get(wid);
        Map<String, Object> jvmData = new HashMap<>();
        List<Long> timestamps = new ArrayList<>();
        List<Float> fracJvm = new ArrayList<>();
        List<Float> fracOffHeap = new ArrayList<>();
        long[] lastBytes = new long[2 + numDevices];
        long[] lastMaxBytes = new long[2 + numDevices];
        List<List<Float>> fracDeviceMem = null;
        if (numDevices > 0) {
            fracDeviceMem = new ArrayList<>(numDevices);
            for (int i = 0; i < numDevices; i++) {
                fracDeviceMem.add(new ArrayList<>());
            }
        }
        for (Persistable p : updatesLastNMinutes) {
            //TODO single pass
            if (!p.getWorkerID().equals(wid))
                continue;
            if (!(p instanceof StatsReport))
                continue;
            StatsReport sp = (StatsReport) p;
            timestamps.add(sp.getTimeStamp());
            long jvmCurrentBytes = sp.getJvmCurrentBytes();
            long jvmMaxBytes = sp.getJvmMaxBytes();
            long ohCurrentBytes = sp.getOffHeapCurrentBytes();
            long ohMaxBytes = sp.getOffHeapMaxBytes();
            double jvmFrac = jvmCurrentBytes / ((double) jvmMaxBytes);
            double offheapFrac = ohCurrentBytes / ((double) ohMaxBytes);
            if (Double.isNaN(jvmFrac))
                jvmFrac = 0.0;
            if (Double.isNaN(offheapFrac))
                offheapFrac = 0.0;
            fracJvm.add((float) jvmFrac);
            fracOffHeap.add((float) offheapFrac);
            lastBytes[0] = jvmCurrentBytes;
            lastBytes[1] = ohCurrentBytes;
            lastMaxBytes[0] = jvmMaxBytes;
            lastMaxBytes[1] = ohMaxBytes;
            if (numDevices > 0) {
                long[] devBytes = sp.getDeviceCurrentBytes();
                long[] devMaxBytes = sp.getDeviceMaxBytes();
                for (int i = 0; i < numDevices; i++) {
                    double frac = devBytes[i] / ((double) devMaxBytes[i]);
                    if (Double.isNaN(frac))
                        frac = 0.0;
                    fracDeviceMem.get(i).add((float) frac);
                    lastBytes[2 + i] = devBytes[i];
                    lastMaxBytes[2 + i] = devMaxBytes[i];
                }
            }
        }
        List<List<Float>> fracUtilized = new ArrayList<>();
        fracUtilized.add(fracJvm);
        fracUtilized.add(fracOffHeap);
        String[] seriesNames = new String[2 + numDevices];
        seriesNames[0] = i18n.getMessage("train.system.hwTable.jvmCurrent");
        seriesNames[1] = i18n.getMessage("train.system.hwTable.offHeapCurrent");
        boolean[] isDevice = new boolean[2 + numDevices];
        String[] devNames = deviceNames.get(wid);
        for (int i = 0; i < numDevices; i++) {
            seriesNames[2 + i] = devNames != null && devNames.length > i ? devNames[i] : "";
            fracUtilized.add(fracDeviceMem.get(i));
            isDevice[2 + i] = true;
        }
        jvmData.put("times", timestamps);
        jvmData.put("isDevice", isDevice);
        jvmData.put("seriesNames", seriesNames);
        jvmData.put("values", fracUtilized);
        jvmData.put("currentBytes", lastBytes);
        jvmData.put("maxBytes", lastMaxBytes);
        ret.put(String.valueOf(count), jvmData);
        count++;
    }
    return ret;
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) Persistable(org.deeplearning4j.api.storage.Persistable) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport) AtomicInteger(java.util.concurrent.atomic.AtomicInteger)

Example 22 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TestStatsListener method testListenerBasic.

@Test
public void testListenerBasic() {
    for (boolean useJ7 : new boolean[] { false, true }) {
        DataSet ds = new IrisDataSetIterator(150, 150).next();
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list().layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(4).nOut(3).build()).pretrain(false).backprop(true).build();
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        //in-memory
        StatsStorage ss = new MapDBStatsStorage();
        if (useJ7) {
            net.setListeners(new J7StatsListener(ss));
        } else {
            net.setListeners(new StatsListener(ss));
        }
        for (int i = 0; i < 3; i++) {
            net.fit(ds);
        }
        List<String> sids = ss.listSessionIDs();
        assertEquals(1, sids.size());
        String sessionID = ss.listSessionIDs().get(0);
        assertEquals(1, ss.listTypeIDsForSession(sessionID).size());
        String typeID = ss.listTypeIDsForSession(sessionID).get(0);
        assertEquals(1, ss.listWorkerIDsForSession(sessionID).size());
        String workerID = ss.listWorkerIDsForSession(sessionID).get(0);
        Persistable staticInfo = ss.getStaticInfo(sessionID, typeID, workerID);
        assertNotNull(staticInfo);
        System.out.println(staticInfo);
        List<Persistable> updates = ss.getAllUpdatesAfter(sessionID, typeID, workerID, 0);
        assertEquals(3, updates.size());
        for (Persistable p : updates) {
            System.out.println(p);
        }
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) Test(org.junit.Test)

Example 23 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TestStatsStorage method testStatsStorage.

@Test
public void testStatsStorage() throws IOException {
    for (boolean useJ7Storage : new boolean[] { false, true }) {
        for (int i = 0; i < 3; i++) {
            StatsStorage ss;
            switch(i) {
                case 0:
                    File f = Files.createTempFile("TestMapDbStatsStore", ".db").toFile();
                    //Don't want file to exist...
                    f.delete();
                    ss = new MapDBStatsStorage.Builder().file(f).build();
                    break;
                case 1:
                    File f2 = Files.createTempFile("TestJ7FileStatsStore", ".db").toFile();
                    //Don't want file to exist...
                    f2.delete();
                    ss = new J7FileStatsStorage(f2);
                    break;
                case 2:
                    ss = new InMemoryStatsStorage();
                    break;
                default:
                    throw new RuntimeException();
            }
            CountingListener l = new CountingListener();
            ss.registerStatsStorageListener(l);
            assertEquals(1, ss.getListeners().size());
            assertEquals(0, ss.listSessionIDs().size());
            assertNull(ss.getLatestUpdate("sessionID", "typeID", "workerID"));
            assertEquals(0, ss.listSessionIDs().size());
            ss.putStaticInfo(getInitReport(0, 0, 0, useJ7Storage));
            assertEquals(1, l.countNewSession);
            assertEquals(1, l.countNewWorkerId);
            assertEquals(1, l.countStaticInfo);
            assertEquals(0, l.countUpdate);
            assertEquals(Collections.singletonList("sid0"), ss.listSessionIDs());
            assertTrue(ss.sessionExists("sid0"));
            assertFalse(ss.sessionExists("sid1"));
            Persistable expected = getInitReport(0, 0, 0, useJ7Storage);
            Persistable p = ss.getStaticInfo("sid0", "tid0", "wid0");
            assertEquals(expected, p);
            List<Persistable> allStatic = ss.getAllStaticInfos("sid0", "tid0");
            assertEquals(Collections.singletonList(expected), allStatic);
            assertNull(ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
            assertEquals(0, ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0).size());
            assertEquals(0, ss.getNumUpdateRecordsFor("sid0"));
            assertEquals(0, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
            //Put first update
            ss.putUpdate(getReport(0, 0, 0, 12345, useJ7Storage));
            assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
            assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
            assertEquals(Collections.singletonList(getReport(0, 0, 0, 12345, useJ7Storage)), ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0));
            assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
            assertEquals(1, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
            List<Persistable> list = ss.getLatestUpdateAllWorkers("sid0", "tid0");
            assertEquals(1, list.size());
            assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12345));
            assertEquals(1, l.countNewSession);
            assertEquals(1, l.countNewWorkerId);
            assertEquals(1, l.countStaticInfo);
            assertEquals(1, l.countUpdate);
            //Put second update
            ss.putUpdate(getReport(0, 0, 0, 12346, useJ7Storage));
            assertEquals(1, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
            assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12346));
            ss.putUpdate(getReport(0, 0, 1, 12345, useJ7Storage));
            assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid1"));
            assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid1", 12345));
            assertEquals(1, l.countNewSession);
            assertEquals(2, l.countNewWorkerId);
            assertEquals(1, l.countStaticInfo);
            assertEquals(3, l.countUpdate);
            //Put static info and update with different session, type and worker IDs
            ss.putStaticInfo(getInitReport(100, 200, 300, useJ7Storage));
            assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            ss.putUpdate(getReport(100, 200, 300, 12346, useJ7Storage));
            assertEquals(Collections.singletonList(getReport(100, 200, 300, 12346, useJ7Storage)), ss.getLatestUpdateAllWorkers("sid100", "tid200"));
            assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100"));
            List<String> temp = ss.listWorkerIDsForSession("sid100");
            System.out.println("temp: " + temp);
            assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100"));
            assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSessionAndType("sid100", "tid200"));
            assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getLatestUpdate("sid100", "tid200", "wid300"));
            assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getUpdate("sid100", "tid200", "wid300", 12346));
            assertEquals(2, l.countNewSession);
            assertEquals(3, l.countNewWorkerId);
            assertEquals(2, l.countStaticInfo);
            assertEquals(4, l.countUpdate);
        }
    }
}
Also used : MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) J7FileStatsStorage(org.deeplearning4j.ui.storage.sqlite.J7FileStatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable) J7FileStatsStorage(org.deeplearning4j.ui.storage.sqlite.J7FileStatsStorage) MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) File(java.io.File) Test(org.junit.Test)

Example 24 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class ConvolutionalListenerModule method getImage.

private Result getImage() {
    if (lastTimeStamp > 0 && lastStorage != null) {
        Persistable p = lastStorage.getStaticInfo(lastSessionID, TYPE_ID, lastWorkerID);
        if (p instanceof ConvolutionListenerPersistable) {
            ConvolutionListenerPersistable clp = (ConvolutionListenerPersistable) p;
            BufferedImage bi = clp.getImg();
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            try {
                ImageIO.write(bi, "jpg", baos);
            } catch (IOException e) {
                log.warn("Error displaying image", e);
            }
            return ok(baos.toByteArray()).as("image/jpg");
        } else {
            return ok(new byte[0]).as("image/jpeg");
        }
    } else {
        return ok(new byte[0]).as("image/jpeg");
    }
}
Also used : Persistable(org.deeplearning4j.api.storage.Persistable) ConvolutionListenerPersistable(org.deeplearning4j.ui.weights.ConvolutionListenerPersistable) ByteArrayOutputStream(java.io.ByteArrayOutputStream) IOException(java.io.IOException) ConvolutionListenerPersistable(org.deeplearning4j.ui.weights.ConvolutionListenerPersistable) BufferedImage(java.awt.image.BufferedImage)

Example 25 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class FlowListenerModule method getUpdate.

private Result getUpdate(String sessionID) {
    if (!knownSessionIDs.containsKey(sessionID))
        return ok("Unknown session ID");
    StatsStorage ss = knownSessionIDs.get(sessionID);
    List<Persistable> list = ss.getLatestUpdateAllWorkers(sessionID, TYPE_ID);
    if (list == null || list.size() == 0)
        return ok();
    Persistable p = list.get(0);
    if (!(p instanceof FlowUpdatePersistable))
        return ok();
    FlowUpdatePersistable f = (FlowUpdatePersistable) p;
    return ok(Json.toJson(f.getModelState()));
}
Also used : StatsStorage(org.deeplearning4j.api.storage.StatsStorage) FlowUpdatePersistable(org.deeplearning4j.ui.flow.data.FlowUpdatePersistable) FlowStaticPersistable(org.deeplearning4j.ui.flow.data.FlowStaticPersistable) Persistable(org.deeplearning4j.api.storage.Persistable) FlowUpdatePersistable(org.deeplearning4j.ui.flow.data.FlowUpdatePersistable)

Aggregations

Persistable (org.deeplearning4j.api.storage.Persistable)30 StatsStorage (org.deeplearning4j.api.storage.StatsStorage)14 StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)7 StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)6 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)6 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 Test (org.junit.Test)5 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 MapDBStatsStorage (org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 IOException (java.io.IOException)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)3 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)3 FlowStaticPersistable (org.deeplearning4j.ui.flow.data.FlowStaticPersistable)3 FlowUpdatePersistable (org.deeplearning4j.ui.flow.data.FlowUpdatePersistable)3 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)3 BufferedImage (java.awt.image.BufferedImage)2 File (java.io.File)2 ArrayList (java.util.ArrayList)2