Search in sources :

Example 1 with HistogramBin

use of org.deeplearning4j.ui.weights.HistogramBin in project deeplearning4j by deeplearning4j.

the class RemoteFlowIterationListener method buildModelState.

protected void buildModelState(Model model) {
    // first we update performance state
    long timeSpent = currTime - lastTime;
    float timeSec = timeSpent / 1000f;
    INDArray input = model.input();
    long tadLength = Shape.getTADLength(input.shape(), ArrayUtil.range(1, input.rank()));
    long numSamples = input.lengthLong() / tadLength;
    modelState.addPerformanceSamples(numSamples / timeSec);
    modelState.addPerformanceBatches(1 / timeSec);
    modelState.setIterationTime(timeSpent);
    // now model score
    modelState.addScore((float) model.score());
    modelState.setScore((float) model.score());
    modelState.setTrainingTime(parseTime(System.currentTimeMillis() - initTime));
    // and now update model params/gradients
    Map<String, Map> newGrad = new LinkedHashMap<>();
    Map<String, Map> newParams = new LinkedHashMap<>();
    Map<String, INDArray> params = model.paramTable();
    Layer[] layers = null;
    if (model instanceof MultiLayerNetwork) {
        layers = ((MultiLayerNetwork) model).getLayers();
    } else if (model instanceof ComputationGraph) {
        layers = ((ComputationGraph) model).getLayers();
    }
    List<Double> lrs = new ArrayList<>();
    if (layers != null) {
        for (Layer layer : layers) {
            lrs.add(layer.conf().getLayer().getLearningRate());
        }
        modelState.setLearningRates(lrs);
    }
    Map<Integer, LayerParams> layerParamsMap = new LinkedHashMap<>();
    for (Map.Entry<String, INDArray> entry : params.entrySet()) {
        String param = entry.getKey();
        if (!Character.isDigit(param.charAt(0)))
            continue;
        int layer = Integer.parseInt(param.replaceAll("\\_.*$", ""));
        String key = param.replaceAll("^.*?_", "").toLowerCase();
        if (!layerParamsMap.containsKey(layer))
            layerParamsMap.put(layer, new LayerParams());
        HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup()).setBinCount(14).setRounding(6).build();
        // TODO: something better would be nice to have here
        if (key.equalsIgnoreCase("w")) {
            layerParamsMap.get(layer).setW(histogram.getData());
        } else if (key.equalsIgnoreCase("rw")) {
            layerParamsMap.get(layer).setRW(histogram.getData());
        } else if (key.equalsIgnoreCase("rwf")) {
            layerParamsMap.get(layer).setRWF(histogram.getData());
        } else if (key.equalsIgnoreCase("b")) {
            layerParamsMap.get(layer).setB(histogram.getData());
        }
    }
    modelState.setLayerParams(layerParamsMap);
}
Also used : Layer(org.deeplearning4j.nn.api.Layer) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) BaseOutputLayer(org.deeplearning4j.nn.conf.layers.BaseOutputLayer) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) INDArray(org.nd4j.linalg.api.ndarray.INDArray) HistogramBin(org.deeplearning4j.ui.weights.HistogramBin) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 2 with HistogramBin

use of org.deeplearning4j.ui.weights.HistogramBin in project deeplearning4j by deeplearning4j.

the class FlowIterationListener method buildModelState.

protected void buildModelState(Model model) {
    // first we update performance state
    long timeSpent = currTime - lastTime;
    float timeSec = timeSpent / 1000f;
    INDArray input = model.input();
    long tadLength = Shape.getTADLength(input.shape(), ArrayUtil.range(1, input.rank()));
    long numSamples = input.lengthLong() / tadLength;
    modelState.addPerformanceSamples(numSamples / timeSec);
    modelState.addPerformanceBatches(1 / timeSec);
    modelState.setIterationTime(timeSpent);
    // now model score
    modelState.addScore((float) model.score());
    modelState.setScore((float) model.score());
    modelState.setTrainingTime(parseTime(System.currentTimeMillis() - initTime));
    // and now update model params/gradients
    Map<String, Map> newGrad = new LinkedHashMap<>();
    Map<String, Map> newParams = new LinkedHashMap<>();
    Map<String, INDArray> params = model.paramTable();
    Layer[] layers = null;
    if (model instanceof MultiLayerNetwork) {
        layers = ((MultiLayerNetwork) model).getLayers();
    } else if (model instanceof ComputationGraph) {
        layers = ((ComputationGraph) model).getLayers();
    }
    List<Double> lrs = new ArrayList<>();
    if (layers != null) {
        for (Layer layer : layers) {
            lrs.add(layer.conf().getLayer().getLearningRate());
        }
        modelState.setLearningRates(lrs);
    }
    Map<Integer, LayerParams> layerParamsMap = new LinkedHashMap<>();
    for (Map.Entry<String, INDArray> entry : params.entrySet()) {
        String param = entry.getKey();
        if (!Character.isDigit(param.charAt(0)))
            continue;
        int layer = Integer.parseInt(param.replaceAll("\\_.*$", ""));
        String key = param.replaceAll("^.*?_", "").toLowerCase();
        if (!layerParamsMap.containsKey(layer))
            layerParamsMap.put(layer, new LayerParams());
        HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup()).setBinCount(14).setRounding(6).build();
        // TODO: something better would be nice to have here
        if (key.equalsIgnoreCase("w")) {
            layerParamsMap.get(layer).setW(histogram.getData());
        } else if (key.equalsIgnoreCase("rw")) {
            layerParamsMap.get(layer).setRW(histogram.getData());
        } else if (key.equalsIgnoreCase("rwf")) {
            layerParamsMap.get(layer).setRWF(histogram.getData());
        } else if (key.equalsIgnoreCase("b")) {
            layerParamsMap.get(layer).setB(histogram.getData());
        }
    }
    modelState.setLayerParams(layerParamsMap);
}
Also used : Layer(org.deeplearning4j.nn.api.Layer) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) BaseOutputLayer(org.deeplearning4j.nn.conf.layers.BaseOutputLayer) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) INDArray(org.nd4j.linalg.api.ndarray.INDArray) HistogramBin(org.deeplearning4j.ui.weights.HistogramBin) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Aggregations

AtomicInteger (java.util.concurrent.atomic.AtomicInteger)2 Layer (org.deeplearning4j.nn.api.Layer)2 BaseOutputLayer (org.deeplearning4j.nn.conf.layers.BaseOutputLayer)2 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)2 SubsamplingLayer (org.deeplearning4j.nn.conf.layers.SubsamplingLayer)2 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)2 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)2 HistogramBin (org.deeplearning4j.ui.weights.HistogramBin)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2