Search in sources :

Example 1 with StorageMetaData

use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.

the class RemoteUIStatsStorageRouter method tryPost.

private boolean tryPost(ToPost toPost) throws IOException {
    HttpURLConnection connection = getConnection();
    String className;
    byte[] asBytes;
    StorageType type;
    if (toPost.getMeta() != null) {
        StorageMetaData smd = toPost.getMeta();
        className = smd.getClass().getName();
        asBytes = smd.encode();
        type = StorageType.MetaData;
    } else if (toPost.getStaticInfo() != null) {
        Persistable p = toPost.getStaticInfo();
        className = p.getClass().getName();
        asBytes = p.encode();
        type = StorageType.StaticInfo;
    } else {
        Persistable p = toPost.getUpdate();
        className = p.getClass().getName();
        asBytes = p.encode();
        type = StorageType.Update;
    }
    String base64 = DatatypeConverter.printBase64Binary(asBytes);
    Map<String, String> jsonObj = new LinkedHashMap<>();
    jsonObj.put("type", type.name());
    jsonObj.put("class", className);
    jsonObj.put("data", base64);
    String str;
    try {
        str = objectMapper.writeValueAsString(jsonObj);
    } catch (Exception e) {
        //Should never get an exception from simple Map<String,String>
        throw new RuntimeException(e);
    }
    DataOutputStream dos = new DataOutputStream(connection.getOutputStream());
    dos.writeBytes(str);
    dos.flush();
    dos.close();
    try {
        int responseCode = connection.getResponseCode();
        if (responseCode != 200) {
            BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream()));
            String inputLine;
            StringBuilder response = new StringBuilder();
            while ((inputLine = in.readLine()) != null) {
                response.append(inputLine);
            }
            in.close();
            log.warn("Error posting to remote UI - received response code {}\tContent: {}", response, response.toString());
            return false;
        }
    } catch (IOException e) {
        String msg = e.getMessage();
        if (msg.contains("403 for URL")) {
            log.warn("Error posting to remote UI at {} (Response code: 403)." + " Remote listener support is not enabled? use UIServer.getInstance().enableRemoteListener()", url, e);
        } else {
            log.warn("Error posting to remote UI at {}", url, e);
        }
        return false;
    }
    return true;
}
Also used : StorageType(org.deeplearning4j.api.storage.StorageType) Persistable(org.deeplearning4j.api.storage.Persistable) InputStreamReader(java.io.InputStreamReader) DataOutputStream(java.io.DataOutputStream) IOException(java.io.IOException) MalformedURLException(java.net.MalformedURLException) IOException(java.io.IOException) StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) HttpURLConnection(java.net.HttpURLConnection) BufferedReader(java.io.BufferedReader)

Example 2 with StorageMetaData

use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.

the class ParameterAveragingElementCombineFunction method call.

