Search in sources :

Example 1 with Updater

use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.

the class ParallelWrapper method fit.

/**
     * This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors
     *
     * @param source
     */
public synchronized void fit(@NonNull DataSetIterator source) {
    stopFit.set(false);
    if (zoo == null) {
        zoo = new Trainer[workers];
        for (int cnt = 0; cnt < workers; cnt++) {
            zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread());
            // if if we're using MQ here - we'd like
            if (isMQ)
                Nd4j.getAffinityManager().attachThreadToDevice(zoo[cnt], cnt % Nd4j.getAffinityManager().getNumberOfDevices());
            zoo[cnt].setUncaughtExceptionHandler(handler);
            zoo[cnt].start();
        }
    }
    source.reset();
    DataSetIterator iterator;
    if (prefetchSize > 0 && source.asyncSupported()) {
        if (isMQ) {
            if (workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0)
                log.warn("Number of workers [{}] isn't optimal for available devices [{}]", workers, Nd4j.getAffinityManager().getNumberOfDevices());
            MagicQueue queue = new MagicQueue.Builder().setCapacityPerFlow(8).setMode(MagicQueue.Mode.SEQUENTIAL).setNumberOfBuckets(Nd4j.getAffinityManager().getNumberOfDevices()).build();
            iterator = new AsyncDataSetIterator(source, prefetchSize, queue);
        } else
            iterator = new AsyncDataSetIterator(source, prefetchSize);
    } else
        iterator = source;
    AtomicInteger locker = new AtomicInteger(0);
    int whiles = 0;
    while (iterator.hasNext() && !stopFit.get()) {
        whiles++;
        DataSet dataSet = iterator.next();
        if (dataSet == null)
            throw new ND4JIllegalStateException("You can't have NULL as DataSet");
        /*
             now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
            */
        int pos = locker.getAndIncrement();
        if (zoo == null)
            throw new IllegalStateException("ParallelWrapper.shutdown() has been called too early and will fail from this point forward.");
        zoo[pos].feedDataSet(dataSet);
        /*
                if all workers are dispatched now, join till all are finished
            */
        if (pos + 1 == workers || !iterator.hasNext()) {
            iterationsCounter.incrementAndGet();
            for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                try {
                    zoo[cnt].waitTillRunning();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
            /*
                    average model, and propagate it to whole
                */
            if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
                double score = getScore(locker);
                // averaging updaters state
                if (model instanceof MultiLayerNetwork) {
                    if (averageUpdaters) {
                        Updater updater = ((MultiLayerNetwork) model).getUpdater();
                        int batchSize = 0;
                        if (updater != null && updater.getStateViewArray() != null) {
                            if (!legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
                                List<INDArray> updaters = new ArrayList<>();
                                for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                                    MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
                                    updaters.add(workerModel.getUpdater().getStateViewArray());
                                    batchSize += workerModel.batchSize();
                                }
                                Nd4j.averageAndPropagate(updater.getStateViewArray(), updaters);
                            } else {
                                INDArray state = Nd4j.zeros(updater.getStateViewArray().shape());
                                int cnt = 0;
                                for (; cnt < workers && cnt < locker.get(); cnt++) {
                                    MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
                                    state.addi(workerModel.getUpdater().getStateViewArray().dup());
                                    batchSize += workerModel.batchSize();
                                }
                                state.divi(cnt);
                                updater.setStateViewArray((MultiLayerNetwork) model, state, false);
                            }
                        }
                    }
                    ((MultiLayerNetwork) model).setScore(score);
                } else if (model instanceof ComputationGraph) {
                    averageUpdatersState(locker, score);
                }
                if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
                    for (int cnt = 0; cnt < workers; cnt++) {
                        zoo[cnt].updateModel(model);
                    }
                }
            }
            locker.set(0);
        }
    }
    // sanity checks, or the dataset may never average
    if (!wasAveraged)
        log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
    //            throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
    log.debug("Iterations passed: {}", iterationsCounter.get());
