Search in sources :

Example 1 with Propagation

use of org.encog.neural.networks.training.propagation.Propagation in project shifu by ShifuML.

the class NNTrainer method train.

@Override
public double train() throws IOException {
    if (toLoggingProcess)
        LOG.info("Using neural network algorithm...");
    if (toLoggingProcess) {
        if (this.dryRun) {
            LOG.info("Start Training(Dry Run)... Model #" + this.trainerID);
        } else {
            LOG.info("Start Training... Model #" + this.trainerID);
        }
        LOG.info("    - Input Size: " + trainSet.getInputSize());
        LOG.info("    - Ideal Size: " + trainSet.getIdealSize());
        LOG.info("    - Training Records Count: " + trainSet.getRecordCount());
        LOG.info("    - Validation Records Count: " + validSet.getRecordCount());
    }
    // set up the model
    buildNetwork();
    Propagation mlTrain = getMLTrain();
    mlTrain.setThreadCount(0);
    if (this.dryRun) {
        return 0.0;
    }
    int epochs = this.modelConfig.getNumTrainEpochs();
    int factor = Math.max(epochs / 50, 10);
    // Get convergence threshold from modelConfig.
    double threshold = modelConfig.getTrain().getConvergenceThreshold() == null ? 0.0 : modelConfig.getTrain().getConvergenceThreshold().doubleValue();
    String formatedThreshold = df.format(threshold);
    setBaseMSE(Double.MAX_VALUE);
    for (int i = 0; i < epochs; i++) {
        mlTrain.iteration();
        if (i % factor == 0) {
            this.saveTmpNN(i);
        }
        double validMSE = (this.validSet.getRecordCount() > 0) ? getValidSetError() : mlTrain.getError();
        String extra = "";
        if (validMSE < getBaseMSE()) {
            setBaseMSE(validMSE);
            saveNN();
            extra = " <-- NN saved: ./models/model" + this.trainerID + ".nn";
        }
        if (toLoggingProcess)
            LOG.info("  Trainer-" + trainerID + "> Epoch #" + (i + 1) + " Train Error: " + df.format(mlTrain.getError()) + " Validation Error: " + ((this.validSet.getRecordCount() > 0) ? df.format(validMSE) : "N/A") + " " + extra);
        // Convergence judging.
        double avgErr = (mlTrain.getError() + validMSE) / 2;
        if (judger.judge(avgErr, threshold)) {
            LOG.info("Trainer-{}> Epoch #{} converged! Average Error: {}, Threshold: {}", trainerID, (i + 1), df.format(avgErr), formatedThreshold);
            break;
        } else {
            if (toLoggingProcess) {
                LOG.info("Trainer-{}> Epoch #{} Average Error: {}, Threshold: {}", trainerID, (i + 1), df.format(avgErr), formatedThreshold);
            }
        }
    }
    mlTrain.finishTraining();
    if (toLoggingProcess)
        LOG.info("Trainer #" + this.trainerID + " is Finished!");
    return getBaseMSE();
}
Also used : Propagation(org.encog.neural.networks.training.propagation.Propagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation)

Example 2 with Propagation

use of org.encog.neural.networks.training.propagation.Propagation in project shifu by ShifuML.

the class LogisticRegressionTrainer method train.

/**
 * {@inheritDoc}
 * <p>
 * no <code>regularization</code>
 * <p>
 * Regular will be provide later
 * <p>
 *
 * @throws IOException
 *             e
 */
@Override
public double train() throws IOException {
    classifier = new BasicNetwork();
    classifier.addLayer(new BasicLayer(new ActivationLinear(), true, trainSet.getInputSize()));
    classifier.addLayer(new BasicLayer(new ActivationSigmoid(), false, trainSet.getIdealSize()));
    classifier.getStructure().finalizeStructure();
    // resetParams(classifier);
    classifier.reset();
    // Propagation mlTrain = getMLTrain();
    Propagation propagation = new QuickPropagation(classifier, trainSet, (Double) modelConfig.getParams().get("LearningRate"));
    int epochs = modelConfig.getNumTrainEpochs();
    // Get convergence threshold from modelConfig.
    double threshold = modelConfig.getTrain().getConvergenceThreshold() == null ? 0.0 : modelConfig.getTrain().getConvergenceThreshold().doubleValue();
    String formatedThreshold = df.format(threshold);
    LOG.info("Using " + (Double) modelConfig.getParams().get("LearningRate") + " training rate");
    for (int i = 0; i < epochs; i++) {
        propagation.iteration();
        double trainError = propagation.getError();
        double validError = classifier.calculateError(this.validSet);
        LOG.info("Epoch #" + (i + 1) + " Train Error:" + df.format(trainError) + " Validation Error:" + df.format(validError));
        // Convergence judging.
        double avgErr = (trainError + validError) / 2;
        if (judger.judge(avgErr, threshold)) {
            LOG.info("Trainer-{}> Epoch #{} converged! Average Error: {}, Threshold: {}", trainerID, (i + 1), df.format(avgErr), formatedThreshold);
            break;
        }
    }
    propagation.finishTraining();
    LOG.info("#" + this.trainerID + " finish training");
    saveLR();
    return 0.0d;
}
Also used : BasicNetwork(org.encog.neural.networks.BasicNetwork) Propagation(org.encog.neural.networks.training.propagation.Propagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) ActivationLinear(org.encog.engine.network.activation.ActivationLinear) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) ActivationSigmoid(org.encog.engine.network.activation.ActivationSigmoid) BasicLayer(org.encog.neural.networks.layers.BasicLayer)