@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple v1, ParameterAveragingAggregationTuple v2) throws Exception {
    if (v1 == null)
        return v2;
    else if (v2 == null)
        return v1;
    //Handle edge case of less data than executors: in this case, one (or both) of v1 and v2 might not have any contents...
    if (v1.getParametersSum() == null)
        return v2;
    else if (v2.getParametersSum() == null)
        return v1;
    INDArray newParams = v1.getParametersSum().addi(v2.getParametersSum());
    INDArray updaterStateSum;
    if (v1.getUpdaterStateSum() == null) {
        updaterStateSum = v2.getUpdaterStateSum();
    } else {
        updaterStateSum = v1.getUpdaterStateSum();
        if (v2.getUpdaterStateSum() != null)
            updaterStateSum.addi(v2.getUpdaterStateSum());
    }
    double scoreSum = v1.getScoreSum() + v2.getScoreSum();
    int aggregationCount = v1.getAggregationsCount() + v2.getAggregationsCount();
    SparkTrainingStats stats = v1.getSparkTrainingStats();
    if (v2.getSparkTrainingStats() != null) {
        if (stats == null)
            stats = v2.getSparkTrainingStats();
        else
            stats.addOtherTrainingStats(v2.getSparkTrainingStats());
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> listenerMetaData = v1.getListenerMetaData();
    if (listenerMetaData == null)
        listenerMetaData = v2.getListenerMetaData();
    else {
        Collection<StorageMetaData> newMeta = v2.getListenerMetaData();
        if (newMeta != null)
            listenerMetaData.addAll(newMeta);
    }
    Collection<Persistable> listenerStaticInfo = v1.getListenerStaticInfo();
    if (listenerStaticInfo == null)
        listenerStaticInfo = v2.getListenerStaticInfo();
    else {
        Collection<Persistable> newStatic = v2.getListenerStaticInfo();
        if (newStatic != null)
            listenerStaticInfo.addAll(newStatic);
    }
    Collection<Persistable> listenerUpdates = v1.getListenerUpdates();
    if (listenerUpdates == null)
        listenerUpdates = v2.getListenerUpdates();
    else {
        Collection<Persistable> listenerUpdates2 = v2.getListenerUpdates();
        if (listenerUpdates2 != null)
            listenerUpdates.addAll(listenerUpdates2);
    }
    return new ParameterAveragingAggregationTuple(newParams, updaterStateSum, scoreSum, aggregationCount, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats)

Example 3 with StorageMetaData

use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.

the class TestRemoteReceiver method testRemoteBasic.

@Test
@Ignore
public void testRemoteBasic() throws Exception {
    List<Persistable> updates = new ArrayList<>();
    List<Persistable> staticInfo = new ArrayList<>();
    List<StorageMetaData> metaData = new ArrayList<>();
    CollectionStatsStorageRouter collectionRouter = new CollectionStatsStorageRouter(metaData, staticInfo, updates);
    UIServer s = UIServer.getInstance();
    s.enableRemoteListener(collectionRouter, false);
    RemoteUIStatsStorageRouter remoteRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
    SbeStatsReport update1 = new SbeStatsReport();
    update1.setDeviceCurrentBytes(new long[] { 1, 2 });
    update1.reportIterationCount(10);
    update1.reportIDs("sid", "tid", "wid", 123456);
    update1.reportPerformance(10, 20, 30, 40, 50);
    SbeStatsReport update2 = new SbeStatsReport();
    update2.setDeviceCurrentBytes(new long[] { 3, 4 });
    update2.reportIterationCount(20);
    update2.reportIDs("sid2", "tid2", "wid2", 123456);
    update2.reportPerformance(11, 21, 31, 40, 50);
    StorageMetaData smd1 = new SbeStorageMetaData(123, "sid", "typeid", "wid", "initTypeClass", "updaterTypeClass");
    StorageMetaData smd2 = new SbeStorageMetaData(456, "sid2", "typeid2", "wid2", "initTypeClass2", "updaterTypeClass2");
    SbeStatsInitializationReport init1 = new SbeStatsInitializationReport();
    init1.reportIDs("sid", "wid", "tid", 3145253452L);
    init1.reportHardwareInfo(1, 2, 3, 4, null, null, "2344253");
    remoteRouter.putUpdate(update1);
    Thread.sleep(100);
    remoteRouter.putStorageMetaData(smd1);
    Thread.sleep(100);
    remoteRouter.putStaticInfo(init1);
    Thread.sleep(100);
    remoteRouter.putUpdate(update2);
    Thread.sleep(100);
    remoteRouter.putStorageMetaData(smd2);
    Thread.sleep(2000);
    assertEquals(2, metaData.size());
    assertEquals(2, updates.size());
    assertEquals(1, staticInfo.size());
    assertEquals(Arrays.asList(update1, update2), updates);
    assertEquals(Arrays.asList(smd1, smd2), metaData);
    assertEquals(Collections.singletonList(init1), staticInfo);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) SbeStorageMetaData(org.deeplearning4j.ui.storage.impl.SbeStorageMetaData) Persistable(org.deeplearning4j.api.storage.Persistable) SbeStatsInitializationReport(org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport) SbeStatsReport(org.deeplearning4j.ui.stats.impl.SbeStatsReport) UIServer(org.deeplearning4j.ui.api.UIServer) ArrayList(java.util.ArrayList) CollectionStatsStorageRouter(org.deeplearning4j.api.storage.impl.CollectionStatsStorageRouter) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) SbeStorageMetaData(org.deeplearning4j.ui.storage.impl.SbeStorageMetaData) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 4 with StorageMetaData

use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.

the class TestStorageMetaData method testStorageMetaData.

@Test
public void testStorageMetaData() {
    Serializable extraMeta = "ExtraMetaData";
    long timeStamp = 123456;
    StorageMetaData m = new SbeStorageMetaData(timeStamp, "sessionID", "typeID", "workerID", "org.some.class.InitType", "org.some.class.UpdateType", extraMeta);
    byte[] bytes = m.encode();
    StorageMetaData m2 = new SbeStorageMetaData();
    m2.decode(bytes);
    assertEquals(m, m2);
    assertArrayEquals(bytes, m2.encode());
    //Sanity check: null values
    m = new SbeStorageMetaData(0, null, null, null, null, (String) null);
    bytes = m.encode();
    m2 = new SbeStorageMetaData();
    m2.decode(bytes);
    //In practice, we don't want these things to ever be null anyway...
    assertNullOrZeroLength(m2.getSessionID());
    assertNullOrZeroLength(m2.getTypeID());
    assertNullOrZeroLength(m2.getWorkerID());
    assertNullOrZeroLength(m2.getInitTypeClass());
    assertNullOrZeroLength(m2.getUpdateTypeClass());
    assertArrayEquals(bytes, m2.encode());
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) SbeStorageMetaData(org.deeplearning4j.ui.storage.impl.SbeStorageMetaData) Serializable(java.io.Serializable) SbeStorageMetaData(org.deeplearning4j.ui.storage.impl.SbeStorageMetaData) Test(org.junit.Test)

Example 5 with StorageMetaData

use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getFinalResult.

@Override
public ParameterAveragingTrainingResult getFinalResult(ComputationGraph network) {
    INDArray updaterState = null;
    if (saveUpdater) {
        ComputationGraphUpdater u = network.getUpdater();
        if (u != null)
            updaterState = u.getStateViewArray();
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> storageMetaData = null;
    Collection<Persistable> listenerStaticInfo = null;
    Collection<Persistable> listenerUpdates = null;
    if (listenerRouterProvider != null) {
        StatsStorageRouter r = listenerRouterProvider.getRouter();
        if (r instanceof VanillaStatsStorageRouter) {
            //TODO this is ugly... need to find a better solution
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
    }
    return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)

Aggregations

StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)8 Persistable (org.deeplearning4j.api.storage.Persistable)6 INDArray (org.nd4j.linalg.api.ndarray.INDArray)5 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)4 StatsStorageRouter (org.deeplearning4j.api.storage.StatsStorageRouter)2 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)2 SparkTrainingStats (org.deeplearning4j.spark.api.stats.SparkTrainingStats)2 VanillaStatsStorageRouter (org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)2 SbeStorageMetaData (org.deeplearning4j.ui.storage.impl.SbeStorageMetaData)2 Test (org.junit.Test)2 BufferedReader (java.io.BufferedReader)1 DataOutputStream (java.io.DataOutputStream)1 IOException (java.io.IOException)1 InputStream (java.io.InputStream)1 InputStreamReader (java.io.InputStreamReader)1 Serializable (java.io.Serializable)1 OperatingSystemMXBean (java.lang.management.OperatingSystemMXBean)1 RuntimeMXBean (java.lang.management.RuntimeMXBean)1 HttpURLConnection (java.net.HttpURLConnection)1 MalformedURLException (java.net.MalformedURLException)1