use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.
the class AbstractTrainer method setDataSet.
/*
* Set up the training dataset and validation dataset
*/
public void setDataSet(MLDataSet masterDataSet) throws IOException {
log.info("Setting Data Set...");
MLDataSet sampledDataSet;
if (this.trainingOption.equalsIgnoreCase("M")) {
log.info("Loading to Memory ...");
sampledDataSet = new BasicMLDataSet();
this.trainSet = new BasicMLDataSet();
this.validSet = new BasicMLDataSet();
} else if (this.trainingOption.equalsIgnoreCase("D")) {
log.info("Loading to Disk ...");
sampledDataSet = new BufferedMLDataSet(new File(Constants.TMP, "sampled.egb"));
this.trainSet = new BufferedMLDataSet(new File(Constants.TMP, "train.egb"));
this.validSet = new BufferedMLDataSet(new File(Constants.TMP, "valid.egb"));
int inputSize = masterDataSet.getInputSize();
int idealSize = masterDataSet.getIdealSize();
((BufferedMLDataSet) sampledDataSet).beginLoad(inputSize, idealSize);
((BufferedMLDataSet) trainSet).beginLoad(inputSize, idealSize);
((BufferedMLDataSet) validSet).beginLoad(inputSize, idealSize);
} else {
throw new RuntimeException("Training Option is not Valid: " + this.trainingOption);
}
// Encog 3.1
// int masterSize = masterDataSet.size();
// Encog 3.0
int masterSize = (int) masterDataSet.getRecordCount();
if (!modelConfig.isFixInitialInput()) {
// Bagging
if (modelConfig.isBaggingWithReplacement()) {
// Bagging With Replacement
int sampledSize = (int) (masterSize * baggingSampleRate);
for (int i = 0; i < sampledSize; i++) {
// Encog 3.1
// sampledDataSet.add(masterDataSet.get(random.nextInt(masterSize)));
// Encog 3.0
double[] input = new double[masterDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
masterDataSet.getRecord(random.nextInt(masterSize), pair);
sampledDataSet.add(pair);
}
} else {
// Bagging Without Replacement
for (MLDataPair pair : masterDataSet) {
if (random.nextDouble() < baggingSampleRate) {
sampledDataSet.add(pair);
}
}
}
} else {
List<Integer> list = loadSampleInput((int) (masterSize * baggingSampleRate), masterSize, modelConfig.isBaggingWithReplacement());
for (Integer i : list) {
double[] input = new double[masterDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
masterDataSet.getRecord(i, pair);
sampledDataSet.add(pair);
}
}
if (this.trainingOption.equalsIgnoreCase("D")) {
((BufferedMLDataSet) sampledDataSet).endLoad();
}
// Cross Validation
log.info("Generating Training Set and Validation Set ...");
if (!modelConfig.isFixInitialInput()) {
// Encog 3.0
for (int i = 0; i < sampledDataSet.getRecordCount(); i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
if (random.nextDouble() > crossValidationRate) {
trainSet.add(pair);
} else {
validSet.add(pair);
}
}
} else {
long sampleSize = sampledDataSet.getRecordCount();
long trainSetSize = (long) (sampleSize * (1 - crossValidationRate));
int i = 0;
for (; i < trainSetSize; i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
trainSet.add(pair);
}
for (; i < sampleSize; i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
validSet.add(pair);
}
}
if (this.trainingOption.equalsIgnoreCase("D")) {
((BufferedMLDataSet) trainSet).endLoad();
((BufferedMLDataSet) validSet).endLoad();
}
log.info(" - # Records of the Master Data Set: " + masterSize);
log.info(" - Bagging Sample Rate: " + baggingSampleRate);
log.info(" - Bagging With Replacement: " + modelConfig.isBaggingWithReplacement());
log.info(" - # Records of the Selected Data Set: " + sampledDataSet.getRecordCount());
log.info(" - Cross Validation Rate: " + crossValidationRate);
log.info(" - # Records of the Training Set: " + this.getTrainSetSize());
log.info(" - # Records of the Validation Set: " + this.getValidSetSize());
}
use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.
the class ValidationConductor method runValidate.
public double runValidate() {
// 1. prepare training data
MLDataSet trainingData = new BasicMLDataSet();
MLDataSet testingData = new BasicMLDataSet();
this.trainingDataSet.generateValidateData(this.workingColumnSet, this.modelConfig.getValidSetRate(), trainingData, testingData);
// 2. build NNTrainer
NNTrainer trainer = new NNTrainer(this.modelConfig, 1, false);
trainer.setTrainSet(trainingData);
trainer.setValidSet(testingData);
trainer.disableModelPersistence();
trainer.disableLogging();
// 3. train and get validation error
double validateError = Double.MAX_VALUE;
try {
validateError = trainer.train();
} catch (IOException e) {
// Ignore the exception when nn files
validateError = trainer.getBaseMSE();
}
return validateError;
}
use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.
the class DTrainTest method splitDataSet.
private MLDataSet[] splitDataSet(MLDataSet data) {
MLDataSet[] subsets = new MLDataSet[numSplit];
for (int i = 0; i < subsets.length; i++) {
subsets[i] = new BasicMLDataSet();
}
for (int i = 0; i < data.getRecordCount(); i++) {
MLDataPair pair = BasicMLDataPair.createPair(INPUT_COUNT, OUTPUT_COUNT);
data.getRecord(i, pair);
subsets[i % numSplit].add(pair);
}
return subsets;
}
use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.
the class DTrainTest method manhantTest.
@Test
public void manhantTest() throws IOException {
double[] gradientError = new double[NUM_EPOCHS];
double[] ecogError = new double[NUM_EPOCHS];
network.reset();
weights = network.getFlat().getWeights();
MLDataSet[] subsets = splitDataSet(training);
Gradient[] workers = new Gradient[numSplit];
Weight weightCalculator = null;
for (int i = 0; i < workers.length; i++) {
workers[i] = initGradient(subsets[i]);
workers[i].setWeights(weights);
}
NNParams globalParams = new NNParams();
globalParams.setWeights(weights);
log.info("Starting manhattan propagation testing!");
for (int i = 0; i < NUM_EPOCHS; i++) {
double error = 0.0;
// each worker do the job
for (int j = 0; j < workers.length; j++) {
workers[j].run();
error += workers[j].getError();
}
gradientError[i] = error / workers.length;
log.info("The #" + i + " training error: " + gradientError[i]);
// master
globalParams.reset();
for (int j = 0; j < workers.length; j++) {
globalParams.accumulateGradients(workers[j].getGradients());
globalParams.accumulateTrainSize(subsets[j].getRecordCount());
}
if (weightCalculator == null) {
weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.MANHATTAN_PROPAGATION, 0, RegulationLevel.NONE);
}
double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
globalParams.setWeights(interWeight);
// set weights
for (int j = 0; j < workers.length; j++) {
workers[j].setWeights(interWeight);
}
}
// encog
network.reset();
// NNUtils.randomize(numSplit, weights);
network.getFlat().setWeights(weights);
Propagation p = null;
p = new ManhattanPropagation(network, training, rate);
p.setThreadCount(numSplit);
for (int i = 0; i < NUM_EPOCHS; i++) {
p.iteration(1);
// System.out.println("the #" + i + " training error: " + p.getError());
ecogError[i] = p.getError();
}
// assert
double diff = 0.0;
for (int i = 0; i < NUM_EPOCHS; i++) {
diff += Math.abs(ecogError[i] - gradientError[i]);
}
Assert.assertTrue(diff / NUM_EPOCHS < 0.3);
}
use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.
the class DTrainTest method resilientPropagationTest.
@Test
public void resilientPropagationTest() {
double[] gradientError = new double[NUM_EPOCHS];
double[] ecogError = new double[NUM_EPOCHS];
network.reset();
weights = network.getFlat().getWeights();
MLDataSet[] subsets = splitDataSet(training);
Gradient[] workers = new Gradient[numSplit];
Weight weightCalculator = null;
for (int i = 0; i < workers.length; i++) {
workers[i] = initGradient(subsets[i]);
workers[i].setWeights(weights);
}
log.info("Starting resilient propagation testing!");
NNParams globalParams = new NNParams();
globalParams.setWeights(weights);
for (int i = 0; i < NUM_EPOCHS; i++) {
double error = 0.0;
// each worker do the job
for (int j = 0; j < workers.length; j++) {
workers[j].run();
error += workers[j].getError();
}
gradientError[i] = error / workers.length;
log.info("The #" + i + " training error: " + gradientError[i]);
// master
globalParams.reset();
for (int j = 0; j < workers.length; j++) {
globalParams.accumulateGradients(workers[j].getGradients());
globalParams.accumulateTrainSize(subsets[j].getRecordCount());
}
if (weightCalculator == null) {
weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.RESILIENTPROPAGATION, 0, RegulationLevel.NONE);
}
double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
globalParams.setWeights(interWeight);
// set weights
for (int j = 0; j < workers.length; j++) {
workers[j].setWeights(interWeight);
}
}
// encog
network.reset();
// NNUtils.randomize(numSplit, weights);
network.getFlat().setWeights(weights);
Propagation p = null;
p = new ResilientPropagation(network, training);
p.setThreadCount(numSplit);
for (int i = 0; i < NUM_EPOCHS; i++) {
p.iteration(1);
ecogError[i] = p.getError();
}
// assert
double diff = 0.0;
for (int i = 0; i < NUM_EPOCHS; i++) {
diff += Math.abs(ecogError[i] - gradientError[i]);
}
Assert.assertTrue(diff / NUM_EPOCHS < 0.2);
}
Aggregations