Search in sources :

Example 1 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TestStatsStorage method testFileStatsStore.

@Test
public void testFileStatsStore() throws IOException {
    for (boolean useJ7Storage : new boolean[] { false, true }) {
        for (int i = 0; i < 2; i++) {
            File f;
            if (i == 0) {
                f = Files.createTempFile("TestMapDbStatsStore", ".db").toFile();
            } else {
                f = Files.createTempFile("TestSqliteStatsStore", ".db").toFile();
            }
            //Don't want file to exist...
            f.delete();
            StatsStorage ss;
            if (i == 0) {
                ss = new MapDBStatsStorage.Builder().file(f).build();
            } else {
                ss = new J7FileStatsStorage(f);
            }
            CountingListener l = new CountingListener();
            ss.registerStatsStorageListener(l);
            assertEquals(1, ss.getListeners().size());
            assertEquals(0, ss.listSessionIDs().size());
            assertNull(ss.getLatestUpdate("sessionID", "typeID", "workerID"));
            assertEquals(0, ss.listSessionIDs().size());
            ss.putStaticInfo(getInitReport(0, 0, 0, useJ7Storage));
            assertEquals(1, l.countNewSession);
            assertEquals(1, l.countNewWorkerId);
            assertEquals(1, l.countStaticInfo);
            assertEquals(0, l.countUpdate);
            assertEquals(Collections.singletonList("sid0"), ss.listSessionIDs());
            assertTrue(ss.sessionExists("sid0"));
            assertFalse(ss.sessionExists("sid1"));
            Persistable expected = getInitReport(0, 0, 0, useJ7Storage);
            Persistable p = ss.getStaticInfo("sid0", "tid0", "wid0");
            assertEquals(expected, p);
            List<Persistable> allStatic = ss.getAllStaticInfos("sid0", "tid0");
            assertEquals(Collections.singletonList(expected), allStatic);
            assertNull(ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
            assertEquals(0, ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0).size());
            assertEquals(0, ss.getNumUpdateRecordsFor("sid0"));
            assertEquals(0, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
            //Put first update
            ss.putUpdate(getReport(0, 0, 0, 12345, useJ7Storage));
            assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
            assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
            assertEquals(Collections.singletonList(getReport(0, 0, 0, 12345, useJ7Storage)), ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0));
            assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
            assertEquals(1, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
            List<Persistable> list = ss.getLatestUpdateAllWorkers("sid0", "tid0");
            assertEquals(1, list.size());
            assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12345));
            assertEquals(1, l.countNewSession);
            assertEquals(1, l.countNewWorkerId);
            assertEquals(1, l.countStaticInfo);
            assertEquals(1, l.countUpdate);
            //Put second update
            ss.putUpdate(getReport(0, 0, 0, 12346, useJ7Storage));
            assertEquals(1, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
            assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12346));
            ss.putUpdate(getReport(0, 0, 1, 12345, useJ7Storage));
            assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid1"));
            assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid1", 12345));
            assertEquals(1, l.countNewSession);
            assertEquals(2, l.countNewWorkerId);
            assertEquals(1, l.countStaticInfo);
            assertEquals(3, l.countUpdate);
            //Put static info and update with different session, type and worker IDs
            ss.putStaticInfo(getInitReport(100, 200, 300, useJ7Storage));
            assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            ss.putUpdate(getReport(100, 200, 300, 12346, useJ7Storage));
            assertEquals(Collections.singletonList(getReport(100, 200, 300, 12346, useJ7Storage)), ss.getLatestUpdateAllWorkers("sid100", "tid200"));
            assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100"));
            List<String> temp = ss.listWorkerIDsForSession("sid100");
            System.out.println("temp: " + temp);
            assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100"));
            assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSessionAndType("sid100", "tid200"));
            assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getLatestUpdate("sid100", "tid200", "wid300"));
            assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getUpdate("sid100", "tid200", "wid300", 12346));
            assertEquals(2, l.countNewSession);
            assertEquals(3, l.countNewWorkerId);
            assertEquals(2, l.countStaticInfo);
            assertEquals(4, l.countUpdate);
            //Close and re-open
            ss.close();
            assertTrue(ss.isClosed());
            if (i == 0) {
                ss = new MapDBStatsStorage.Builder().file(f).build();
            } else {
                ss = new J7FileStatsStorage(f);
            }
            assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12345));
            assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12346));
            assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid1"));
            assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid1", 12345));
            assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            assertEquals(1, ss.getLatestUpdateAllWorkers("sid100", "tid200").size());
            assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100"));
            assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100"));
            assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSessionAndType("sid100", "tid200"));
            assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getLatestUpdate("sid100", "tid200", "wid300"));
            assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getUpdate("sid100", "tid200", "wid300", 12346));
        }
    }
}
Also used : MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) J7FileStatsStorage(org.deeplearning4j.ui.storage.sqlite.J7FileStatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable) J7FileStatsStorage(org.deeplearning4j.ui.storage.sqlite.J7FileStatsStorage) MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) File(java.io.File) Test(org.junit.Test)

