use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getFinalResult.
@Override
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
INDArray updaterState = null;
if (saveUpdater) {
Updater u = network.getUpdater();
if (u != null)
updaterState = u.getStateViewArray();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> storageMetaData = null;
Collection<Persistable> listenerStaticInfo = null;
Collection<Persistable> listenerUpdates = null;
if (listenerRouterProvider != null) {
StatsStorageRouter r = listenerRouterProvider.getRouter();
if (r instanceof VanillaStatsStorageRouter) {
//TODO this is ugly... need to find a better solution
VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
storageMetaData = ssr.getStorageMetaData();
listenerStaticInfo = ssr.getStaticInfo();
listenerUpdates = ssr.getUpdates();
}
}
return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.
the class ParameterAveragingElementAddFunction method call.
@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple tuple, ParameterAveragingTrainingResult result) throws Exception {
if (tuple == null) {
return ParameterAveragingAggregationTuple.builder().parametersSum(result.getParameters()).updaterStateSum(result.getUpdaterState()).scoreSum(result.getScore()).aggregationsCount(1).sparkTrainingStats(result.getSparkTrainingStats()).listenerMetaData(result.getListenerMetaData()).listenerStaticInfo(result.getListenerStaticInfo()).listenerUpdates(result.getListenerUpdates()).build();
}
INDArray params = tuple.getParametersSum().addi(result.getParameters());
INDArray updaterStateSum;
if (tuple.getUpdaterStateSum() == null) {
updaterStateSum = result.getUpdaterState();
} else {
updaterStateSum = tuple.getUpdaterStateSum();
if (result.getUpdaterState() != null)
updaterStateSum.addi(result.getUpdaterState());
}
double scoreSum = tuple.getScoreSum() + result.getScore();
SparkTrainingStats stats = tuple.getSparkTrainingStats();
if (result.getSparkTrainingStats() != null) {
if (stats == null)
stats = result.getSparkTrainingStats();
else
stats.addOtherTrainingStats(result.getSparkTrainingStats());
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> listenerMetaData = tuple.getListenerMetaData();
if (listenerMetaData == null)
listenerMetaData = result.getListenerMetaData();
else {
Collection<StorageMetaData> newMeta = result.getListenerMetaData();
if (newMeta != null)
listenerMetaData.addAll(newMeta);
}
Collection<Persistable> listenerStaticInfo = tuple.getListenerStaticInfo();
if (listenerStaticInfo == null)
listenerStaticInfo = result.getListenerStaticInfo();
else {
Collection<Persistable> newStatic = tuple.getListenerStaticInfo();
if (newStatic != null)
listenerStaticInfo.addAll(newStatic);
}
Collection<Persistable> listenerUpdates = tuple.getListenerUpdates();
if (listenerUpdates == null)
listenerUpdates = result.getListenerUpdates();
else {
Collection<Persistable> newUpdates = result.getListenerUpdates();
if (newUpdates != null)
listenerUpdates.addAll(newUpdates);
}
return new ParameterAveragingAggregationTuple(params, updaterStateSum, scoreSum, tuple.getAggregationsCount() + 1, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.
the class BaseStatsListener method doInit.
private void doInit(Model model) {
boolean backpropParamsOnly = backpropParamsOnly(model);
//TODO support NTP
long initTime = System.currentTimeMillis();
StatsInitializationReport initReport = getNewInitializationReport();
initReport.reportIDs(getSessionID(model), TYPE_ID, workerID, initTime);
if (initConfig.collectSoftwareInfo()) {
OperatingSystemMXBean osBean = ManagementFactory.getOperatingSystemMXBean();
RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean();
String arch = osBean.getArch();
String osName = osBean.getName();
String jvmName = runtime.getVmName();
String jvmVersion = System.getProperty("java.version");
String jvmSpecVersion = runtime.getSpecVersion();
String nd4jBackendClass = Nd4j.getNDArrayFactory().getClass().getName();
String nd4jDataTypeName = DataTypeUtil.getDtypeFromContext().name();
String hostname = System.getenv("COMPUTERNAME");
if (hostname == null || hostname.isEmpty()) {
try {
Process proc = Runtime.getRuntime().exec("hostname");
try (InputStream stream = proc.getInputStream()) {
hostname = IOUtils.toString(stream);
}
} catch (Exception e) {
}
}
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
Map<String, String> envInfo = new HashMap<>();
for (Map.Entry<Object, Object> e : p.entrySet()) {
Object v = e.getValue();
String value = (v == null ? "" : v.toString());
envInfo.put(e.getKey().toString(), value);
}
initReport.reportSoftwareInfo(arch, osName, jvmName, jvmVersion, jvmSpecVersion, nd4jBackendClass, nd4jDataTypeName, hostname, UIDProvider.getJVMUID(), envInfo);
}
if (initConfig.collectHardwareInfo()) {
int availableProcessors = Runtime.getRuntime().availableProcessors();
NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
int nDevices = nativeOps.getAvailableDevices();
long[] deviceTotalMem = null;
//TODO
String[] deviceDescription = null;
if (nDevices > 0) {
deviceTotalMem = new long[nDevices];
deviceDescription = new String[nDevices];
for (int i = 0; i < nDevices; i++) {
try {
Pointer p = getDevicePointer(i);
if (p == null) {
deviceTotalMem[i] = 0;
deviceDescription[i] = "Device(" + i + ")";
} else {
deviceTotalMem[i] = nativeOps.getDeviceTotalMemory(p);
deviceDescription[i] = nativeOps.getDeviceName(p);
if (nDevices > 1) {
deviceDescription[i] = deviceDescription[i] + " (" + i + ")";
}
}
} catch (Exception e) {
log.debug("Error getting device info", e);
}
}
}
long jvmMaxMemory = Runtime.getRuntime().maxMemory();
long offheapMaxMemory = Pointer.maxBytes();
initReport.reportHardwareInfo(availableProcessors, nDevices, jvmMaxMemory, offheapMaxMemory, deviceTotalMem, deviceDescription, UIDProvider.getHardwareUID());
}
if (initConfig.collectModelInfo()) {
String jsonConf;
int numLayers;
int numParams;
if (model instanceof MultiLayerNetwork) {
MultiLayerNetwork net = ((MultiLayerNetwork) model);
jsonConf = net.getLayerWiseConfigurations().toJson();
numLayers = net.getnLayers();
numParams = net.numParams();
} else if (model instanceof ComputationGraph) {
ComputationGraph cg = ((ComputationGraph) model);
jsonConf = cg.getConfiguration().toJson();
numLayers = cg.getNumLayers();
numParams = cg.numParams();
} else if (model instanceof Layer) {
Layer l = (Layer) model;
jsonConf = l.conf().toJson();
numLayers = 1;
numParams = l.numParams();
} else {
throw new RuntimeException("Invalid model: Expected MultiLayerNetwork or ComputationGraph. Got: " + (model == null ? null : model.getClass()));
}
Map<String, INDArray> paramMap = model.paramTable(backpropParamsOnly);
String[] paramNames = new String[paramMap.size()];
int i = 0;
for (String s : paramMap.keySet()) {
//Assuming sensible iteration order - LinkedHashMaps are used in MLN/CG for example
paramNames[i++] = s;
}
initReport.reportModelInfo(model.getClass().getName(), jsonConf, paramNames, numLayers, numParams);
}
StorageMetaData meta = getNewStorageMetaData(initTime, getSessionID(model), workerID);
router.putStorageMetaData(meta);
//TODO error handling
router.putStaticInfo(initReport);
}
Aggregations