Search in sources :

Example 1 with TrainingListener

use of org.deeplearning4j.optimize.api.TrainingListener in project deeplearning4j by deeplearning4j.

the class ComputationGraph method computeGradientAndScore.

@Override
public void computeGradientAndScore() {
    //Calculate activations (which are stored in each layer, and used in backprop)
    if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
        Map<String, INDArray> activations = rnnActivateUsingStoredState(inputs, true, true);
        if (trainingListeners.size() > 0) {
            for (TrainingListener tl : trainingListeners) {
                tl.onForwardPass(this, activations);
            }
        }
        calcBackpropGradients(true);
    } else {
        Map<String, INDArray> activations = feedForward(true, true);
        if (trainingListeners.size() > 0) {
            for (TrainingListener tl : trainingListeners) {
                tl.onForwardPass(this, activations);
            }
        }
        calcBackpropGradients(false);
    }
    //Score: sum of the scores for the various output layers...
    double l1 = calcL1();
    double l2 = calcL2();
    score = 0.0;
    for (String s : configuration.getNetworkOutputs()) {
        GraphVertex gv = verticesMap.get(s);
        score += ((IOutputLayer) gv.getLayer()).computeScore(l1, l2, true);
        //Only want to add l1/l2 once...
        l1 = 0.0;
        l2 = 0.0;
    }
    //Listeners
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onBackwardPass(this);
        }
    }
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) TrainingListener(org.deeplearning4j.optimize.api.TrainingListener)

Example 2 with TrainingListener

use of org.deeplearning4j.optimize.api.TrainingListener in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method fit.

@Override
public void fit(DataSetIterator iterator) {
    DataSetIterator iter;
    // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
    if (iterator.asyncSupported()) {
        iter = new AsyncDataSetIterator(iterator, 2);
    } else {
        iter = iterator;
    }
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochStart(this);
        }
    }
    if (layerWiseConfigurations.isPretrain()) {
        pretrain(iter);
        if (iter.resetSupported()) {
            iter.reset();
        }
    //            while (iter.hasNext()) {
    //                DataSet next = iter.next();
    //                if (next.getFeatureMatrix() == null || next.getLabels() == null)
    //                    break;
    //                setInput(next.getFeatureMatrix());
    //                setLabels(next.getLabels());
    //                finetune();
    //            }
    }
    if (layerWiseConfigurations.isBackprop()) {
        update(TaskUtils.buildTask(iter));
        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }
        while (iter.hasNext()) {
            DataSet next = iter.next();
            if (next.getFeatureMatrix() == null || next.getLabels() == null)
                break;
            boolean hasMaskArrays = next.hasMaskArrays();
            if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(), next.getLabelsMaskArray());
            } else {
                if (hasMaskArrays)
                    setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
                setInput(next.getFeatureMatrix());
                setLabels(next.getLabels());
                if (solver == null) {
                    solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                }
                solver.optimize();
            }
            if (hasMaskArrays)
                clearLayerMaskArrays();
            Nd4j.getMemoryManager().invokeGcOccasionally();
        }
    } else if (layerWiseConfigurations.isPretrain()) {
        log.warn("Warning: finetune is not applied.");
    }
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochEnd(this);
        }
    }
}
Also used : Solver(org.deeplearning4j.optimize.Solver) DataSet(org.nd4j.linalg.dataset.DataSet) TrainingListener(org.deeplearning4j.optimize.api.TrainingListener) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)

Example 3 with TrainingListener

use of org.deeplearning4j.optimize.api.TrainingListener in project deeplearning4j by deeplearning4j.

the class ComputationGraph method fit.

/**
     * Fit the ComputationGraph using a DataSetIterator.
     * Note that this method can only be used with ComputationGraphs with 1 input and 1 output
     */