//        iterationsCounter.set(0);
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) DataSet(org.nd4j.linalg.dataset.api.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Example 2 with Updater

use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.

the class GradientCheckUtil method checkGradients.

/**
     * Check backprop gradients for a MultiLayerNetwork.
     * @param mln MultiLayerNetwork to test. This must be initialized.
     * @param epsilon Usually on the order/ of 1e-4 or so.
     * @param maxRelError Maximum relative error. Usually < 1e-5 or so, though maybe more for deep networks or those with nonlinear activation
     * @param minAbsoluteError Minimum absolute error to cause a failure. Numerical gradients can be non-zero due to precision issues.
     *                         For example, 0.0 vs. 1e-18: relative error is 1.0, but not really a failure
     * @param print Whether to print full pass/failure details for each parameter gradient
     * @param exitOnFirstError If true: return upon first failure. If false: continue checking even if
     *  one parameter gradient has failed. Typically use false for debugging, true for unit tests.
     * @param input Input array to use for forward pass. May be mini-batch data.
     * @param labels Labels/targets to use to calculate backprop gradient. May be mini-batch data.
     * @return true if gradients are passed, false otherwise.
     */
public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels) {
    //Basic sanity checks on input:
    if (epsilon <= 0.0 || epsilon > 0.1)
        throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
    if (maxRelError <= 0.0 || maxRelError > 0.25)
        throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
    if (!(mln.getOutputLayer() instanceof IOutputLayer))
        throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer");
    //Check network configuration:
    int layerCount = 0;
    for (NeuralNetConfiguration n : mln.getLayerWiseConfigurations().getConfs()) {
        org.deeplearning4j.nn.conf.Updater u = n.getLayer().getUpdater();
        if (u == org.deeplearning4j.nn.conf.Updater.SGD) {
            //Must have LR of 1.0
            double lr = n.getLayer().getLearningRate();
            if (lr != 1.0) {
                throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer " + layerCount + "; got " + u + " with lr=" + lr + " for layer \"" + n.getLayer().getLayerName() + "\"");
            }
        } else if (u != org.deeplearning4j.nn.conf.Updater.NONE) {
            throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer " + layerCount + "; got " + u);
        }
        double dropout = n.getLayer().getDropOut();
        if (n.isUseRegularization() && dropout != 0.0) {
            throw new IllegalStateException("Must have dropout == 0.0 for gradient checks - got dropout = " + dropout + " for layer " + layerCount);
        }
        IActivation activation = n.getLayer().getActivationFn();
        if (activation != null) {
            if (!VALID_ACTIVATION_FUNCTIONS.contains(activation.getClass())) {
                log.warn("Layer " + layerCount + " is possibly using an unsuitable activation function: " + activation.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not " + "contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
            }
        }
    }
    mln.setInput(input);
    mln.setLabels(labels);
    mln.computeGradientAndScore();
    Pair<Gradient, Double> gradAndScore = mln.gradientAndScore();
    Updater updater = UpdaterCreator.getUpdater(mln);
    updater.update(mln, gradAndScore.getFirst(), 0, mln.batchSize());
    //need dup: gradients are a *view* of the full gradient array (which will change every time backprop is done)
    INDArray gradientToCheck = gradAndScore.getFirst().gradient().dup();
    //need dup: params are a *view* of full parameters
    INDArray originalParams = mln.params().dup();
    int nParams = originalParams.length();
    Map<String, INDArray> paramTable = mln.paramTable();
    List<String> paramNames = new ArrayList<>(paramTable.keySet());
    int[] paramEnds = new int[paramNames.size()];
    paramEnds[0] = paramTable.get(paramNames.get(0)).length();
    for (int i = 1; i < paramEnds.length; i++) {
        paramEnds[i] = paramEnds[i - 1] + paramTable.get(paramNames.get(i)).length();
    }
    int totalNFailures = 0;
    double maxError = 0.0;
    DataSet ds = new DataSet(input, labels);
    int currParamNameIdx = 0;
    //Assumption here: params is a view that we can modify in-place
    INDArray params = mln.params();
    for (int i = 0; i < nParams; i++) {
        //Get param name
        if (i >= paramEnds[currParamNameIdx]) {
            currParamNameIdx++;
        }
        String paramName = paramNames.get(currParamNameIdx);
        //(w+epsilon): Do forward pass and score
        double origValue = params.getDouble(i);
        params.putScalar(i, origValue + epsilon);
        double scorePlus = mln.score(ds, true);
        //(w-epsilon): Do forward pass and score
        params.putScalar(i, origValue - epsilon);
        double scoreMinus = mln.score(ds, true);
        //Reset original param value
        params.putScalar(i, origValue);
        //Calculate numerical parameter gradient:
        double scoreDelta = scorePlus - scoreMinus;
        double numericalGradient = scoreDelta / (2 * epsilon);
        if (Double.isNaN(numericalGradient))
            throw new IllegalStateException("Numerical gradient was NaN for parameter " + i + " of " + nParams);
        double backpropGradient = gradientToCheck.getDouble(i);
        //http://cs231n.github.io/neural-networks-3/#gradcheck
        //use mean centered
        double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
        if (backpropGradient == 0.0 && numericalGradient == 0.0)
            //Edge case: i.e., RNNs with time series length of 1.0
            relError = 0.0;
        if (relError > maxError)
            maxError = relError;
        if (relError > maxRelError || Double.isNaN(relError)) {
            double absError = Math.abs(backpropGradient - numericalGradient);
            if (absError < minAbsoluteError) {
                log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsoluteError);
            } else {
                if (print)
                    log.info("Param " + i + " (" + paramName + ") FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
                if (exitOnFirstError)
                    return false;
                totalNFailures++;
            }
        } else if (print) {
            log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
        }
    }
    if (print) {
        int nPass = nParams - totalNFailures;
        log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
    }
    return totalNFailures == 0;
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) ArrayList(java.util.ArrayList) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) IActivation(org.nd4j.linalg.activations.IActivation) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer)

