Search in sources :

Example 6 with StorageMetaData

use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getFinalResult.

@Override
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
    INDArray updaterState = null;
    if (saveUpdater) {
        Updater u = network.getUpdater();
        if (u != null)
            updaterState = u.getStateViewArray();
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> storageMetaData = null;
    Collection<Persistable> listenerStaticInfo = null;
    Collection<Persistable> listenerUpdates = null;
    if (listenerRouterProvider != null) {
        StatsStorageRouter r = listenerRouterProvider.getRouter();
        if (r instanceof VanillaStatsStorageRouter) {
            //TODO this is ugly... need to find a better solution
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
    }
    return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) MultiLayerUpdater(org.deeplearning4j.nn.updater.MultiLayerUpdater) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)

Example 7 with StorageMetaData

use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.

the class ParameterAveragingElementAddFunction method call.

@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple tuple, ParameterAveragingTrainingResult result) throws Exception {
    if (tuple == null) {
        return ParameterAveragingAggregationTuple.builder().parametersSum(result.getParameters()).updaterStateSum(result.getUpdaterState()).scoreSum(result.getScore()).aggregationsCount(1).sparkTrainingStats(result.getSparkTrainingStats()).listenerMetaData(result.getListenerMetaData()).listenerStaticInfo(result.getListenerStaticInfo()).listenerUpdates(result.getListenerUpdates()).build();
    }
    INDArray params = tuple.getParametersSum().addi(result.getParameters());
    INDArray updaterStateSum;
    if (tuple.getUpdaterStateSum() == null) {
        updaterStateSum = result.getUpdaterState();
    } else {
        updaterStateSum = tuple.getUpdaterStateSum();
        if (result.getUpdaterState() != null)
            updaterStateSum.addi(result.getUpdaterState());
    }
    double scoreSum = tuple.getScoreSum() + result.getScore();
    SparkTrainingStats stats = tuple.getSparkTrainingStats();
    if (result.getSparkTrainingStats() != null) {
        if (stats == null)
            stats = result.getSparkTrainingStats();
        else
            stats.addOtherTrainingStats(result.getSparkTrainingStats());
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> listenerMetaData = tuple.getListenerMetaData();
    if (listenerMetaData == null)
        listenerMetaData = result.getListenerMetaData();
    else {
        Collection<StorageMetaData> newMeta = result.getListenerMetaData();
        if (newMeta != null)
            listenerMetaData.addAll(newMeta);
    }
    Collection<Persistable> listenerStaticInfo = tuple.getListenerStaticInfo();
    if (listenerStaticInfo == null)
        listenerStaticInfo = result.getListenerStaticInfo();
    else {
        Collection<Persistable> newStatic = tuple.getListenerStaticInfo();
        if (newStatic != null)
            listenerStaticInfo.addAll(newStatic);
    }
    Collection<Persistable> listenerUpdates = tuple.getListenerUpdates();
    if (listenerUpdates == null)
        listenerUpdates = result.getListenerUpdates();
    else {
        Collection<Persistable> newUpdates = result.getListenerUpdates();
        if (newUpdates != null)
            listenerUpdates.addAll(newUpdates);
    }
    return new ParameterAveragingAggregationTuple(params, updaterStateSum, scoreSum, tuple.getAggregationsCount() + 1, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats)

Example 8 with StorageMetaData

use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.

the class BaseStatsListener method doInit.

private void doInit(Model model) {
    boolean backpropParamsOnly = backpropParamsOnly(model);
    //TODO support NTP
    long initTime = System.currentTimeMillis();
    StatsInitializationReport initReport = getNewInitializationReport();
    initReport.reportIDs(getSessionID(model), TYPE_ID, workerID, initTime);
    if (initConfig.collectSoftwareInfo()) {
        OperatingSystemMXBean osBean = ManagementFactory.getOperatingSystemMXBean();
        RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean();
        String arch = osBean.getArch();
        String osName = osBean.getName();
        String jvmName = runtime.getVmName();
        String jvmVersion = System.getProperty("java.version");
        String jvmSpecVersion = runtime.getSpecVersion();
        String nd4jBackendClass = Nd4j.getNDArrayFactory().getClass().getName();
        String nd4jDataTypeName = DataTypeUtil.getDtypeFromContext().name();
        String hostname = System.getenv("COMPUTERNAME");
        if (hostname == null || hostname.isEmpty()) {
            try {
                Process proc = Runtime.getRuntime().exec("hostname");
                try (InputStream stream = proc.getInputStream()) {
                    hostname = IOUtils.toString(stream);
                }
            } catch (Exception e) {
            }
        }
        Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
        Map<String, String> envInfo = new HashMap<>();
        for (Map.Entry<Object, Object> e : p.entrySet()) {
            Object v = e.getValue();
            String value = (v == null ? "" : v.toString());
            envInfo.put(e.getKey().toString(), value);
        }
        initReport.reportSoftwareInfo(arch, osName, jvmName, jvmVersion, jvmSpecVersion, nd4jBackendClass, nd4jDataTypeName, hostname, UIDProvider.getJVMUID(), envInfo);
    }
    if (initConfig.collectHardwareInfo()) {
        int availableProcessors = Runtime.getRuntime().availableProcessors();
        NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        int nDevices = nativeOps.getAvailableDevices();
        long[] deviceTotalMem = null;
        //TODO
        String[] deviceDescription = null;
        if (nDevices > 0) {
            deviceTotalMem = new long[nDevices];
            deviceDescription = new String[nDevices];
            for (int i = 0; i < nDevices; i++) {
                try {
                    Pointer p = getDevicePointer(i);
                    if (p == null) {
                        deviceTotalMem[i] = 0;
                        deviceDescription[i] = "Device(" + i + ")";
                    } else {
                        deviceTotalMem[i] = nativeOps.getDeviceTotalMemory(p);
                        deviceDescription[i] = nativeOps.getDeviceName(p);
                        if (nDevices > 1) {
                            deviceDescription[i] = deviceDescription[i] + " (" + i + ")";
                        }
                    }
                } catch (Exception e) {
                    log.debug("Error getting device info", e);
                }
            }
        }
        long jvmMaxMemory = Runtime.getRuntime().maxMemory();
        long offheapMaxMemory = Pointer.maxBytes();
        initReport.reportHardwareInfo(availableProcessors, nDevices, jvmMaxMemory, offheapMaxMemory, deviceTotalMem, deviceDescription, UIDProvider.getHardwareUID());
    }
    if (initConfig.collectModelInfo()) {
        String jsonConf;
        int numLayers;
        int numParams;
        if (model instanceof MultiLayerNetwork) {
            MultiLayerNetwork net = ((MultiLayerNetwork) model);
            jsonConf = net.getLayerWiseConfigurations().toJson();
            numLayers = net.getnLayers();
            numParams = net.numParams();
        } else if (model instanceof ComputationGraph) {
            ComputationGraph cg = ((ComputationGraph) model);
            jsonConf = cg.getConfiguration().toJson();
            numLayers = cg.getNumLayers();
            numParams = cg.numParams();
        } else if (model instanceof Layer) {
            Layer l = (Layer) model;
            jsonConf = l.conf().toJson();
            numLayers = 1;
            numParams = l.numParams();
        } else {
            throw new RuntimeException("Invalid model: Expected MultiLayerNetwork or ComputationGraph. Got: " + (model == null ? null : model.getClass()));
        }
        Map<String, INDArray> paramMap = model.paramTable(backpropParamsOnly);
        String[] paramNames = new String[paramMap.size()];
        int i = 0;
        for (String s : paramMap.keySet()) {
            //Assuming sensible iteration order - LinkedHashMaps are used in MLN/CG for example
            paramNames[i++] = s;
        }
        initReport.reportModelInfo(model.getClass().getName(), jsonConf, paramNames, numLayers, numParams);
    }
    StorageMetaData meta = getNewStorageMetaData(initTime, getSessionID(model), workerID);
    router.putStorageMetaData(meta);
    //TODO error handling
    router.putStaticInfo(initReport);
}
Also used : Pointer(org.bytedeco.javacpp.Pointer) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) NativeOps(org.nd4j.nativeblas.NativeOps) InputStream(java.io.InputStream) RuntimeMXBean(java.lang.management.RuntimeMXBean) Layer(org.deeplearning4j.nn.api.Layer) StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) INDArray(org.nd4j.linalg.api.ndarray.INDArray) OperatingSystemMXBean(java.lang.management.OperatingSystemMXBean)

Aggregations

StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)8 Persistable (org.deeplearning4j.api.storage.Persistable)6 INDArray (org.nd4j.linalg.api.ndarray.INDArray)5 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)4 StatsStorageRouter (org.deeplearning4j.api.storage.StatsStorageRouter)2 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)2 SparkTrainingStats (org.deeplearning4j.spark.api.stats.SparkTrainingStats)2 VanillaStatsStorageRouter (org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)2 SbeStorageMetaData (org.deeplearning4j.ui.storage.impl.SbeStorageMetaData)2 Test (org.junit.Test)2 BufferedReader (java.io.BufferedReader)1 DataOutputStream (java.io.DataOutputStream)1 IOException (java.io.IOException)1 InputStream (java.io.InputStream)1 InputStreamReader (java.io.InputStreamReader)1 Serializable (java.io.Serializable)1 OperatingSystemMXBean (java.lang.management.OperatingSystemMXBean)1 RuntimeMXBean (java.lang.management.RuntimeMXBean)1 HttpURLConnection (java.net.HttpURLConnection)1 MalformedURLException (java.net.MalformedURLException)1