Example 3 with Propagation

use of org.encog.neural.networks.training.propagation.Propagation in project shifu by ShifuML.

the class DTrainTest method manhantTest.

@Test
public void manhantTest() throws IOException {
    double[] gradientError = new double[NUM_EPOCHS];
    double[] ecogError = new double[NUM_EPOCHS];
    network.reset();
    weights = network.getFlat().getWeights();
    MLDataSet[] subsets = splitDataSet(training);
    Gradient[] workers = new Gradient[numSplit];
    Weight weightCalculator = null;
    for (int i = 0; i < workers.length; i++) {
        workers[i] = initGradient(subsets[i]);
        workers[i].setWeights(weights);
    }
    NNParams globalParams = new NNParams();
    globalParams.setWeights(weights);
    log.info("Starting manhattan propagation testing!");
    for (int i = 0; i < NUM_EPOCHS; i++) {
        double error = 0.0;
        // each worker do the job
        for (int j = 0; j < workers.length; j++) {
            workers[j].run();
            error += workers[j].getError();
        }
        gradientError[i] = error / workers.length;
        log.info("The #" + i + " training error: " + gradientError[i]);
        // master
        globalParams.reset();
        for (int j = 0; j < workers.length; j++) {
            globalParams.accumulateGradients(workers[j].getGradients());
            globalParams.accumulateTrainSize(subsets[j].getRecordCount());
        }
        if (weightCalculator == null) {
            weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.MANHATTAN_PROPAGATION, 0, RegulationLevel.NONE);
        }
        double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
        globalParams.setWeights(interWeight);
        // set weights
        for (int j = 0; j < workers.length; j++) {
            workers[j].setWeights(interWeight);
        }
    }
    // encog
    network.reset();
    // NNUtils.randomize(numSplit, weights);
    network.getFlat().setWeights(weights);
    Propagation p = null;
    p = new ManhattanPropagation(network, training, rate);
    p.setThreadCount(numSplit);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        p.iteration(1);
        // System.out.println("the #" + i + " training error: " + p.getError());
        ecogError[i] = p.getError();
    }
    // assert
    double diff = 0.0;
    for (int i = 0; i < NUM_EPOCHS; i++) {
        diff += Math.abs(ecogError[i] - gradientError[i]);
    }
    Assert.assertTrue(diff / NUM_EPOCHS < 0.3);
}
Also used : Propagation(org.encog.neural.networks.training.propagation.Propagation) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) NNParams(ml.shifu.shifu.core.dtrain.nn.NNParams) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation) Test(org.testng.annotations.Test) BeforeTest(org.testng.annotations.BeforeTest)

Example 4 with Propagation

use of org.encog.neural.networks.training.propagation.Propagation in project shifu by ShifuML.

the class DTrainTest method resilientPropagationTest.