Example 3 with Updater

use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.

the class TestDecayPolicies method testLearningRateSigmoidDecaySingleLayer.

@Test
public void testLearningRateSigmoidDecaySingleLayer() {
    int iterations = 2;
    double lr = 1e-2;
    double decayRate = 2;
    double steps = 3;
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr).learningRateDecayPolicy(LearningRatePolicy.Sigmoid).lrPolicyDecayRate(decayRate).lrPolicySteps(steps).iterations(iterations).layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).updater(org.deeplearning4j.nn.conf.Updater.SGD).build()).build();
    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = Nd4j.create(1, numParams);
    Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
    Updater updater = UpdaterCreator.getUpdater(layer);
    Gradient gradientActual = new DefaultGradient();
    gradientActual.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient);
    gradientActual.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient);
    for (int i = 0; i < iterations; i++) {
        updater.update(layer, gradientActual, i, 1);
        double expectedLr = calcSigmoidDecay(layer.conf().getLearningRateByParam("W"), decayRate, i, steps);
        assertEquals(expectedLr, layer.conf().getLearningRateByParam("W"), 1e-4);
        assertEquals(expectedLr, layer.conf().getLearningRateByParam("b"), 1e-4);
    }
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Updater(org.deeplearning4j.nn.api.Updater) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) Layer(org.deeplearning4j.nn.api.Layer) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) Test(org.junit.Test)

Example 4 with Updater

use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.

the class TestDecayPolicies method testMomentumScheduleMLN.