public void fit(DataSetIterator iterator) {
    if (flattenedGradients == null)
        initGradientsView();
    if (numInputArrays != 1 || numOutputArrays != 1)
        throw new UnsupportedOperationException("Cannot train ComputationGraph network with " + " multiple inputs or outputs using a DataSetIterator");
    DataSetIterator dataSetIterator;
    // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
    if (iterator.asyncSupported()) {
        dataSetIterator = new AsyncDataSetIterator(iterator, 2);
    } else
        dataSetIterator = iterator;
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochStart(this);
        }
    }
    if (configuration.isPretrain()) {
        pretrain(dataSetIterator);
    }
    if (configuration.isBackprop()) {
        update(TaskUtils.buildTask(dataSetIterator));
        while (dataSetIterator.hasNext()) {
            DataSet next = dataSetIterator.next();
            if (next.getFeatures() == null || next.getLabels() == null)
                break;
            boolean hasMaskArrays = next.hasMaskArrays();
            if (hasMaskArrays) {
                INDArray[] fMask = (next.getFeaturesMaskArray() != null ? new INDArray[] { next.getFeaturesMaskArray() } : null);
                INDArray[] lMask = (next.getLabelsMaskArray() != null ? new INDArray[] { next.getLabelsMaskArray() } : null);
                setLayerMaskArrays(fMask, lMask);
            }
            if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(new INDArray[] { next.getFeatures() }, new INDArray[] { next.getLabels() }, (hasMaskArrays ? new INDArray[] { next.getFeaturesMaskArray() } : null), (hasMaskArrays ? new INDArray[] { next.getLabelsMaskArray() } : null));
            } else {
                setInput(0, next.getFeatures());
                setLabel(0, next.getLabels());
                if (solver == null) {
                    solver = //TODO; don't like this
                    new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
                }
                solver.optimize();
            }
            if (hasMaskArrays) {
                clearLayerMaskArrays();
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
        }
    }
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochEnd(this);
        }
    }
}
Also used : Solver(org.deeplearning4j.optimize.Solver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSet(org.nd4j.linalg.dataset.api.DataSet) TrainingListener(org.deeplearning4j.optimize.api.TrainingListener) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) SingletonMultiDataSetIterator(org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Example 4 with TrainingListener

use of org.deeplearning4j.optimize.api.TrainingListener in project deeplearning4j by deeplearning4j.

the class ComputationGraph method setListeners.

/**
     * Set the IterationListeners for the ComputationGraph (and all layers in the network)
     */
public void setListeners(Collection<IterationListener> listeners) {
    this.listeners = listeners;
    if (layers == null)
        init();
    for (Layer l : layers) {
        l.setListeners(listeners);
    }
    if (solver != null) {
        solver.setListeners(listeners);
    }
    this.trainingListeners.clear();
    if (listeners != null) {
        for (IterationListener il : listeners) {
            if (il instanceof TrainingListener) {
                this.trainingListeners.add((TrainingListener) il);
            }
        }
    }
}
Also used : IterationListener(org.deeplearning4j.optimize.api.IterationListener) TrainingListener(org.deeplearning4j.optimize.api.TrainingListener) Layer(org.deeplearning4j.nn.api.Layer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer)

Example 5 with TrainingListener

use of org.deeplearning4j.optimize.api.TrainingListener in project deeplearning4j by deeplearning4j.

the class VariationalAutoencoder method computeGradientAndScore.

@Override
public void computeGradientAndScore() {
    //Forward pass through the encoder and mean for P(Z|X)
    VAEFwdHelper fwd = doForward(true, true);
    IActivation afn = conf().getLayer().getActivationFn();
    //Forward pass through logStd^2 for P(Z|X)
    INDArray pzxLogStd2W = params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W);
    INDArray pzxLogStd2b = params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B);
    INDArray pzxLogStd2Pre = fwd.encoderActivations[fwd.encoderActivations.length - 1].mmul(pzxLogStd2W).addiRowVector(pzxLogStd2b);
    INDArray meanZ = fwd.pzxMeanPreOut.dup();
    INDArray logStdev2Z = pzxLogStd2Pre.dup();
    pzxActivationFn.getActivation(meanZ, true);
    pzxActivationFn.getActivation(logStdev2Z, true);
    INDArray pzxSigmaSquared = Transforms.exp(logStdev2Z, true);
    INDArray pzxSigma = Transforms.sqrt(pzxSigmaSquared, true);
    int minibatch = input.size(0);
    int size = fwd.pzxMeanPreOut.size(1);
    Map<String, INDArray> gradientMap = new HashMap<>();
    double scaleFactor = 1.0 / numSamples;
    Level1 blasL1 = Nd4j.getBlasWrapper().level1();
    INDArray[] encoderActivationDerivs = (numSamples > 1 ? new INDArray[encoderLayerSizes.length] : null);
    for (int l = 0; l < numSamples; l++) {
        //Default (and in most cases) numSamples == 1
        //0 for first one (to get rid of previous buffer data), otherwise 1 (for adding)
        double gemmCConstant = (l == 0 ? 0.0 : 1.0);
        INDArray e = Nd4j.randn(minibatch, size);
        //z = mu + sigma * e, with e ~ N(0,1)
        INDArray z = pzxSigma.mul(e).addi(meanZ);
        //Need to do forward pass through decoder layers
        int nDecoderLayers = decoderLayerSizes.length;
        INDArray current = z;
        //Need pre-out for backprop later
        INDArray[] decoderPreOut = new INDArray[nDecoderLayers];
        INDArray[] decoderActivations = new INDArray[nDecoderLayers];
        for (int i = 0; i < nDecoderLayers; i++) {
            String wKey = "d" + i + WEIGHT_KEY_SUFFIX;
            String bKey = "d" + i + BIAS_KEY_SUFFIX;
            INDArray weights = params.get(wKey);
            INDArray bias = params.get(bKey);
            current = current.mmul(weights).addiRowVector(bias);
            decoderPreOut[i] = current.dup();
            afn.getActivation(current, true);
            decoderActivations[i] = current;
        }
        INDArray pxzw = params.get(VariationalAutoencoderParamInitializer.PXZ_W);
        INDArray pxzb = params.get(VariationalAutoencoderParamInitializer.PXZ_B);
        if (l == 0) {
            //Need to add other component of score, in addition to negative log probability
            //Note the negative here vs. the equation in Kingma & Welling: this is because we are minimizing the negative of
            // variational lower bound, rather than maximizing the variational lower bound
            //Unlike log probability (which is averaged over samples) this should be calculated just once
            INDArray temp = meanZ.mul(meanZ).addi(pzxSigmaSquared).negi();
            temp.addi(logStdev2Z).addi(1.0);
            double scorePt1 = -0.5 / minibatch * temp.sumNumber().doubleValue();
            this.score = scorePt1 + (calcL1(false) + calcL2(false)) / minibatch;
        }
        INDArray pxzDistributionPreOut = current.mmul(pxzw).addiRowVector(pxzb);
        double logPTheta = reconstructionDistribution.negLogProbability(input, pxzDistributionPreOut, true);
        this.score += logPTheta / numSamples;
        //If we have any training listeners (for example, for UI StatsListener - pass on activations)
        if (trainingListeners != null && trainingListeners.size() > 0 && l == 0) {
            //Note: only doing this on the *first* sample
            Map<String, INDArray> activations = new LinkedHashMap<>();
            for (int i = 0; i < fwd.encoderActivations.length; i++) {
                activations.put("e" + i, fwd.encoderActivations[i]);
            }
            activations.put(VariationalAutoencoderParamInitializer.PZX_PREFIX, z);
            for (int i = 0; i < decoderActivations.length; i++) {
                activations.put("d" + i, decoderActivations[i]);
            }
            activations.put(VariationalAutoencoderParamInitializer.PXZ_PREFIX, reconstructionDistribution.generateAtMean(pxzDistributionPreOut));
            for (TrainingListener tl : trainingListeners) {
                tl.onForwardPass(this, activations);
            }
        }
        /////////////////////////////////////////////////////////
        //Backprop
        //First: calculate the gradients at the input to the reconstruction distribution
        INDArray dpdpxz = reconstructionDistribution.gradient(input, pxzDistributionPreOut);
        //Do backprop for output reconstruction distribution -> final decoder layer
        INDArray dLdxzw = gradientViews.get(VariationalAutoencoderParamInitializer.PXZ_W);
        INDArray dLdxzb = gradientViews.get(VariationalAutoencoderParamInitializer.PXZ_B);
        INDArray lastDecActivations = decoderActivations[decoderActivations.length - 1];
        Nd4j.gemm(lastDecActivations, dpdpxz, dLdxzw, true, false, scaleFactor, gemmCConstant);
        if (l == 0) {
            //TODO: do this without the assign
            dLdxzb.assign(dpdpxz.sum(0));
            if (numSamples > 1) {
                dLdxzb.muli(scaleFactor);
            }
        } else {
            blasL1.axpy(dLdxzb.length(), scaleFactor, dpdpxz.sum(0), dLdxzb);
        }
        gradientMap.put(VariationalAutoencoderParamInitializer.PXZ_W, dLdxzw);
        gradientMap.put(VariationalAutoencoderParamInitializer.PXZ_B, dLdxzb);
        INDArray epsilon = pxzw.mmul(dpdpxz.transpose()).transpose();
        //Next: chain derivatives backwards through the decoder layers
        for (int i = nDecoderLayers - 1; i >= 0; i--) {
            String wKey = "d" + i + WEIGHT_KEY_SUFFIX;
            String bKey = "d" + i + BIAS_KEY_SUFFIX;
            //TODO activation functions with params
            INDArray currentDelta = afn.backprop(decoderPreOut[i], epsilon).getFirst();
            INDArray weights = params.get(wKey);
            INDArray dLdW = gradientViews.get(wKey);
            INDArray dLdB = gradientViews.get(bKey);
            INDArray actInput;
            if (i == 0) {
                actInput = z;
            } else {
                actInput = decoderActivations[i - 1];
            }
            Nd4j.gemm(actInput, currentDelta, dLdW, true, false, scaleFactor, gemmCConstant);
            if (l == 0) {
                //TODO: do this without the assign
                dLdB.assign(currentDelta.sum(0));
                if (numSamples > 1) {
                    dLdB.muli(scaleFactor);
                }
            } else {
                blasL1.axpy(dLdB.length(), scaleFactor, currentDelta.sum(0), dLdB);
            }
            gradientMap.put(wKey, dLdW);
            gradientMap.put(bKey, dLdB);
            epsilon = weights.mmul(currentDelta.transpose()).transpose();
        }
        //Do backprop through p(z|x)
        INDArray eZXMeanW = params.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W);
        INDArray eZXLogStdev2W = params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W);
        INDArray dLdz = epsilon;
        //If we were maximizing the equation in Kinga and Welling, this would be a .sub(meanZ). Here: we are minimizing the negative instead
        INDArray dLdmu = dLdz.add(meanZ);
        INDArray dLdLogSigma2 = dLdz.mul(e).muli(pzxSigma).addi(pzxSigmaSquared).subi(1).muli(0.5);
        INDArray dLdPreMu = pzxActivationFn.backprop(fwd.getPzxMeanPreOut().dup(), dLdmu).getFirst();
        INDArray dLdPreLogSigma2 = pzxActivationFn.backprop(pzxLogStd2Pre.dup(), dLdLogSigma2).getFirst();
        //Weight gradients for weights feeding into p(z|x)
        INDArray lastEncoderActivation = fwd.encoderActivations[fwd.encoderActivations.length - 1];
        INDArray dLdZXMeanW = gradientViews.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W);
        INDArray dLdZXLogStdev2W = gradientViews.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W);
        Nd4j.gemm(lastEncoderActivation, dLdPreMu, dLdZXMeanW, true, false, scaleFactor, gemmCConstant);
        Nd4j.gemm(lastEncoderActivation, dLdPreLogSigma2, dLdZXLogStdev2W, true, false, scaleFactor, gemmCConstant);
        //Bias gradients for p(z|x)
        INDArray dLdZXMeanb = gradientViews.get(VariationalAutoencoderParamInitializer.PZX_MEAN_B);
        INDArray dLdZXLogStdev2b = gradientViews.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B);
        //If we were maximizing the equation in Kinga and Welling, this would be a .sub(meanZ). Here: we are minimizing the negative instead
        if (l == 0) {
            dLdZXMeanb.assign(pzxActivationFn.backprop(fwd.getPzxMeanPreOut().dup(), dLdz.add(meanZ)).getFirst().sum(0));
            dLdZXLogStdev2b.assign(dLdPreLogSigma2.sum(0));
            if (numSamples > 1) {
                dLdZXMeanb.muli(scaleFactor);
                dLdZXLogStdev2b.muli(scaleFactor);
            }
        } else {
            blasL1.axpy(dLdZXMeanb.length(), scaleFactor, pzxActivationFn.backprop(fwd.getPzxMeanPreOut().dup(), dLdz.add(meanZ)).getFirst().sum(0), dLdZXMeanb);
            blasL1.axpy(dLdZXLogStdev2b.length(), scaleFactor, dLdPreLogSigma2.sum(0), dLdZXLogStdev2b);
        }
        gradientMap.put(VariationalAutoencoderParamInitializer.PZX_MEAN_W, dLdZXMeanW);
        gradientMap.put(VariationalAutoencoderParamInitializer.PZX_MEAN_B, dLdZXMeanb);
        gradientMap.put(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W, dLdZXLogStdev2W);
        gradientMap.put(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B, dLdZXLogStdev2b);
        //Epsilon (dL/dActivation) at output of the last encoder layer:
        //Equivalent to: epsilon = eZXMeanW.mmul(dLdPreMu.transpose()).transpose(); using   (AxB^T)^T = BxA^T
        epsilon = Nd4j.gemm(dLdPreMu, eZXMeanW, false, true);
        //Next line: equivalent to epsilon.addi(eZXLogStdev2W.mmul(dLdPreLogSigma2.transpose()).transpose());       using: (AxB^T)^T = BxA^T
        Nd4j.gemm(dLdPreLogSigma2, eZXLogStdev2W, epsilon, false, true, 1.0, 1.0);
        //Backprop through encoder:
        int nEncoderLayers = encoderLayerSizes.length;
        for (int i = nEncoderLayers - 1; i >= 0; i--) {
            String wKey = "e" + i + WEIGHT_KEY_SUFFIX;
            String bKey = "e" + i + BIAS_KEY_SUFFIX;
            INDArray weights = params.get(wKey);
            INDArray dLdW = gradientViews.get(wKey);
            INDArray dLdB = gradientViews.get(bKey);
            INDArray preOut = fwd.encoderPreOuts[i];
            INDArray currentDelta;
            if (numSamples > 1) {
                // only the errors do
                if (l == 0) {
                    //Not the most elegent implementation (with the ND4j.ones()), but it works...
                    encoderActivationDerivs[i] = afn.backprop(fwd.encoderPreOuts[i], Nd4j.ones(fwd.encoderPreOuts[i].shape())).getFirst();
                }
                currentDelta = epsilon.muli(encoderActivationDerivs[i]);
            } else {
                currentDelta = afn.backprop(preOut, epsilon).getFirst();
            }
            INDArray actInput;
            if (i == 0) {
                actInput = input;
            } else {
                actInput = fwd.encoderActivations[i - 1];
            }
            Nd4j.gemm(actInput, currentDelta, dLdW, true, false, scaleFactor, gemmCConstant);
            if (l == 0) {
                //TODO: do this without the assign
                dLdB.assign(currentDelta.sum(0));
                if (numSamples > 1) {
                    dLdB.muli(scaleFactor);
                }
            } else {
                blasL1.axpy(dLdB.length(), scaleFactor, currentDelta.sum(0), dLdB);
            }
            gradientMap.put(wKey, dLdW);
            gradientMap.put(bKey, dLdB);
            epsilon = weights.mmul(currentDelta.transpose()).transpose();
        }
    }
    //Insert the gradients into the Gradient map in the correct order, in case we need to flatten the gradient later
    // to match the parameters iteration order
    Gradient gradient = new DefaultGradient(gradientsFlattened);
    Map<String, INDArray> g = gradient.gradientForVariable();
    for (int i = 0; i < encoderLayerSizes.length; i++) {
        String w = "e" + i + VariationalAutoencoderParamInitializer.WEIGHT_KEY_SUFFIX;
        g.put(w, gradientMap.get(w));
        String b = "e" + i + VariationalAutoencoderParamInitializer.BIAS_KEY_SUFFIX;
        g.put(b, gradientMap.get(b));
    }
    g.put(VariationalAutoencoderParamInitializer.PZX_MEAN_W, gradientMap.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W));
    g.put(VariationalAutoencoderParamInitializer.PZX_MEAN_B, gradientMap.get(VariationalAutoencoderParamInitializer.PZX_MEAN_B));
    g.put(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W, gradientMap.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W));
    g.put(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B, gradientMap.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B));
    for (int i = 0; i < decoderLayerSizes.length; i++) {
        String w = "d" + i + VariationalAutoencoderParamInitializer.WEIGHT_KEY_SUFFIX;
        g.put(w, gradientMap.get(w));
        String b = "d" + i + VariationalAutoencoderParamInitializer.BIAS_KEY_SUFFIX;
        g.put(b, gradientMap.get(b));
    }
    g.put(VariationalAutoencoderParamInitializer.PXZ_W, gradientMap.get(VariationalAutoencoderParamInitializer.PXZ_W));
    g.put(VariationalAutoencoderParamInitializer.PXZ_B, gradientMap.get(VariationalAutoencoderParamInitializer.PXZ_B));
    this.gradient = gradient;
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) TrainingListener(org.deeplearning4j.optimize.api.TrainingListener) IActivation(org.nd4j.linalg.activations.IActivation) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Level1(org.nd4j.linalg.api.blas.Level1)

Aggregations

TrainingListener (org.deeplearning4j.optimize.api.TrainingListener)8 INDArray (org.nd4j.linalg.api.ndarray.INDArray)5 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)3 AsyncDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)2 RecurrentLayer (org.deeplearning4j.nn.api.layers.RecurrentLayer)2 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)2 FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)2 Solver (org.deeplearning4j.optimize.Solver)2 IterationListener (org.deeplearning4j.optimize.api.IterationListener)2 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)2 Pair (org.deeplearning4j.berkeley.Pair)1 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)1 SingletonMultiDataSetIterator (org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator)1 Layer (org.deeplearning4j.nn.api.Layer)1 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)1 Gradient (org.deeplearning4j.nn.gradient.Gradient)1 GraphVertex (org.deeplearning4j.nn.graph.vertex.GraphVertex)1 IActivation (org.nd4j.linalg.activations.IActivation)1 Level1 (org.nd4j.linalg.api.blas.Level1)1 DataSet (org.nd4j.linalg.dataset.DataSet)1