use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TestStatsStorage method testFileStatsStore.
@Test
public void testFileStatsStore() throws IOException {
for (boolean useJ7Storage : new boolean[] { false, true }) {
for (int i = 0; i < 2; i++) {
File f;
if (i == 0) {
f = Files.createTempFile("TestMapDbStatsStore", ".db").toFile();
} else {
f = Files.createTempFile("TestSqliteStatsStore", ".db").toFile();
}
//Don't want file to exist...
f.delete();
StatsStorage ss;
if (i == 0) {
ss = new MapDBStatsStorage.Builder().file(f).build();
} else {
ss = new J7FileStatsStorage(f);
}
CountingListener l = new CountingListener();
ss.registerStatsStorageListener(l);
assertEquals(1, ss.getListeners().size());
assertEquals(0, ss.listSessionIDs().size());
assertNull(ss.getLatestUpdate("sessionID", "typeID", "workerID"));
assertEquals(0, ss.listSessionIDs().size());
ss.putStaticInfo(getInitReport(0, 0, 0, useJ7Storage));
assertEquals(1, l.countNewSession);
assertEquals(1, l.countNewWorkerId);
assertEquals(1, l.countStaticInfo);
assertEquals(0, l.countUpdate);
assertEquals(Collections.singletonList("sid0"), ss.listSessionIDs());
assertTrue(ss.sessionExists("sid0"));
assertFalse(ss.sessionExists("sid1"));
Persistable expected = getInitReport(0, 0, 0, useJ7Storage);
Persistable p = ss.getStaticInfo("sid0", "tid0", "wid0");
assertEquals(expected, p);
List<Persistable> allStatic = ss.getAllStaticInfos("sid0", "tid0");
assertEquals(Collections.singletonList(expected), allStatic);
assertNull(ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
assertEquals(0, ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0).size());
assertEquals(0, ss.getNumUpdateRecordsFor("sid0"));
assertEquals(0, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
//Put first update
ss.putUpdate(getReport(0, 0, 0, 12345, useJ7Storage));
assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
assertEquals(Collections.singletonList(getReport(0, 0, 0, 12345, useJ7Storage)), ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0));
assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
assertEquals(1, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
List<Persistable> list = ss.getLatestUpdateAllWorkers("sid0", "tid0");
assertEquals(1, list.size());
assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12345));
assertEquals(1, l.countNewSession);
assertEquals(1, l.countNewWorkerId);
assertEquals(1, l.countStaticInfo);
assertEquals(1, l.countUpdate);
//Put second update
ss.putUpdate(getReport(0, 0, 0, 12346, useJ7Storage));
assertEquals(1, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12346));
ss.putUpdate(getReport(0, 0, 1, 12345, useJ7Storage));
assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid1"));
assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid1", 12345));
assertEquals(1, l.countNewSession);
assertEquals(2, l.countNewWorkerId);
assertEquals(1, l.countStaticInfo);
assertEquals(3, l.countUpdate);
//Put static info and update with different session, type and worker IDs
ss.putStaticInfo(getInitReport(100, 200, 300, useJ7Storage));
assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
ss.putUpdate(getReport(100, 200, 300, 12346, useJ7Storage));
assertEquals(Collections.singletonList(getReport(100, 200, 300, 12346, useJ7Storage)), ss.getLatestUpdateAllWorkers("sid100", "tid200"));
assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100"));
List<String> temp = ss.listWorkerIDsForSession("sid100");
System.out.println("temp: " + temp);
assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100"));
assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSessionAndType("sid100", "tid200"));
assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getLatestUpdate("sid100", "tid200", "wid300"));
assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getUpdate("sid100", "tid200", "wid300", 12346));
assertEquals(2, l.countNewSession);
assertEquals(3, l.countNewWorkerId);
assertEquals(2, l.countStaticInfo);
assertEquals(4, l.countUpdate);
//Close and re-open
ss.close();
assertTrue(ss.isClosed());
if (i == 0) {
ss = new MapDBStatsStorage.Builder().file(f).build();
} else {
ss = new J7FileStatsStorage(f);
}
assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12345));
assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12346));
assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid1"));
assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid1", 12345));
assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
assertEquals(1, ss.getLatestUpdateAllWorkers("sid100", "tid200").size());
assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100"));
assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100"));
assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSessionAndType("sid100", "tid200"));
assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getLatestUpdate("sid100", "tid200", "wid300"));
assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getUpdate("sid100", "tid200", "wid300", 12346));
}
}
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class FlowIterationListener method iterationDone.
/**
* Event listener for each iteration
*
* @param model the model iterating
* @param iteration the iteration
*/
@Override
public synchronized void iterationDone(Model model, int iteration) {
if (iterationCount.incrementAndGet() % frequency == 0) {
currTime = System.currentTimeMillis();
if (firstIteration) {
// On first pass we just build list of layers. However, for MultiLayerNetwork first pass is the last pass, since we know connections in advance
ModelInfo info = buildModelInfo(model);
// send ModelInfo to stats storage
Persistable staticInfo = new FlowStaticPersistable(sessionID, workerID, System.currentTimeMillis(), info);
ssr.putStaticInfo(staticInfo);
}
// update modelState
buildModelState(model);
Persistable updateInfo = new FlowUpdatePersistable(sessionID, workerID, System.currentTimeMillis(), modelState);
ssr.putUpdate(updateInfo);
if (firstIteration && openBrowser) {
UIServer uiServer = UIServer.getInstance();
String path = "http://localhost:" + uiServer.getPort() + "/flow?sid=" + sessionID;
try {
UiUtils.tryOpenBrowser(path, log);
} catch (Exception e) {
}
firstIteration = false;
}
}
lastTime = System.currentTimeMillis();
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class RemoteUIStatsStorageRouter method tryPost.
private boolean tryPost(ToPost toPost) throws IOException {
HttpURLConnection connection = getConnection();
String className;
byte[] asBytes;
StorageType type;
if (toPost.getMeta() != null) {
StorageMetaData smd = toPost.getMeta();
className = smd.getClass().getName();
asBytes = smd.encode();
type = StorageType.MetaData;
} else if (toPost.getStaticInfo() != null) {
Persistable p = toPost.getStaticInfo();
className = p.getClass().getName();
asBytes = p.encode();
type = StorageType.StaticInfo;
} else {
Persistable p = toPost.getUpdate();
className = p.getClass().getName();
asBytes = p.encode();
type = StorageType.Update;
}
String base64 = DatatypeConverter.printBase64Binary(asBytes);
Map<String, String> jsonObj = new LinkedHashMap<>();
jsonObj.put("type", type.name());
jsonObj.put("class", className);
jsonObj.put("data", base64);
String str;
try {
str = objectMapper.writeValueAsString(jsonObj);
} catch (Exception e) {
//Should never get an exception from simple Map<String,String>
throw new RuntimeException(e);
}
DataOutputStream dos = new DataOutputStream(connection.getOutputStream());
dos.writeBytes(str);
dos.flush();
dos.close();
try {
int responseCode = connection.getResponseCode();
if (responseCode != 200) {
BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream()));
String inputLine;
StringBuilder response = new StringBuilder();
while ((inputLine = in.readLine()) != null) {
response.append(inputLine);
}
in.close();
log.warn("Error posting to remote UI - received response code {}\tContent: {}", response, response.toString());
return false;
}
} catch (IOException e) {
String msg = e.getMessage();
if (msg.contains("403 for URL")) {
log.warn("Error posting to remote UI at {} (Response code: 403)." + " Remote listener support is not enabled? use UIServer.getInstance().enableRemoteListener()", url, e);
} else {
log.warn("Error posting to remote UI at {}", url, e);
}
return false;
}
return true;
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class ParameterAveragingElementCombineFunction method call.
@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple v1, ParameterAveragingAggregationTuple v2) throws Exception {
if (v1 == null)
return v2;
else if (v2 == null)
return v1;
//Handle edge case of less data than executors: in this case, one (or both) of v1 and v2 might not have any contents...
if (v1.getParametersSum() == null)
return v2;
else if (v2.getParametersSum() == null)
return v1;
INDArray newParams = v1.getParametersSum().addi(v2.getParametersSum());
INDArray updaterStateSum;
if (v1.getUpdaterStateSum() == null) {
updaterStateSum = v2.getUpdaterStateSum();
} else {
updaterStateSum = v1.getUpdaterStateSum();
if (v2.getUpdaterStateSum() != null)
updaterStateSum.addi(v2.getUpdaterStateSum());
}
double scoreSum = v1.getScoreSum() + v2.getScoreSum();
int aggregationCount = v1.getAggregationsCount() + v2.getAggregationsCount();
SparkTrainingStats stats = v1.getSparkTrainingStats();
if (v2.getSparkTrainingStats() != null) {
if (stats == null)
stats = v2.getSparkTrainingStats();
else
stats.addOtherTrainingStats(v2.getSparkTrainingStats());
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> listenerMetaData = v1.getListenerMetaData();
if (listenerMetaData == null)
listenerMetaData = v2.getListenerMetaData();
else {
Collection<StorageMetaData> newMeta = v2.getListenerMetaData();
if (newMeta != null)
listenerMetaData.addAll(newMeta);
}
Collection<Persistable> listenerStaticInfo = v1.getListenerStaticInfo();
if (listenerStaticInfo == null)
listenerStaticInfo = v2.getListenerStaticInfo();
else {
Collection<Persistable> newStatic = v2.getListenerStaticInfo();
if (newStatic != null)
listenerStaticInfo.addAll(newStatic);
}
Collection<Persistable> listenerUpdates = v1.getListenerUpdates();
if (listenerUpdates == null)
listenerUpdates = v2.getListenerUpdates();
else {
Collection<Persistable> listenerUpdates2 = v2.getListenerUpdates();
if (listenerUpdates2 != null)
listenerUpdates.addAll(listenerUpdates2);
}
return new ParameterAveragingAggregationTuple(newParams, updaterStateSum, scoreSum, aggregationCount, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TrainModule method getLayerMeanMagnitudes.
//TODO float precision for smaller transfers?
//First: iteration. Second: ratios, by parameter
private MeanMagnitudes getLayerMeanMagnitudes(int layerIdx, TrainModuleUtils.GraphInfo gi, List<Persistable> updates, List<Integer> iterationCounts, ModelType modelType) {
if (gi == null) {
return new MeanMagnitudes(Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
}
String layerName = gi.getVertexNames().get(layerIdx);
if (modelType != ModelType.CG) {
//Get the original name, for the index...
layerName = gi.getOriginalVertexName().get(layerIdx);
}
String layerType = gi.getVertexTypes().get(layerIdx);
if ("input".equalsIgnoreCase(layerType)) {
//TODO better checking - other vertices, etc
return new MeanMagnitudes(Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
}
List<Integer> iterCounts = new ArrayList<>();
Map<String, List<Double>> ratioValues = new HashMap<>();
Map<String, List<Double>> outParamMM = new HashMap<>();
Map<String, List<Double>> outUpdateMM = new HashMap<>();
if (updates != null) {
int pCount = -1;
for (Persistable u : updates) {
pCount++;
if (!(u instanceof StatsReport))
continue;
StatsReport sp = (StatsReport) u;
if (iterationCounts != null) {
iterCounts.add(iterationCounts.get(pCount));
} else {
int iterCount = sp.getIterationCount();
iterCounts.add(iterCount);
}
//Info we want, for each parameter in this layer: mean magnitudes for parameters, updates AND the ratio of these
Map<String, Double> paramMM = sp.getMeanMagnitudes(StatsType.Parameters);
Map<String, Double> updateMM = sp.getMeanMagnitudes(StatsType.Updates);
for (String s : paramMM.keySet()) {
String prefix;
if (modelType == ModelType.Layer) {
prefix = layerName;
} else {
prefix = layerName + "_";
}
if (s.startsWith(prefix)) {
//Relevant parameter for this layer...
String layerParam = s.substring(prefix.length());
double pmm = paramMM.getOrDefault(s, 0.0);
double umm = updateMM.getOrDefault(s, 0.0);
if (!Double.isFinite(pmm)) {
pmm = NAN_REPLACEMENT_VALUE;
}
if (!Double.isFinite(umm)) {
umm = NAN_REPLACEMENT_VALUE;
}
double ratio;
if (umm == 0.0 && pmm == 0.0) {
//To avoid NaN from 0/0
ratio = 0.0;
} else {
ratio = umm / pmm;
}
List<Double> list = ratioValues.get(layerParam);
if (list == null) {
list = new ArrayList<>();
ratioValues.put(layerParam, list);
}
list.add(ratio);
List<Double> pmmList = outParamMM.get(layerParam);
if (pmmList == null) {
pmmList = new ArrayList<>();
outParamMM.put(layerParam, pmmList);
}
pmmList.add(pmm);
List<Double> ummList = outUpdateMM.get(layerParam);
if (ummList == null) {
ummList = new ArrayList<>();
outUpdateMM.put(layerParam, ummList);
}
ummList.add(umm);
}
}
}
}
return new MeanMagnitudes(iterCounts, ratioValues, outParamMM, outUpdateMM);
}
Aggregations