Search in sources :

Example 1 with CompactModelAndGradient

use of org.deeplearning4j.ui.weights.beans.CompactModelAndGradient in project deeplearning4j by deeplearning4j.

the class RemoteHistogramIterationListener method iterationDone.

@Override
public void iterationDone(Model model, int iteration) {
    if (curIteration % iterations == 0) {
        Map<String, Map> newGrad = new LinkedHashMap<>();
        try {
            Map<String, INDArray> grad = model.gradient().gradientForVariable();
            if (meanMagHistoryParams.isEmpty()) {
                //Initialize:
                int maxLayerIdx = -1;
                for (String s : grad.keySet()) {
                    maxLayerIdx = Math.max(maxLayerIdx, indexFromString(s));
                }
                if (maxLayerIdx == -1)
                    maxLayerIdx = 0;
                for (int i = 0; i <= maxLayerIdx; i++) {
                    meanMagHistoryParams.add(new LinkedHashMap<String, List<Double>>());
                    meanMagHistoryUpdates.add(new LinkedHashMap<String, List<Double>>());
                }
            }
            for (Map.Entry<String, INDArray> entry : grad.entrySet()) {
                String param = entry.getKey();
                String newName;
                if (Character.isDigit(param.charAt(0)))
                    newName = "param_" + param;
                else
                    newName = param;
                HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup()).setBinCount(20).setRounding(6).build();
                newGrad.put(newName, histogram.getData());
                //CSS identifier can't start with digit http://www.w3.org/TR/CSS21/syndata.html#value-def-identifier
                int idx = indexFromString(param);
                if (idx >= meanMagHistoryUpdates.size()) {
                    //log.info("Can't find idx for update ["+newName+"]");
                    meanMagHistoryUpdates.add(new LinkedHashMap<String, List<Double>>());
                }
                //Work out layer index:
                Map<String, List<Double>> map = meanMagHistoryUpdates.get(idx);
                List<Double> list = map.get(newName);
                if (list == null) {
                    list = new ArrayList<>();
                    map.put(newName, list);
                }
                double meanMag = entry.getValue().norm1Number().doubleValue() / entry.getValue().length();
                list.add(meanMag);
            }
        } catch (Exception e) {
            log.warn("Skipping gradients update");
        }
        //Process parameters: duplicate + calculate and store mean magnitudes
        Map<String, INDArray> params = model.paramTable();
        Map<String, Map> newParams = new LinkedHashMap<>();
        for (Map.Entry<String, INDArray> entry : params.entrySet()) {
            String param = entry.getKey();
            String newName;
            if (Character.isDigit(param.charAt(0)))
                newName = "param_" + param;
            else
                newName = param;
            HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup()).setBinCount(20).setRounding(6).build();
            newParams.put(newName, histogram.getData());
            //dup() because params might be a view
            int idx = indexFromString(param);
            if (idx >= meanMagHistoryParams.size()) {
                //log.info("Can't find idx for param ["+newName+"]");
                meanMagHistoryParams.add(new LinkedHashMap<String, List<Double>>());
            }
            Map<String, List<Double>> map = meanMagHistoryParams.get(idx);
            List<Double> list = map.get(newName);
            if (list == null) {
                list = new ArrayList<>();
                map.put(newName, list);
            }
            double meanMag = entry.getValue().norm1Number().doubleValue() / entry.getValue().length();
            list.add(meanMag);
        }
        double score = model.score();
        scoreHistory.add(score);
        //log.info("Saving score: " + score);
        CompactModelAndGradient g = new CompactModelAndGradient();
        g.setGradients(newGrad);
        g.setParameters(newParams);
        g.setScore(score);
        g.setScores(scoreHistory);
        g.setPath(subPath);
        g.setUpdateMagnitudes(meanMagHistoryUpdates);
        g.setParamMagnitudes(meanMagHistoryParams);
        g.setLayerNames(layerNames);
        g.setLastUpdateTime(System.currentTimeMillis());
        Response resp = target.request(MediaType.APPLICATION_JSON).accept(MediaType.APPLICATION_JSON).post(Entity.entity(g, MediaType.APPLICATION_JSON));
        log.debug("{}", resp);
        if (firstIteration) {
            StringBuilder builder = new StringBuilder(connectionInfo.getFullAddress());
            builder.append(subPath).append("?sid=").append(connectionInfo.getSessionId());
            firstIteration = false;
        }
    }
    curIteration += 1;
}
Also used : CompactModelAndGradient(org.deeplearning4j.ui.weights.beans.CompactModelAndGradient) Response(javax.ws.rs.core.Response) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 2 with CompactModelAndGradient

