Search in sources :

Example 1 with QuickPropagation

use of org.encog.neural.networks.training.propagation.quick.QuickPropagation in project shifu by ShifuML.

the class LogisticRegressionTrainer method train.

/**
 * {@inheritDoc}
 * <p>
 * no <code>regularization</code>
 * <p>
 * Regular will be provide later
 * <p>
 *
 * @throws IOException
 *             e
 */
@Override
public double train() throws IOException {
    classifier = new BasicNetwork();
    classifier.addLayer(new BasicLayer(new ActivationLinear(), true, trainSet.getInputSize()));
    classifier.addLayer(new BasicLayer(new ActivationSigmoid(), false, trainSet.getIdealSize()));
    classifier.getStructure().finalizeStructure();
    // resetParams(classifier);
    classifier.reset();
    // Propagation mlTrain = getMLTrain();
    Propagation propagation = new QuickPropagation(classifier, trainSet, (Double) modelConfig.getParams().get("LearningRate"));
    int epochs = modelConfig.getNumTrainEpochs();
    // Get convergence threshold from modelConfig.
    double threshold = modelConfig.getTrain().getConvergenceThreshold() == null ? 0.0 : modelConfig.getTrain().getConvergenceThreshold().doubleValue();
    String formatedThreshold = df.format(threshold);
    LOG.info("Using " + (Double) modelConfig.getParams().get("LearningRate") + " training rate");
    for (int i = 0; i < epochs; i++) {
        propagation.iteration();
        double trainError = propagation.getError();
        double validError = classifier.calculateError(this.validSet);
        LOG.info("Epoch #" + (i + 1) + " Train Error:" + df.format(trainError) + " Validation Error:" + df.format(validError));
        // Convergence judging.
        double avgErr = (trainError + validError) / 2;
        if (judger.judge(avgErr, threshold)) {
            LOG.info("Trainer-{}> Epoch #{} converged! Average Error: {}, Threshold: {}", trainerID, (i + 1), df.format(avgErr), formatedThreshold);
            break;
        }
    }
    propagation.finishTraining();
    LOG.info("#" + this.trainerID + " finish training");
    saveLR();
    return 0.0d;
}
Also used : BasicNetwork(org.encog.neural.networks.BasicNetwork) Propagation(org.encog.neural.networks.training.propagation.Propagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) ActivationLinear(org.encog.engine.network.activation.ActivationLinear) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) ActivationSigmoid(org.encog.engine.network.activation.ActivationSigmoid) BasicLayer(org.encog.neural.networks.layers.BasicLayer)

Example 2 with QuickPropagation

use of org.encog.neural.networks.training.propagation.quick.QuickPropagation in project shifu by ShifuML.

the class NNTrainerTest method testAndOperation.

@Test
public void testAndOperation() throws IOException {
    MLDataPair dataPair0 = BasicMLDataPair.createPair(2, 1);
    dataPair0.setInputArray(new double[] { 0.0, 0.0 });
    dataPair0.setIdealArray(new double[] { 0.0 });
    trainSet.add(dataPair0);
    MLDataPair dataPair1 = BasicMLDataPair.createPair(2, 1);
    dataPair1.setInputArray(new double[] { 0.0, 1.0 });
    dataPair1.setIdealArray(new double[] { 0.0 });
    trainSet.add(dataPair1);
    MLDataPair dataPair2 = BasicMLDataPair.createPair(2, 1);
    dataPair2.setInputArray(new double[] { 1.0, 0.0 });
    dataPair2.setIdealArray(new double[] { 0.0 });
    trainSet.add(dataPair2);
    MLDataPair dataPair3 = BasicMLDataPair.createPair(2, 1);
    dataPair3.setInputArray(new double[] { 1.0, 1.0 });
    dataPair3.setIdealArray(new double[] { 1.0 });
    trainSet.add(dataPair3);
    Propagation propagation = new QuickPropagation(network, trainSet, 0.1);
    double error = 0.0;
    double lastError = Double.MAX_VALUE;
    int iterCnt = 0;
    do {
        propagation.iteration();
        lastError = error;
        error = propagation.getError();
        System.out.println("The #" + (++iterCnt) + " error is " + error);
    } while (Math.abs(lastError - error) > 0.001);
    propagation.finishTraining();
    File tmp = new File("model_folder");
    if (!tmp.exists()) {
        FileUtils.forceMkdir(tmp);
    }
    File modelFile = new File("model_folder/model6.nn");
    EncogDirectoryPersistence.saveObject(modelFile, network);
    Assert.assertTrue(modelFile.exists());
    FileUtils.deleteQuietly(modelFile);
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) Propagation(org.encog.neural.networks.training.propagation.Propagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) File(java.io.File) Test(org.testng.annotations.Test)

Example 3 with QuickPropagation

use of org.encog.neural.networks.training.propagation.quick.QuickPropagation in project shifu by ShifuML.

the class DTrainTest method quickTest.

