Search in sources :

Example 1 with MLDataSet

use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.

the class AbstractTrainer method setDataSet.

/*
     * Set up the training dataset and validation dataset
     */
public void setDataSet(MLDataSet masterDataSet) throws IOException {
    log.info("Setting Data Set...");
    MLDataSet sampledDataSet;
    if (this.trainingOption.equalsIgnoreCase("M")) {
        log.info("Loading to Memory ...");
        sampledDataSet = new BasicMLDataSet();
        this.trainSet = new BasicMLDataSet();
        this.validSet = new BasicMLDataSet();
    } else if (this.trainingOption.equalsIgnoreCase("D")) {
        log.info("Loading to Disk ...");
        sampledDataSet = new BufferedMLDataSet(new File(Constants.TMP, "sampled.egb"));
        this.trainSet = new BufferedMLDataSet(new File(Constants.TMP, "train.egb"));
        this.validSet = new BufferedMLDataSet(new File(Constants.TMP, "valid.egb"));
        int inputSize = masterDataSet.getInputSize();
        int idealSize = masterDataSet.getIdealSize();
        ((BufferedMLDataSet) sampledDataSet).beginLoad(inputSize, idealSize);
        ((BufferedMLDataSet) trainSet).beginLoad(inputSize, idealSize);
        ((BufferedMLDataSet) validSet).beginLoad(inputSize, idealSize);
    } else {
        throw new RuntimeException("Training Option is not Valid: " + this.trainingOption);
    }
    // Encog 3.1
    // int masterSize = masterDataSet.size();
    // Encog 3.0
    int masterSize = (int) masterDataSet.getRecordCount();
    if (!modelConfig.isFixInitialInput()) {
        // Bagging
        if (modelConfig.isBaggingWithReplacement()) {
            // Bagging With Replacement
            int sampledSize = (int) (masterSize * baggingSampleRate);
            for (int i = 0; i < sampledSize; i++) {
                // Encog 3.1
                // sampledDataSet.add(masterDataSet.get(random.nextInt(masterSize)));
                // Encog 3.0
                double[] input = new double[masterDataSet.getInputSize()];
                double[] ideal = new double[1];
                MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
                masterDataSet.getRecord(random.nextInt(masterSize), pair);
                sampledDataSet.add(pair);
            }
        } else {
            // Bagging Without Replacement
            for (MLDataPair pair : masterDataSet) {
                if (random.nextDouble() < baggingSampleRate) {
                    sampledDataSet.add(pair);
                }
            }
        }
    } else {
        List<Integer> list = loadSampleInput((int) (masterSize * baggingSampleRate), masterSize, modelConfig.isBaggingWithReplacement());
        for (Integer i : list) {
            double[] input = new double[masterDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            masterDataSet.getRecord(i, pair);
            sampledDataSet.add(pair);
        }
    }
    if (this.trainingOption.equalsIgnoreCase("D")) {
        ((BufferedMLDataSet) sampledDataSet).endLoad();
    }
    // Cross Validation
    log.info("Generating Training Set and Validation Set ...");
    if (!modelConfig.isFixInitialInput()) {
        // Encog 3.0
        for (int i = 0; i < sampledDataSet.getRecordCount(); i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            if (random.nextDouble() > crossValidationRate) {
                trainSet.add(pair);
            } else {
                validSet.add(pair);
            }
        }
    } else {
        long sampleSize = sampledDataSet.getRecordCount();
        long trainSetSize = (long) (sampleSize * (1 - crossValidationRate));
        int i = 0;
        for (; i < trainSetSize; i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            trainSet.add(pair);
        }
        for (; i < sampleSize; i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            validSet.add(pair);
        }
    }
    if (this.trainingOption.equalsIgnoreCase("D")) {
        ((BufferedMLDataSet) trainSet).endLoad();
        ((BufferedMLDataSet) validSet).endLoad();
    }
    log.info("    - # Records of the Master Data Set: " + masterSize);
    log.info("    - Bagging Sample Rate: " + baggingSampleRate);
    log.info("    - Bagging With Replacement: " + modelConfig.isBaggingWithReplacement());
    log.info("    - # Records of the Selected Data Set: " + sampledDataSet.getRecordCount());
    log.info("        - Cross Validation Rate: " + crossValidationRate);
    log.info("        - # Records of the Training Set: " + this.getTrainSetSize());
    log.info("        - # Records of the Validation Set: " + this.getValidSetSize());
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet) BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) BasicMLData(org.encog.ml.data.basic.BasicMLData) File(java.io.File) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet)