Example 2 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class FlowIterationListener method iterationDone.

/**
     * Event listener for each iteration
     *
     * @param model     the model iterating
     * @param iteration the iteration
     */
@Override
public synchronized void iterationDone(Model model, int iteration) {
    if (iterationCount.incrementAndGet() % frequency == 0) {
        currTime = System.currentTimeMillis();
        if (firstIteration) {
            // On first pass we just build list of layers. However, for MultiLayerNetwork first pass is the last pass, since we know connections in advance
            ModelInfo info = buildModelInfo(model);
            // send ModelInfo to stats storage
            Persistable staticInfo = new FlowStaticPersistable(sessionID, workerID, System.currentTimeMillis(), info);
            ssr.putStaticInfo(staticInfo);
        }
        // update modelState
        buildModelState(model);
        Persistable updateInfo = new FlowUpdatePersistable(sessionID, workerID, System.currentTimeMillis(), modelState);
        ssr.putUpdate(updateInfo);
        if (firstIteration && openBrowser) {
            UIServer uiServer = UIServer.getInstance();
            String path = "http://localhost:" + uiServer.getPort() + "/flow?sid=" + sessionID;
            try {
                UiUtils.tryOpenBrowser(path, log);
            } catch (Exception e) {
            }
            firstIteration = false;
        }
    }
    lastTime = System.currentTimeMillis();
}
Also used : FlowStaticPersistable(org.deeplearning4j.ui.flow.data.FlowStaticPersistable) FlowUpdatePersistable(org.deeplearning4j.ui.flow.data.FlowUpdatePersistable) Persistable(org.deeplearning4j.api.storage.Persistable) FlowStaticPersistable(org.deeplearning4j.ui.flow.data.FlowStaticPersistable) FlowUpdatePersistable(org.deeplearning4j.ui.flow.data.FlowUpdatePersistable) UIServer(org.deeplearning4j.ui.api.UIServer)

Example 3 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class RemoteUIStatsStorageRouter method tryPost.

