use of org.deeplearning4j.ui.weights.beans.CompactModelAndGradient in project deeplearning4j by deeplearning4j.
the class RemoteHistogramIterationListener method iterationDone.
@Override
public void iterationDone(Model model, int iteration) {
if (curIteration % iterations == 0) {
Map<String, Map> newGrad = new LinkedHashMap<>();
try {
Map<String, INDArray> grad = model.gradient().gradientForVariable();
if (meanMagHistoryParams.isEmpty()) {
//Initialize:
int maxLayerIdx = -1;
for (String s : grad.keySet()) {
maxLayerIdx = Math.max(maxLayerIdx, indexFromString(s));
}
if (maxLayerIdx == -1)
maxLayerIdx = 0;
for (int i = 0; i <= maxLayerIdx; i++) {
meanMagHistoryParams.add(new LinkedHashMap<String, List<Double>>());
meanMagHistoryUpdates.add(new LinkedHashMap<String, List<Double>>());
}
}
for (Map.Entry<String, INDArray> entry : grad.entrySet()) {
String param = entry.getKey();
String newName;
if (Character.isDigit(param.charAt(0)))
newName = "param_" + param;
else
newName = param;
HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup()).setBinCount(20).setRounding(6).build();
newGrad.put(newName, histogram.getData());
//CSS identifier can't start with digit http://www.w3.org/TR/CSS21/syndata.html#value-def-identifier
int idx = indexFromString(param);
if (idx >= meanMagHistoryUpdates.size()) {
//log.info("Can't find idx for update ["+newName+"]");
meanMagHistoryUpdates.add(new LinkedHashMap<String, List<Double>>());
}
//Work out layer index:
Map<String, List<Double>> map = meanMagHistoryUpdates.get(idx);
List<Double> list = map.get(newName);
if (list == null) {
list = new ArrayList<>();
map.put(newName, list);
}
double meanMag = entry.getValue().norm1Number().doubleValue() / entry.getValue().length();
list.add(meanMag);
}
} catch (Exception e) {
log.warn("Skipping gradients update");
}
//Process parameters: duplicate + calculate and store mean magnitudes
Map<String, INDArray> params = model.paramTable();
Map<String, Map> newParams = new LinkedHashMap<>();
for (Map.Entry<String, INDArray> entry : params.entrySet()) {
String param = entry.getKey();
String newName;
if (Character.isDigit(param.charAt(0)))
newName = "param_" + param;
else
newName = param;
HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup()).setBinCount(20).setRounding(6).build();
newParams.put(newName, histogram.getData());
//dup() because params might be a view
int idx = indexFromString(param);
if (idx >= meanMagHistoryParams.size()) {
//log.info("Can't find idx for param ["+newName+"]");
meanMagHistoryParams.add(new LinkedHashMap<String, List<Double>>());
}
Map<String, List<Double>> map = meanMagHistoryParams.get(idx);
List<Double> list = map.get(newName);
if (list == null) {
list = new ArrayList<>();
map.put(newName, list);
}
double meanMag = entry.getValue().norm1Number().doubleValue() / entry.getValue().length();
list.add(meanMag);
}
double score = model.score();
scoreHistory.add(score);
//log.info("Saving score: " + score);
CompactModelAndGradient g = new CompactModelAndGradient();
g.setGradients(newGrad);
g.setParameters(newParams);
g.setScore(score);
g.setScores(scoreHistory);
g.setPath(subPath);
g.setUpdateMagnitudes(meanMagHistoryUpdates);
g.setParamMagnitudes(meanMagHistoryParams);
g.setLayerNames(layerNames);
g.setLastUpdateTime(System.currentTimeMillis());
Response resp = target.request(MediaType.APPLICATION_JSON).accept(MediaType.APPLICATION_JSON).post(Entity.entity(g, MediaType.APPLICATION_JSON));
log.debug("{}", resp);
if (firstIteration) {
StringBuilder builder = new StringBuilder(connectionInfo.getFullAddress());
builder.append(subPath).append("?sid=").append(connectionInfo.getSessionId());
firstIteration = false;
}
}
curIteration += 1;
}
use of org.deeplearning4j.ui.weights.beans.CompactModelAndGradient in project deeplearning4j by deeplearning4j.
the class HistogramModule method processRequest.
private Result processRequest(String sessionId) {
//TODO cache the relevant info and update, rather than querying StatsStorage and building from scratch each time
StatsStorage ss = knownSessionIDs.get(sessionId);
if (ss == null) {
return Results.notFound("Unknown session ID: " + sessionId);
}
List<String> workerIDs = ss.listWorkerIDsForSession(sessionId);
//TODO checks
StatsInitializationReport initReport = (StatsInitializationReport) ss.getStaticInfo(sessionId, StatsListener.TYPE_ID, workerIDs.get(0));
if (initReport == null)
return Results.ok(Json.toJson(Collections.EMPTY_MAP));
String[] paramNames = initReport.getModelParamNames();
//Infer layer names from param names...
Set<String> layerNameSet = new LinkedHashSet<>();
for (String s : paramNames) {
String[] split = s.split("_");
if (!layerNameSet.contains(split[0])) {
layerNameSet.add(split[0]);
}
}
List<String> layerNameList = new ArrayList<>(layerNameSet);
List<Persistable> list = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, workerIDs.get(0), 0);
Collections.sort(list, (a, b) -> Long.compare(a.getTimeStamp(), b.getTimeStamp()));
List<Double> scoreList = new ArrayList<>(list.size());
//List.get(i) -> layer i. Maps: parameter for the given layer
List<Map<String, List<Double>>> meanMagHistoryParams = new ArrayList<>();
//List.get(i) -> layer i. Maps: updates for the given layer
List<Map<String, List<Double>>> meanMagHistoryUpdates = new ArrayList<>();
for (int i = 0; i < layerNameList.size(); i++) {
meanMagHistoryParams.add(new HashMap<>());
meanMagHistoryUpdates.add(new HashMap<>());
}
StatsReport last = null;
for (Persistable p : list) {
if (!(p instanceof StatsReport)) {
log.debug("Encountered unexpected type: {}", p);
continue;
}
StatsReport sp = (StatsReport) p;
scoreList.add(sp.getScore());
//Mean magnitudes
if (sp.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)) {
updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Parameters), layerNameList, meanMagHistoryParams);
}
if (sp.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)) {
updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Updates), layerNameList, meanMagHistoryUpdates);
}
last = sp;
}
Map<String, Map> newParams = getHistogram(last.getHistograms(StatsType.Parameters));
Map<String, Map> newGrad = getHistogram(last.getHistograms(StatsType.Updates));
double lastScore = (scoreList.size() == 0 ? 0.0 : scoreList.get(scoreList.size() - 1));
CompactModelAndGradient g = new CompactModelAndGradient();
g.setGradients(newGrad);
g.setParameters(newParams);
g.setScore(lastScore);
g.setScores(scoreList);
// g.setPath(subPath);
g.setUpdateMagnitudes(meanMagHistoryUpdates);
g.setParamMagnitudes(meanMagHistoryParams);
// g.setLayerNames(layerNames);
g.setLastUpdateTime(last.getTimeStamp());
return Results.ok(Json.toJson(g));
}
Aggregations