Example 2 with MLDataSet

use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.

the class ValidationConductor method runValidate.

public double runValidate() {
    // 1. prepare training data
    MLDataSet trainingData = new BasicMLDataSet();
    MLDataSet testingData = new BasicMLDataSet();
    this.trainingDataSet.generateValidateData(this.workingColumnSet, this.modelConfig.getValidSetRate(), trainingData, testingData);
    // 2. build NNTrainer
    NNTrainer trainer = new NNTrainer(this.modelConfig, 1, false);
    trainer.setTrainSet(trainingData);
    trainer.setValidSet(testingData);
    trainer.disableModelPersistence();
    trainer.disableLogging();
    // 3. train and get validation error
    double validateError = Double.MAX_VALUE;
    try {
        validateError = trainer.train();
    } catch (IOException e) {
        // Ignore the exception when nn files
        validateError = trainer.getBaseMSE();
    }
    return validateError;
}
Also used : NNTrainer(ml.shifu.shifu.core.alg.NNTrainer) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) IOException(java.io.IOException)

Example 3 with MLDataSet

use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.

the class DTrainTest method splitDataSet.

private MLDataSet[] splitDataSet(MLDataSet data) {
    MLDataSet[] subsets = new MLDataSet[numSplit];
    for (int i = 0; i < subsets.length; i++) {
        subsets[i] = new BasicMLDataSet();
    }
    for (int i = 0; i < data.getRecordCount(); i++) {
        MLDataPair pair = BasicMLDataPair.createPair(INPUT_COUNT, OUTPUT_COUNT);
        data.getRecord(i, pair);
        subsets[i % numSplit].add(pair);
    }
    return subsets;
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet)

Example 4 with MLDataSet

use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.

the class DTrainTest method manhantTest.

@Test
public void manhantTest() throws IOException {
    double[] gradientError = new double[NUM_EPOCHS];
    double[] ecogError = new double[NUM_EPOCHS];
    network.reset();
    weights = network.getFlat().getWeights();
    MLDataSet[] subsets = splitDataSet(training);
    Gradient[] workers = new Gradient[numSplit];
    Weight weightCalculator = null;
    for (int i = 0; i < workers.length; i++) {
        workers[i] = initGradient(subsets[i]);
        workers[i].setWeights(weights);
    }
    NNParams globalParams = new NNParams();
    globalParams.setWeights(weights);
    log.info("Starting manhattan propagation testing!");
    for (int i = 0; i < NUM_EPOCHS; i++) {
        double error = 0.0;
        // each worker do the job
        for (int j = 0; j < workers.length; j++) {
            workers[j].run();
            error += workers[j].getError();
        }
        gradientError[i] = error / workers.length;
        log.info("The #" + i + " training error: " + gradientError[i]);
        // master
        globalParams.reset();
        for (int j = 0; j < workers.length; j++) {
            globalParams.accumulateGradients(workers[j].getGradients());
            globalParams.accumulateTrainSize(subsets[j].getRecordCount());
        }
        if (weightCalculator == null) {
            weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.MANHATTAN_PROPAGATION, 0, RegulationLevel.NONE);
        }
        double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
        globalParams.setWeights(interWeight);
        // set weights
        for (int j = 0; j < workers.length; j++) {
            workers[j].setWeights(interWeight);
        }
    }
    // encog
    network.reset();
    // NNUtils.randomize(numSplit, weights);
    network.getFlat().setWeights(weights);
    Propagation p = null;
    p = new ManhattanPropagation(network, training, rate);
    p.setThreadCount(numSplit);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        p.iteration(1);
        // System.out.println("the #" + i + " training error: " + p.getError());
        ecogError[i] = p.getError();
    }
    // assert
    double diff = 0.0;
    for (int i = 0; i < NUM_EPOCHS; i++) {
        diff += Math.abs(ecogError[i] - gradientError[i]);
    }
    Assert.assertTrue(diff / NUM_EPOCHS < 0.3);
}
Also used : Propagation(org.encog.neural.networks.training.propagation.Propagation) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) NNParams(ml.shifu.shifu.core.dtrain.nn.NNParams) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation) Test(org.testng.annotations.Test) BeforeTest(org.testng.annotations.BeforeTest)