private boolean tryPost(ToPost toPost) throws IOException {
    HttpURLConnection connection = getConnection();
    String className;
    byte[] asBytes;
    StorageType type;
    if (toPost.getMeta() != null) {
        StorageMetaData smd = toPost.getMeta();
        className = smd.getClass().getName();
        asBytes = smd.encode();
        type = StorageType.MetaData;
    } else if (toPost.getStaticInfo() != null) {
        Persistable p = toPost.getStaticInfo();
        className = p.getClass().getName();
        asBytes = p.encode();
        type = StorageType.StaticInfo;
    } else {
        Persistable p = toPost.getUpdate();
        className = p.getClass().getName();
        asBytes = p.encode();
        type = StorageType.Update;
    }
    String base64 = DatatypeConverter.printBase64Binary(asBytes);
    Map<String, String> jsonObj = new LinkedHashMap<>();
    jsonObj.put("type", type.name());
    jsonObj.put("class", className);
    jsonObj.put("data", base64);
    String str;
    try {
        str = objectMapper.writeValueAsString(jsonObj);
    } catch (Exception e) {
        //Should never get an exception from simple Map<String,String>
        throw new RuntimeException(e);
    }
    DataOutputStream dos = new DataOutputStream(connection.getOutputStream());
    dos.writeBytes(str);
    dos.flush();
    dos.close();
    try {
        int responseCode = connection.getResponseCode();
        if (responseCode != 200) {
            BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream()));
            String inputLine;
            StringBuilder response = new StringBuilder();
            while ((inputLine = in.readLine()) != null) {
                response.append(inputLine);
            }
            in.close();
            log.warn("Error posting to remote UI - received response code {}\tContent: {}", response, response.toString());
            return false;
        }
    } catch (IOException e) {
        String msg = e.getMessage();
        if (msg.contains("403 for URL")) {
            log.warn("Error posting to remote UI at {} (Response code: 403)." + " Remote listener support is not enabled? use UIServer.getInstance().enableRemoteListener()", url, e);
        } else {
            log.warn("Error posting to remote UI at {}", url, e);
        }
        return false;
    }
    return true;
}
Also used : StorageType(org.deeplearning4j.api.storage.StorageType) Persistable(org.deeplearning4j.api.storage.Persistable) InputStreamReader(java.io.InputStreamReader) DataOutputStream(java.io.DataOutputStream) IOException(java.io.IOException) MalformedURLException(java.net.MalformedURLException) IOException(java.io.IOException) StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) HttpURLConnection(java.net.HttpURLConnection) BufferedReader(java.io.BufferedReader)

Example 4 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class ParameterAveragingElementCombineFunction method call.