@Test
public void testMomentumScheduleMLN() {
    double lr = 1e-2;
    double mu = 0.6;
    Map<Integer, Double> momentumAfter = new HashMap<>();
    momentumAfter.put(1, 0.2);
    int iterations = 2;
    int[] nIns = { 4, 2 };
    int[] nOuts = { 2, 3 };
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr).momentum(mu).momentumAfter(momentumAfter).iterations(iterations).list().layer(0, new DenseLayer.Builder().nIn(nIns[0]).nOut(nOuts[0]).updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS).build()).layer(1, new OutputLayer.Builder().nIn(nIns[1]).nOut(nOuts[1]).updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS).build()).backprop(true).pretrain(false).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    Updater updater = UpdaterCreator.getUpdater(net);
    int stateSize = updater.stateSizeForLayer(net);
    updater.setStateViewArray(net, Nd4j.create(1, stateSize), true);
    String wKey, bKey;
    Gradient gradientExpected = new DefaultGradient();
    for (int k = 0; k < net.getnLayers(); k++) {
        wKey = String.valueOf(k) + "_" + DefaultParamInitializer.WEIGHT_KEY;
        gradientExpected.setGradientFor(wKey, Nd4j.ones(nIns[k], nOuts[k]));
        bKey = String.valueOf(k) + "_" + DefaultParamInitializer.BIAS_KEY;
        gradientExpected.setGradientFor(bKey, Nd4j.ones(1, nOuts[k]));
    }
    Gradient gradientMLN = new DefaultGradient();
    for (int j = 0; j < 2; j++) {
        wKey = String.valueOf(j) + "_" + DefaultParamInitializer.WEIGHT_KEY;
        gradientMLN.setGradientFor(wKey, Nd4j.ones(nIns[j], nOuts[j]));
        bKey = String.valueOf(j) + "_" + DefaultParamInitializer.BIAS_KEY;
        gradientMLN.setGradientFor(bKey, Nd4j.ones(1, nOuts[j]));
    }
    for (int i = 0; i < 2; i++) {
        updater.update(net, gradientMLN, i, 1);
        mu = testNesterovsComputation(gradientMLN, gradientExpected, lr, mu, momentumAfter, i);
        assertEquals(mu, net.getLayer(1).conf().getLayer().getMomentum(), 1e-4);
    }
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) HashMap(java.util.HashMap) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) Updater(org.deeplearning4j.nn.api.Updater) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) Test(org.junit.Test)

Example 5 with Updater

use of org.deeplearning4j.nn.api.Updater in project deeplearning4j by deeplearning4j.

the class TestDecayPolicies method testLearningRateExponentialDecaySingleLayer.

@Test
public void testLearningRateExponentialDecaySingleLayer() {
    int iterations = 2;
    double lr = 1e-2;
    double decayRate = 2;
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr).learningRateDecayPolicy(LearningRatePolicy.Exponential).lrPolicyDecayRate(decayRate).iterations(iterations).layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).updater(org.deeplearning4j.nn.conf.Updater.SGD).build()).build();
    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = Nd4j.create(1, numParams);
    Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
    Updater updater = UpdaterCreator.getUpdater(layer);
    Gradient gradientActual = new DefaultGradient();
    gradientActual.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient);
    gradientActual.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient);
    for (int i = 0; i < iterations; i++) {
        updater.update(layer, gradientActual, i, 1);
        double expectedLr = calcExponentialDecay(lr, decayRate, i);
        assertEquals(expectedLr, layer.conf().getLearningRateByParam("W"), 1e-4);
        assertEquals(expectedLr, layer.conf().getLearningRateByParam("b"), 1e-4);
    }
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Updater(org.deeplearning4j.nn.api.Updater) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) Layer(org.deeplearning4j.nn.api.Layer) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) Test(org.junit.Test)

Aggregations

Updater (org.deeplearning4j.nn.api.Updater)37 Test (org.junit.Test)28 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)27 INDArray (org.nd4j.linalg.api.ndarray.INDArray)27 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)25 Gradient (org.deeplearning4j.nn.gradient.Gradient)25 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)23 Layer (org.deeplearning4j.nn.api.Layer)21 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)18 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)9 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)8 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)5 HashMap (java.util.HashMap)4 Solver (org.deeplearning4j.optimize.Solver)4 ArrayList (java.util.ArrayList)2 Field (java.lang.reflect.Field)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 ZipEntry (java.util.zip.ZipEntry)1 ZipFile (java.util.zip.ZipFile)1 Persistable (org.deeplearning4j.api.storage.Persistable)1