@Test
public void resilientPropagationTest() {
    double[] gradientError = new double[NUM_EPOCHS];
    double[] ecogError = new double[NUM_EPOCHS];
    network.reset();
    weights = network.getFlat().getWeights();
    MLDataSet[] subsets = splitDataSet(training);
    Gradient[] workers = new Gradient[numSplit];
    Weight weightCalculator = null;
    for (int i = 0; i < workers.length; i++) {
        workers[i] = initGradient(subsets[i]);
        workers[i].setWeights(weights);
    }
    log.info("Starting resilient propagation testing!");
    NNParams globalParams = new NNParams();
    globalParams.setWeights(weights);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        double error = 0.0;
        // each worker do the job
        for (int j = 0; j < workers.length; j++) {
            workers[j].run();
            error += workers[j].getError();
        }
        gradientError[i] = error / workers.length;
        log.info("The #" + i + " training error: " + gradientError[i]);
        // master
        globalParams.reset();
        for (int j = 0; j < workers.length; j++) {
            globalParams.accumulateGradients(workers[j].getGradients());
            globalParams.accumulateTrainSize(subsets[j].getRecordCount());
        }
        if (weightCalculator == null) {
            weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.RESILIENTPROPAGATION, 0, RegulationLevel.NONE);
        }
        double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
        globalParams.setWeights(interWeight);
        // set weights
        for (int j = 0; j < workers.length; j++) {
            workers[j].setWeights(interWeight);
        }
    }
    // encog
    network.reset();
    // NNUtils.randomize(numSplit, weights);
    network.getFlat().setWeights(weights);
    Propagation p = null;
    p = new ResilientPropagation(network, training);
    p.setThreadCount(numSplit);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        p.iteration(1);
        ecogError[i] = p.getError();
    }
    // assert
    double diff = 0.0;
    for (int i = 0; i < NUM_EPOCHS; i++) {
        diff += Math.abs(ecogError[i] - gradientError[i]);
    }
    Assert.assertTrue(diff / NUM_EPOCHS < 0.2);
}
Also used : Propagation(org.encog.neural.networks.training.propagation.Propagation) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) NNParams(ml.shifu.shifu.core.dtrain.nn.NNParams) Test(org.testng.annotations.Test) BeforeTest(org.testng.annotations.BeforeTest)

Example 5 with Propagation

use of org.encog.neural.networks.training.propagation.Propagation in project shifu by ShifuML.

the class NNTrainerTest method testAndOperation.

@Test
public void testAndOperation() throws IOException {
    MLDataPair dataPair0 = BasicMLDataPair.createPair(2, 1);
    dataPair0.setInputArray(new double[] { 0.0, 0.0 });
    dataPair0.setIdealArray(new double[] { 0.0 });
    trainSet.add(dataPair0);
    MLDataPair dataPair1 = BasicMLDataPair.createPair(2, 1);
    dataPair1.setInputArray(new double[] { 0.0, 1.0 });
    dataPair1.setIdealArray(new double[] { 0.0 });
    trainSet.add(dataPair1);
    MLDataPair dataPair2 = BasicMLDataPair.createPair(2, 1);
    dataPair2.setInputArray(new double[] { 1.0, 0.0 });
    dataPair2.setIdealArray(new double[] { 0.0 });
    trainSet.add(dataPair2);
    MLDataPair dataPair3 = BasicMLDataPair.createPair(2, 1);
    dataPair3.setInputArray(new double[] { 1.0, 1.0 });
    dataPair3.setIdealArray(new double[] { 1.0 });
    trainSet.add(dataPair3);
    Propagation propagation = new QuickPropagation(network, trainSet, 0.1);
    double error = 0.0;
    double lastError = Double.MAX_VALUE;
    int iterCnt = 0;
    do {
        propagation.iteration();
        lastError = error;
        error = propagation.getError();
        System.out.println("The #" + (++iterCnt) + " error is " + error);
    } while (Math.abs(lastError - error) > 0.001);
    propagation.finishTraining();
    File tmp = new File("model_folder");
    if (!tmp.exists()) {
        FileUtils.forceMkdir(tmp);
    }
    File modelFile = new File("model_folder/model6.nn");
    EncogDirectoryPersistence.saveObject(modelFile, network);
    Assert.assertTrue(modelFile.exists());
    FileUtils.deleteQuietly(modelFile);
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) Propagation(org.encog.neural.networks.training.propagation.Propagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) File(java.io.File) Test(org.testng.annotations.Test)

Aggregations

Propagation (org.encog.neural.networks.training.propagation.Propagation)7 QuickPropagation (org.encog.neural.networks.training.propagation.quick.QuickPropagation)7 ManhattanPropagation (org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation)5 ResilientPropagation (org.encog.neural.networks.training.propagation.resilient.ResilientPropagation)5 Test (org.testng.annotations.Test)5 NNParams (ml.shifu.shifu.core.dtrain.nn.NNParams)4 MLDataSet (org.encog.ml.data.MLDataSet)4 BasicMLDataSet (org.encog.ml.data.basic.BasicMLDataSet)4 BeforeTest (org.testng.annotations.BeforeTest)4 File (java.io.File)1 ActivationLinear (org.encog.engine.network.activation.ActivationLinear)1 ActivationSigmoid (org.encog.engine.network.activation.ActivationSigmoid)1 MLDataPair (org.encog.ml.data.MLDataPair)1 BasicMLDataPair (org.encog.ml.data.basic.BasicMLDataPair)1 BasicNetwork (org.encog.neural.networks.BasicNetwork)1 BasicLayer (org.encog.neural.networks.layers.BasicLayer)1 Backpropagation (org.encog.neural.networks.training.propagation.back.Backpropagation)1