@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple v1, ParameterAveragingAggregationTuple v2) throws Exception {
    if (v1 == null)
        return v2;
    else if (v2 == null)
        return v1;
    //Handle edge case of less data than executors: in this case, one (or both) of v1 and v2 might not have any contents...
    if (v1.getParametersSum() == null)
        return v2;
    else if (v2.getParametersSum() == null)
        return v1;
    INDArray newParams = v1.getParametersSum().addi(v2.getParametersSum());
    INDArray updaterStateSum;
    if (v1.getUpdaterStateSum() == null) {
        updaterStateSum = v2.getUpdaterStateSum();
    } else {
        updaterStateSum = v1.getUpdaterStateSum();
        if (v2.getUpdaterStateSum() != null)
            updaterStateSum.addi(v2.getUpdaterStateSum());
    }
    double scoreSum = v1.getScoreSum() + v2.getScoreSum();
    int aggregationCount = v1.getAggregationsCount() + v2.getAggregationsCount();
    SparkTrainingStats stats = v1.getSparkTrainingStats();
    if (v2.getSparkTrainingStats() != null) {
        if (stats == null)
            stats = v2.getSparkTrainingStats();
        else
            stats.addOtherTrainingStats(v2.getSparkTrainingStats());
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> listenerMetaData = v1.getListenerMetaData();
    if (listenerMetaData == null)
        listenerMetaData = v2.getListenerMetaData();
    else {
        Collection<StorageMetaData> newMeta = v2.getListenerMetaData();
        if (newMeta != null)
            listenerMetaData.addAll(newMeta);
    }
    Collection<Persistable> listenerStaticInfo = v1.getListenerStaticInfo();
    if (listenerStaticInfo == null)
        listenerStaticInfo = v2.getListenerStaticInfo();
    else {
        Collection<Persistable> newStatic = v2.getListenerStaticInfo();
        if (newStatic != null)
            listenerStaticInfo.addAll(newStatic);
    }
    Collection<Persistable> listenerUpdates = v1.getListenerUpdates();
    if (listenerUpdates == null)
        listenerUpdates = v2.getListenerUpdates();
    else {
        Collection<Persistable> listenerUpdates2 = v2.getListenerUpdates();
        if (listenerUpdates2 != null)
            listenerUpdates.addAll(listenerUpdates2);
    }
    return new ParameterAveragingAggregationTuple(newParams, updaterStateSum, scoreSum, aggregationCount, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats)

Example 5 with Persistable

use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.

the class TrainModule method getLayerMeanMagnitudes.

//TODO float precision for smaller transfers?
//First: iteration. Second: ratios, by parameter
private MeanMagnitudes getLayerMeanMagnitudes(int layerIdx, TrainModuleUtils.GraphInfo gi, List<Persistable> updates, List<Integer> iterationCounts, ModelType modelType) {
    if (gi == null) {
        return new MeanMagnitudes(Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
    }
    String layerName = gi.getVertexNames().get(layerIdx);
    if (modelType != ModelType.CG) {
        //Get the original name, for the index...
        layerName = gi.getOriginalVertexName().get(layerIdx);
    }
    String layerType = gi.getVertexTypes().get(layerIdx);
    if ("input".equalsIgnoreCase(layerType)) {
        //TODO better checking - other vertices, etc
        return new MeanMagnitudes(Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
    }
    List<Integer> iterCounts = new ArrayList<>();
    Map<String, List<Double>> ratioValues = new HashMap<>();
    Map<String, List<Double>> outParamMM = new HashMap<>();
    Map<String, List<Double>> outUpdateMM = new HashMap<>();
    if (updates != null) {
        int pCount = -1;
        for (Persistable u : updates) {
            pCount++;
            if (!(u instanceof StatsReport))
                continue;
            StatsReport sp = (StatsReport) u;
            if (iterationCounts != null) {
                iterCounts.add(iterationCounts.get(pCount));
            } else {
                int iterCount = sp.getIterationCount();
                iterCounts.add(iterCount);
            }
            //Info we want, for each parameter in this layer: mean magnitudes for parameters, updates AND the ratio of these
            Map<String, Double> paramMM = sp.getMeanMagnitudes(StatsType.Parameters);
            Map<String, Double> updateMM = sp.getMeanMagnitudes(StatsType.Updates);
            for (String s : paramMM.keySet()) {
                String prefix;
                if (modelType == ModelType.Layer) {
                    prefix = layerName;
                } else {
                    prefix = layerName + "_";
                }
                if (s.startsWith(prefix)) {
                    //Relevant parameter for this layer...
                    String layerParam = s.substring(prefix.length());
                    double pmm = paramMM.getOrDefault(s, 0.0);
                    double umm = updateMM.getOrDefault(s, 0.0);
                    if (!Double.isFinite(pmm)) {
                        pmm = NAN_REPLACEMENT_VALUE;
                    }
                    if (!Double.isFinite(umm)) {
                        umm = NAN_REPLACEMENT_VALUE;
                    }
                    double ratio;
                    if (umm == 0.0 && pmm == 0.0) {
                        //To avoid NaN from 0/0
                        ratio = 0.0;
                    } else {
                        ratio = umm / pmm;
                    }
                    List<Double> list = ratioValues.get(layerParam);
                    if (list == null) {
                        list = new ArrayList<>();
                        ratioValues.put(layerParam, list);
                    }
                    list.add(ratio);
                    List<Double> pmmList = outParamMM.get(layerParam);
                    if (pmmList == null) {
                        pmmList = new ArrayList<>();
                        outParamMM.put(layerParam, pmmList);
                    }
                    pmmList.add(pmm);
                    List<Double> ummList = outUpdateMM.get(layerParam);
                    if (ummList == null) {
                        ummList = new ArrayList<>();
                        outUpdateMM.put(layerParam, ummList);
                    }
                    ummList.add(umm);
                }
            }
        }
    }
    return new MeanMagnitudes(iterCounts, ratioValues, outParamMM, outUpdateMM);
}
Also used : Persistable(org.deeplearning4j.api.storage.Persistable) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport)

Aggregations

Persistable (org.deeplearning4j.api.storage.Persistable)30 StatsStorage (org.deeplearning4j.api.storage.StatsStorage)14 StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)7 StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)6 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)6 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 Test (org.junit.Test)5 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 MapDBStatsStorage (org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 IOException (java.io.IOException)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)3 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)3 FlowStaticPersistable (org.deeplearning4j.ui.flow.data.FlowStaticPersistable)3 FlowUpdatePersistable (org.deeplearning4j.ui.flow.data.FlowUpdatePersistable)3 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)3 BufferedImage (java.awt.image.BufferedImage)2 File (java.io.File)2 ArrayList (java.util.ArrayList)2