use of org.deeplearning4j.ui.weights.HistogramBin in project deeplearning4j by deeplearning4j.
the class RemoteFlowIterationListener method buildModelState.
protected void buildModelState(Model model) {
// first we update performance state
long timeSpent = currTime - lastTime;
float timeSec = timeSpent / 1000f;
INDArray input = model.input();
long tadLength = Shape.getTADLength(input.shape(), ArrayUtil.range(1, input.rank()));
long numSamples = input.lengthLong() / tadLength;
modelState.addPerformanceSamples(numSamples / timeSec);
modelState.addPerformanceBatches(1 / timeSec);
modelState.setIterationTime(timeSpent);
// now model score
modelState.addScore((float) model.score());
modelState.setScore((float) model.score());
modelState.setTrainingTime(parseTime(System.currentTimeMillis() - initTime));
// and now update model params/gradients
Map<String, Map> newGrad = new LinkedHashMap<>();
Map<String, Map> newParams = new LinkedHashMap<>();
Map<String, INDArray> params = model.paramTable();
Layer[] layers = null;
if (model instanceof MultiLayerNetwork) {
layers = ((MultiLayerNetwork) model).getLayers();
} else if (model instanceof ComputationGraph) {
layers = ((ComputationGraph) model).getLayers();
}
List<Double> lrs = new ArrayList<>();
if (layers != null) {
for (Layer layer : layers) {
lrs.add(layer.conf().getLayer().getLearningRate());
}
modelState.setLearningRates(lrs);
}
Map<Integer, LayerParams> layerParamsMap = new LinkedHashMap<>();
for (Map.Entry<String, INDArray> entry : params.entrySet()) {
String param = entry.getKey();
if (!Character.isDigit(param.charAt(0)))
continue;
int layer = Integer.parseInt(param.replaceAll("\\_.*$", ""));
String key = param.replaceAll("^.*?_", "").toLowerCase();
if (!layerParamsMap.containsKey(layer))
layerParamsMap.put(layer, new LayerParams());
HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup()).setBinCount(14).setRounding(6).build();
// TODO: something better would be nice to have here
if (key.equalsIgnoreCase("w")) {
layerParamsMap.get(layer).setW(histogram.getData());
} else if (key.equalsIgnoreCase("rw")) {
layerParamsMap.get(layer).setRW(histogram.getData());
} else if (key.equalsIgnoreCase("rwf")) {
layerParamsMap.get(layer).setRWF(histogram.getData());
} else if (key.equalsIgnoreCase("b")) {
layerParamsMap.get(layer).setB(histogram.getData());
}
}
modelState.setLayerParams(layerParamsMap);
}
use of org.deeplearning4j.ui.weights.HistogramBin in project deeplearning4j by deeplearning4j.
the class FlowIterationListener method buildModelState.
protected void buildModelState(Model model) {
// first we update performance state
long timeSpent = currTime - lastTime;
float timeSec = timeSpent / 1000f;
INDArray input = model.input();
long tadLength = Shape.getTADLength(input.shape(), ArrayUtil.range(1, input.rank()));
long numSamples = input.lengthLong() / tadLength;
modelState.addPerformanceSamples(numSamples / timeSec);
modelState.addPerformanceBatches(1 / timeSec);
modelState.setIterationTime(timeSpent);
// now model score
modelState.addScore((float) model.score());
modelState.setScore((float) model.score());
modelState.setTrainingTime(parseTime(System.currentTimeMillis() - initTime));
// and now update model params/gradients
Map<String, Map> newGrad = new LinkedHashMap<>();
Map<String, Map> newParams = new LinkedHashMap<>();
Map<String, INDArray> params = model.paramTable();
Layer[] layers = null;
if (model instanceof MultiLayerNetwork) {
layers = ((MultiLayerNetwork) model).getLayers();
} else if (model instanceof ComputationGraph) {
layers = ((ComputationGraph) model).getLayers();
}
List<Double> lrs = new ArrayList<>();
if (layers != null) {
for (Layer layer : layers) {
lrs.add(layer.conf().getLayer().getLearningRate());
}
modelState.setLearningRates(lrs);
}
Map<Integer, LayerParams> layerParamsMap = new LinkedHashMap<>();
for (Map.Entry<String, INDArray> entry : params.entrySet()) {
String param = entry.getKey();
if (!Character.isDigit(param.charAt(0)))
continue;
int layer = Integer.parseInt(param.replaceAll("\\_.*$", ""));
String key = param.replaceAll("^.*?_", "").toLowerCase();
if (!layerParamsMap.containsKey(layer))
layerParamsMap.put(layer, new LayerParams());
HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup()).setBinCount(14).setRounding(6).build();
// TODO: something better would be nice to have here
if (key.equalsIgnoreCase("w")) {
layerParamsMap.get(layer).setW(histogram.getData());
} else if (key.equalsIgnoreCase("rw")) {
layerParamsMap.get(layer).setRW(histogram.getData());
} else if (key.equalsIgnoreCase("rwf")) {
layerParamsMap.get(layer).setRWF(histogram.getData());
} else if (key.equalsIgnoreCase("b")) {
layerParamsMap.get(layer).setB(histogram.getData());
}
}
modelState.setLayerParams(layerParamsMap);
}
Aggregations