use of org.deeplearning4j.ui.weights.beans.CompactModelAndGradient in project deeplearning4j by deeplearning4j.

the class HistogramModule method processRequest.

private Result processRequest(String sessionId) {
    //TODO cache the relevant info and update, rather than querying StatsStorage and building from scratch each time
    StatsStorage ss = knownSessionIDs.get(sessionId);
    if (ss == null) {
        return Results.notFound("Unknown session ID: " + sessionId);
    }
    List<String> workerIDs = ss.listWorkerIDsForSession(sessionId);
    //TODO checks
    StatsInitializationReport initReport = (StatsInitializationReport) ss.getStaticInfo(sessionId, StatsListener.TYPE_ID, workerIDs.get(0));
    if (initReport == null)
        return Results.ok(Json.toJson(Collections.EMPTY_MAP));
    String[] paramNames = initReport.getModelParamNames();
    //Infer layer names from param names...
    Set<String> layerNameSet = new LinkedHashSet<>();
    for (String s : paramNames) {
        String[] split = s.split("_");
        if (!layerNameSet.contains(split[0])) {
            layerNameSet.add(split[0]);
        }
    }
    List<String> layerNameList = new ArrayList<>(layerNameSet);
    List<Persistable> list = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, workerIDs.get(0), 0);
    Collections.sort(list, (a, b) -> Long.compare(a.getTimeStamp(), b.getTimeStamp()));
    List<Double> scoreList = new ArrayList<>(list.size());
    //List.get(i) -> layer i. Maps: parameter for the given layer
    List<Map<String, List<Double>>> meanMagHistoryParams = new ArrayList<>();
    //List.get(i) -> layer i. Maps: updates for the given layer
    List<Map<String, List<Double>>> meanMagHistoryUpdates = new ArrayList<>();
    for (int i = 0; i < layerNameList.size(); i++) {
        meanMagHistoryParams.add(new HashMap<>());
        meanMagHistoryUpdates.add(new HashMap<>());
    }
    StatsReport last = null;
    for (Persistable p : list) {
        if (!(p instanceof StatsReport)) {
            log.debug("Encountered unexpected type: {}", p);
            continue;
        }
        StatsReport sp = (StatsReport) p;
        scoreList.add(sp.getScore());
        //Mean magnitudes
        if (sp.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)) {
            updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Parameters), layerNameList, meanMagHistoryParams);
        }
        if (sp.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)) {
            updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Updates), layerNameList, meanMagHistoryUpdates);
        }
        last = sp;
    }
    Map<String, Map> newParams = getHistogram(last.getHistograms(StatsType.Parameters));
    Map<String, Map> newGrad = getHistogram(last.getHistograms(StatsType.Updates));
    double lastScore = (scoreList.size() == 0 ? 0.0 : scoreList.get(scoreList.size() - 1));
    CompactModelAndGradient g = new CompactModelAndGradient();
    g.setGradients(newGrad);
    g.setParameters(newParams);
    g.setScore(lastScore);
    g.setScores(scoreList);
    //        g.setPath(subPath);
    g.setUpdateMagnitudes(meanMagHistoryUpdates);
    g.setParamMagnitudes(meanMagHistoryParams);
    //        g.setLayerNames(layerNames);
    g.setLastUpdateTime(last.getTimeStamp());
    return Results.ok(Json.toJson(g));
}
Also used : StatsInitializationReport(org.deeplearning4j.ui.stats.api.StatsInitializationReport) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable) CompactModelAndGradient(org.deeplearning4j.ui.weights.beans.CompactModelAndGradient) StatsReport(org.deeplearning4j.ui.stats.api.StatsReport)

Aggregations

CompactModelAndGradient (org.deeplearning4j.ui.weights.beans.CompactModelAndGradient)2 Response (javax.ws.rs.core.Response)1 Persistable (org.deeplearning4j.api.storage.Persistable)1 StatsStorage (org.deeplearning4j.api.storage.StatsStorage)1 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)1 StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1