use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TrainModule method getMemory.
private static Map<String, Object> getMemory(List<Persistable> staticInfoAllWorkers, List<Persistable> updatesLastNMinutes, I18N i18n) {
Map<String, Object> ret = new HashMap<>();
//First: map workers to JVMs
Set<String> jvmIDs = new HashSet<>();
Map<String, String> workersToJvms = new HashMap<>();
Map<String, Integer> workerNumDevices = new HashMap<>();
Map<String, String[]> deviceNames = new HashMap<>();
for (Persistable p : staticInfoAllWorkers) {
//TODO validation/checks
StatsInitializationReport init = (StatsInitializationReport) p;
String jvmuid = init.getSwJvmUID();
workersToJvms.put(p.getWorkerID(), jvmuid);
jvmIDs.add(jvmuid);
int nDevices = init.getHwNumDevices();
workerNumDevices.put(p.getWorkerID(), nDevices);
if (nDevices > 0) {
String[] deviceNamesArr = init.getHwDeviceDescription();
deviceNames.put(p.getWorkerID(), deviceNamesArr);
}
}
List<String> jvmList = new ArrayList<>(jvmIDs);
Collections.sort(jvmList);
//For each unique JVM, collect memory info
//Do this by selecting the first worker
int count = 0;
for (String jvm : jvmList) {
List<String> workersForJvm = new ArrayList<>();
for (String s : workersToJvms.keySet()) {
if (workersToJvms.get(s).equals(jvm)) {
workersForJvm.add(s);
}
}
Collections.sort(workersForJvm);
String wid = workersForJvm.get(0);
int numDevices = workerNumDevices.get(wid);
Map<String, Object> jvmData = new HashMap<>();
List<Long> timestamps = new ArrayList<>();
List<Float> fracJvm = new ArrayList<>();
List<Float> fracOffHeap = new ArrayList<>();
long[] lastBytes = new long[2 + numDevices];
long[] lastMaxBytes = new long[2 + numDevices];
List<List<Float>> fracDeviceMem = null;
if (numDevices > 0) {
fracDeviceMem = new ArrayList<>(numDevices);
for (int i = 0; i < numDevices; i++) {
fracDeviceMem.add(new ArrayList<>());
}
}
for (Persistable p : updatesLastNMinutes) {
//TODO single pass
if (!p.getWorkerID().equals(wid))
continue;
if (!(p instanceof StatsReport))
continue;
StatsReport sp = (StatsReport) p;
timestamps.add(sp.getTimeStamp());
long jvmCurrentBytes = sp.getJvmCurrentBytes();
long jvmMaxBytes = sp.getJvmMaxBytes();
long ohCurrentBytes = sp.getOffHeapCurrentBytes();
long ohMaxBytes = sp.getOffHeapMaxBytes();
double jvmFrac = jvmCurrentBytes / ((double) jvmMaxBytes);
double offheapFrac = ohCurrentBytes / ((double) ohMaxBytes);
if (Double.isNaN(jvmFrac))
jvmFrac = 0.0;
if (Double.isNaN(offheapFrac))
offheapFrac = 0.0;
fracJvm.add((float) jvmFrac);
fracOffHeap.add((float) offheapFrac);
lastBytes[0] = jvmCurrentBytes;
lastBytes[1] = ohCurrentBytes;
lastMaxBytes[0] = jvmMaxBytes;
lastMaxBytes[1] = ohMaxBytes;
if (numDevices > 0) {
long[] devBytes = sp.getDeviceCurrentBytes();
long[] devMaxBytes = sp.getDeviceMaxBytes();
for (int i = 0; i < numDevices; i++) {
double frac = devBytes[i] / ((double) devMaxBytes[i]);
if (Double.isNaN(frac))
frac = 0.0;
fracDeviceMem.get(i).add((float) frac);
lastBytes[2 + i] = devBytes[i];
lastMaxBytes[2 + i] = devMaxBytes[i];
}
}
}
List<List<Float>> fracUtilized = new ArrayList<>();
fracUtilized.add(fracJvm);
fracUtilized.add(fracOffHeap);
String[] seriesNames = new String[2 + numDevices];
seriesNames[0] = i18n.getMessage("train.system.hwTable.jvmCurrent");
seriesNames[1] = i18n.getMessage("train.system.hwTable.offHeapCurrent");
boolean[] isDevice = new boolean[2 + numDevices];
String[] devNames = deviceNames.get(wid);
for (int i = 0; i < numDevices; i++) {
seriesNames[2 + i] = devNames != null && devNames.length > i ? devNames[i] : "";
fracUtilized.add(fracDeviceMem.get(i));
isDevice[2 + i] = true;
}
jvmData.put("times", timestamps);
jvmData.put("isDevice", isDevice);
jvmData.put("seriesNames", seriesNames);
jvmData.put("values", fracUtilized);
jvmData.put("currentBytes", lastBytes);
jvmData.put("maxBytes", lastMaxBytes);
ret.put(String.valueOf(count), jvmData);
count++;
}
return ret;
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TestStatsListener method testListenerBasic.
@Test
public void testListenerBasic() {
for (boolean useJ7 : new boolean[] { false, true }) {
DataSet ds = new IrisDataSetIterator(150, 150).next();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list().layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(4).nOut(3).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
//in-memory
StatsStorage ss = new MapDBStatsStorage();
if (useJ7) {
net.setListeners(new J7StatsListener(ss));
} else {
net.setListeners(new StatsListener(ss));
}
for (int i = 0; i < 3; i++) {
net.fit(ds);
}
List<String> sids = ss.listSessionIDs();
assertEquals(1, sids.size());
String sessionID = ss.listSessionIDs().get(0);
assertEquals(1, ss.listTypeIDsForSession(sessionID).size());
String typeID = ss.listTypeIDsForSession(sessionID).get(0);
assertEquals(1, ss.listWorkerIDsForSession(sessionID).size());
String workerID = ss.listWorkerIDsForSession(sessionID).get(0);
Persistable staticInfo = ss.getStaticInfo(sessionID, typeID, workerID);
assertNotNull(staticInfo);
System.out.println(staticInfo);
List<Persistable> updates = ss.getAllUpdatesAfter(sessionID, typeID, workerID, 0);
assertEquals(3, updates.size());
for (Persistable p : updates) {
System.out.println(p);
}
}
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TestStatsStorage method testStatsStorage.
@Test
public void testStatsStorage() throws IOException {
for (boolean useJ7Storage : new boolean[] { false, true }) {
for (int i = 0; i < 3; i++) {
StatsStorage ss;
switch(i) {
case 0:
File f = Files.createTempFile("TestMapDbStatsStore", ".db").toFile();
//Don't want file to exist...
f.delete();
ss = new MapDBStatsStorage.Builder().file(f).build();
break;
case 1:
File f2 = Files.createTempFile("TestJ7FileStatsStore", ".db").toFile();
//Don't want file to exist...
f2.delete();
ss = new J7FileStatsStorage(f2);
break;
case 2:
ss = new InMemoryStatsStorage();
break;
default:
throw new RuntimeException();
}
CountingListener l = new CountingListener();
ss.registerStatsStorageListener(l);
assertEquals(1, ss.getListeners().size());
assertEquals(0, ss.listSessionIDs().size());
assertNull(ss.getLatestUpdate("sessionID", "typeID", "workerID"));
assertEquals(0, ss.listSessionIDs().size());
ss.putStaticInfo(getInitReport(0, 0, 0, useJ7Storage));
assertEquals(1, l.countNewSession);
assertEquals(1, l.countNewWorkerId);
assertEquals(1, l.countStaticInfo);
assertEquals(0, l.countUpdate);
assertEquals(Collections.singletonList("sid0"), ss.listSessionIDs());
assertTrue(ss.sessionExists("sid0"));
assertFalse(ss.sessionExists("sid1"));
Persistable expected = getInitReport(0, 0, 0, useJ7Storage);
Persistable p = ss.getStaticInfo("sid0", "tid0", "wid0");
assertEquals(expected, p);
List<Persistable> allStatic = ss.getAllStaticInfos("sid0", "tid0");
assertEquals(Collections.singletonList(expected), allStatic);
assertNull(ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
assertEquals(0, ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0).size());
assertEquals(0, ss.getNumUpdateRecordsFor("sid0"));
assertEquals(0, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
//Put first update
ss.putUpdate(getReport(0, 0, 0, 12345, useJ7Storage));
assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
assertEquals(Collections.singletonList(getReport(0, 0, 0, 12345, useJ7Storage)), ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0));
assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
assertEquals(1, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
List<Persistable> list = ss.getLatestUpdateAllWorkers("sid0", "tid0");
assertEquals(1, list.size());
assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12345));
assertEquals(1, l.countNewSession);
assertEquals(1, l.countNewWorkerId);
assertEquals(1, l.countStaticInfo);
assertEquals(1, l.countUpdate);
//Put second update
ss.putUpdate(getReport(0, 0, 0, 12346, useJ7Storage));
assertEquals(1, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12346));
ss.putUpdate(getReport(0, 0, 1, 12345, useJ7Storage));
assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid1"));
assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid1", 12345));
assertEquals(1, l.countNewSession);
assertEquals(2, l.countNewWorkerId);
assertEquals(1, l.countStaticInfo);
assertEquals(3, l.countUpdate);
//Put static info and update with different session, type and worker IDs
ss.putStaticInfo(getInitReport(100, 200, 300, useJ7Storage));
assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
ss.putUpdate(getReport(100, 200, 300, 12346, useJ7Storage));
assertEquals(Collections.singletonList(getReport(100, 200, 300, 12346, useJ7Storage)), ss.getLatestUpdateAllWorkers("sid100", "tid200"));
assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100"));
List<String> temp = ss.listWorkerIDsForSession("sid100");
System.out.println("temp: " + temp);
assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100"));
assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSessionAndType("sid100", "tid200"));
assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getLatestUpdate("sid100", "tid200", "wid300"));
assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getUpdate("sid100", "tid200", "wid300", 12346));
assertEquals(2, l.countNewSession);
assertEquals(3, l.countNewWorkerId);
assertEquals(2, l.countStaticInfo);
assertEquals(4, l.countUpdate);
}
}
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class ConvolutionalListenerModule method getImage.
private Result getImage() {
if (lastTimeStamp > 0 && lastStorage != null) {
Persistable p = lastStorage.getStaticInfo(lastSessionID, TYPE_ID, lastWorkerID);
if (p instanceof ConvolutionListenerPersistable) {
ConvolutionListenerPersistable clp = (ConvolutionListenerPersistable) p;
BufferedImage bi = clp.getImg();
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try {
ImageIO.write(bi, "jpg", baos);
} catch (IOException e) {
log.warn("Error displaying image", e);
}
return ok(baos.toByteArray()).as("image/jpg");
} else {
return ok(new byte[0]).as("image/jpeg");
}
} else {
return ok(new byte[0]).as("image/jpeg");
}
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class FlowListenerModule method getUpdate.
private Result getUpdate(String sessionID) {
if (!knownSessionIDs.containsKey(sessionID))
return ok("Unknown session ID");
StatsStorage ss = knownSessionIDs.get(sessionID);
List<Persistable> list = ss.getLatestUpdateAllWorkers(sessionID, TYPE_ID);
if (list == null || list.size() == 0)
return ok();
Persistable p = list.get(0);
if (!(p instanceof FlowUpdatePersistable))
return ok();
FlowUpdatePersistable f = (FlowUpdatePersistable) p;
return ok(Json.toJson(f.getModelState()));
}
Aggregations