Example 5 with MLDataSet

use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.

the class DTrainTest method resilientPropagationTest.

@Test
public void resilientPropagationTest() {
    double[] gradientError = new double[NUM_EPOCHS];
    double[] ecogError = new double[NUM_EPOCHS];
    network.reset();
    weights = network.getFlat().getWeights();
    MLDataSet[] subsets = splitDataSet(training);
    Gradient[] workers = new Gradient[numSplit];
    Weight weightCalculator = null;
    for (int i = 0; i < workers.length; i++) {
        workers[i] = initGradient(subsets[i]);
        workers[i].setWeights(weights);
    }
    log.info("Starting resilient propagation testing!");
    NNParams globalParams = new NNParams();
    globalParams.setWeights(weights);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        double error = 0.0;
        // each worker do the job
        for (int j = 0; j < workers.length; j++) {
            workers[j].run();
            error += workers[j].getError();
        }
        gradientError[i] = error / workers.length;
        log.info("The #" + i + " training error: " + gradientError[i]);
        // master
        globalParams.reset();
        for (int j = 0; j < workers.length; j++) {
            globalParams.accumulateGradients(workers[j].getGradients());
            globalParams.accumulateTrainSize(subsets[j].getRecordCount());
        }
        if (weightCalculator == null) {
            weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.RESILIENTPROPAGATION, 0, RegulationLevel.NONE);
        }
        double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
        globalParams.setWeights(interWeight);
        // set weights
        for (int j = 0; j < workers.length; j++) {
            workers[j].setWeights(interWeight);
        }
    }
    // encog
    network.reset();
    // NNUtils.randomize(numSplit, weights);
    network.getFlat().setWeights(weights);
    Propagation p = null;
    p = new ResilientPropagation(network, training);
    p.setThreadCount(numSplit);
    for (int i = 0; i < NUM_EPOCHS; i++) {
        p.iteration(1);
        ecogError[i] = p.getError();
    }
    // assert
    double diff = 0.0;
    for (int i = 0; i < NUM_EPOCHS; i++) {
        diff += Math.abs(ecogError[i] - gradientError[i]);
    }
    Assert.assertTrue(diff / NUM_EPOCHS < 0.2);
}
Also used : Propagation(org.encog.neural.networks.training.propagation.Propagation) ManhattanPropagation(org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) ResilientPropagation(org.encog.neural.networks.training.propagation.resilient.ResilientPropagation) NNParams(ml.shifu.shifu.core.dtrain.nn.NNParams) Test(org.testng.annotations.Test) BeforeTest(org.testng.annotations.BeforeTest)

Aggregations

MLDataSet (org.encog.ml.data.MLDataSet)8 BasicMLDataSet (org.encog.ml.data.basic.BasicMLDataSet)8 Test (org.testng.annotations.Test)5 NNParams (ml.shifu.shifu.core.dtrain.nn.NNParams)4 Propagation (org.encog.neural.networks.training.propagation.Propagation)4 ManhattanPropagation (org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation)4 QuickPropagation (org.encog.neural.networks.training.propagation.quick.QuickPropagation)4 ResilientPropagation (org.encog.neural.networks.training.propagation.resilient.ResilientPropagation)4 BeforeTest (org.testng.annotations.BeforeTest)4 MLDataPair (org.encog.ml.data.MLDataPair)3 BasicMLDataPair (org.encog.ml.data.basic.BasicMLDataPair)3 NNTrainer (ml.shifu.shifu.core.alg.NNTrainer)2 BasicMLData (org.encog.ml.data.basic.BasicMLData)2 File (java.io.File)1 IOException (java.io.IOException)1 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)1 BufferedMLDataSet (org.encog.ml.data.buffer.BufferedMLDataSet)1 Backpropagation (org.encog.neural.networks.training.propagation.back.Backpropagation)1