@Test
public void quickTest() throws IOException {
    double[] gradientError = new double[NUM_EPOCHS];
    double[] ecogError = new double[NUM_EPOCHS];
    network.reset();
    weights = network.getFlat().getWeights();
    MLDataSet[] subsets = splitDataSet(training);
    Gradient[] workers = new Gradient[numSplit];
    Weight weightCalculator = null;
    for (int i = 0; i < workers.length; i++) {
        workers[i] = initGradient(subsets[i]);
        workers[i].setWeights(weights);
    }
    log.info("Running QuickPropagtaion testing! ");
    NNParams globalParams = new NNParams();
    globalParams.setWeights(weights);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        double error = 0.0;
        // each worker do the job
        for (int j = 0; j < workers.length; j++) {
            workers[j].run();
            error += workers[j].getError();
        }
        gradientError[i] = error / workers.length;
        log.info("The #" + i + " training error: " + gradientError[i]);
        // master
        globalParams.reset();
        for (int j = 0; j < workers.length; j++) {
            globalParams.accumulateGradients(workers[j].getGradients());
            globalParams.accumulateTrainSize(subsets[j].getRecordCount());
        }
        if (weightCalculator == null) {
            weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.QUICK_PROPAGATION, 0, RegulationLevel.NONE);
        }
        double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
        globalParams.setWeights(interWeight);
        // set weights
        for (int j = 0; j < workers.length; j++) {
            workers[j].setWeights(interWeight);
        }
    }
    // encog
    network.reset();
    // NNUtils.randomize(numSplit, weights);
    network.getFlat().setWeights(weights);
    Propagation p = null;
    p = new QuickPropagation(network, training, rate);
    // p = new ManhattanPropagation(network, training, rate);
    p.setThreadCount(numSplit);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        p.iteration(1);
        // System.out.println("the #" + i + " training error: " + p.getError());
        ecogError[i] = p.getError();
    }
    // assert
    double diff = 0.0;
    for (int i = 0; i < NUM_EPOCHS; i++) {
        diff += Math.abs(ecogError[i] - gradientError[i]);
    }
    Assert.assertTrue(diff / NUM_EPOCHS < 0.1);
}
Also used : Propagation(org.encog.neural.networks.training.propagation.Propagation) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) NNParams(ml.shifu.shifu.core.dtrain.nn.NNParams) Test(org.testng.annotations.Test) BeforeTest(org.testng.annotations.BeforeTest)

Example 4 with QuickPropagation

use of org.encog.neural.networks.training.propagation.quick.QuickPropagation in project shifu by ShifuML.

the class NNTrainer method getMLTrain.

private Propagation getMLTrain() {
    // String alg = this.modelConfig.getLearningAlgorithm();
    String alg = (String) modelConfig.getParams().get(CommonConstants.PROPAGATION);
    if (!(defaultLearningRate.containsKey(alg))) {
        throw new RuntimeException("Learning algorithm is invalid: " + alg);
    }
    // Double rate = this.modelConfig.getLearningRate();
    double rate = defaultLearningRate.get(alg);
    Object rateObj = modelConfig.getParams().get(CommonConstants.LEARNING_RATE);
    if (rateObj instanceof Double) {
        rate = (Double) rateObj;
    } else if (rateObj instanceof Integer) {
        // change like this, because user may set it as integer
        rate = ((Integer) rateObj).doubleValue();
    } else if (rateObj instanceof Float) {
        rate = ((Float) rateObj).doubleValue();
    }
    if (toLoggingProcess)
        LOG.info("    - Learning Algorithm: " + learningAlgMap.get(alg));
    if (alg.equals("Q") || alg.equals("B") || alg.equals("M")) {
        if (toLoggingProcess)
            LOG.info("    - Learning Rate: " + rate);
    }
    if (alg.equals("B")) {
        return new Backpropagation(network, trainSet, rate, 0);
    } else if (alg.equals("Q")) {
        return new QuickPropagation(network, trainSet, rate);
    } else if (alg.equals("M")) {
        return new ManhattanPropagation(network, trainSet, rate);
    } else if (alg.equals("R")) {
        return new ResilientPropagation(network, trainSet);
    } else if (alg.equals("S")) {
        return new ScaledConjugateGradient(network, trainSet);
    } else {
        return null;
    }
}
Also used : Backpropagation(org.encog.neural.networks.training.propagation.back.Backpropagation) ScaledConjugateGradient(org.encog.neural.networks.training.propagation.scg.ScaledConjugateGradient) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) ModelInitInputObject(ml.shifu.shifu.container.ModelInitInputObject) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation)

Aggregations

QuickPropagation (org.encog.neural.networks.training.propagation.quick.QuickPropagation)4 Propagation (org.encog.neural.networks.training.propagation.Propagation)3 ManhattanPropagation (org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation)2 ResilientPropagation (org.encog.neural.networks.training.propagation.resilient.ResilientPropagation)2 Test (org.testng.annotations.Test)2 File (java.io.File)1 ModelInitInputObject (ml.shifu.shifu.container.ModelInitInputObject)1 NNParams (ml.shifu.shifu.core.dtrain.nn.NNParams)1 ActivationLinear (org.encog.engine.network.activation.ActivationLinear)1 ActivationSigmoid (org.encog.engine.network.activation.ActivationSigmoid)1 MLDataPair (org.encog.ml.data.MLDataPair)1 MLDataSet (org.encog.ml.data.MLDataSet)1 BasicMLDataPair (org.encog.ml.data.basic.BasicMLDataPair)1 BasicMLDataSet (org.encog.ml.data.basic.BasicMLDataSet)1 BasicNetwork (org.encog.neural.networks.BasicNetwork)1 BasicLayer (org.encog.neural.networks.layers.BasicLayer)1 Backpropagation (org.encog.neural.networks.training.propagation.back.Backpropagation)1 ScaledConjugateGradient (org.encog.neural.networks.training.propagation.scg.ScaledConjugateGradient)1 BeforeTest (org.testng.annotations.BeforeTest)1