use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.
the class AbstractTrainerTest method testLoad1.
@Test
public void testLoad1() throws IOException {
MLDataSet set = new BasicMLDataSet();
ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", SourceType.LOCAL);
double[] input = new double[modelConfig.getVarSelectFilterNum()];
for (int j = 0; j < 1000; j++) {
for (int i = 0; i < modelConfig.getVarSelectFilterNum(); i++) {
input[i] = random.nextDouble();
}
double[] ideal = new double[1];
ideal[0] = random.nextInt(2);
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
set.add(pair);
}
modelConfig.getTrain().setTrainOnDisk(false);
AbstractTrainer trainer = new NNTrainer(modelConfig, 0, false);
trainer.setDataSet(set);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() <= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 1.05);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() >= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 0.95);
modelConfig.getTrain().setFixInitInput(true);
trainer = new NNTrainer(modelConfig, 0, false);
trainer.setDataSet(set);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() <= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 1.05);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() >= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 0.95);
modelConfig.getTrain().setFixInitInput(false);
modelConfig.getTrain().setBaggingWithReplacement(false);
trainer = new NNTrainer(modelConfig, 0, false);
trainer.setDataSet(set);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() <= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 1.05);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() >= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 0.95);
}
use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.
the class DTrainTest method quickTest.
@Test
public void quickTest() throws IOException {
double[] gradientError = new double[NUM_EPOCHS];
double[] ecogError = new double[NUM_EPOCHS];
network.reset();
weights = network.getFlat().getWeights();
MLDataSet[] subsets = splitDataSet(training);
Gradient[] workers = new Gradient[numSplit];
Weight weightCalculator = null;
for (int i = 0; i < workers.length; i++) {
workers[i] = initGradient(subsets[i]);
workers[i].setWeights(weights);
}
log.info("Running QuickPropagtaion testing! ");
NNParams globalParams = new NNParams();
globalParams.setWeights(weights);
for (int i = 0; i < NUM_EPOCHS; i++) {
double error = 0.0;
// each worker do the job
for (int j = 0; j < workers.length; j++) {
workers[j].run();
error += workers[j].getError();
}
gradientError[i] = error / workers.length;
log.info("The #" + i + " training error: " + gradientError[i]);
// master
globalParams.reset();
for (int j = 0; j < workers.length; j++) {
globalParams.accumulateGradients(workers[j].getGradients());
globalParams.accumulateTrainSize(subsets[j].getRecordCount());
}
if (weightCalculator == null) {
weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.QUICK_PROPAGATION, 0, RegulationLevel.NONE);
}
double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
globalParams.setWeights(interWeight);
// set weights
for (int j = 0; j < workers.length; j++) {
workers[j].setWeights(interWeight);
}
}
// encog
network.reset();
// NNUtils.randomize(numSplit, weights);
network.getFlat().setWeights(weights);
Propagation p = null;
p = new QuickPropagation(network, training, rate);
// p = new ManhattanPropagation(network, training, rate);
p.setThreadCount(numSplit);
for (int i = 0; i < NUM_EPOCHS; i++) {
p.iteration(1);
// System.out.println("the #" + i + " training error: " + p.getError());
ecogError[i] = p.getError();
}
// assert
double diff = 0.0;
for (int i = 0; i < NUM_EPOCHS; i++) {
diff += Math.abs(ecogError[i] - gradientError[i]);
}
Assert.assertTrue(diff / NUM_EPOCHS < 0.1);
}
use of org.encog.ml.data.MLDataSet in project shifu by ShifuML.
the class DTrainTest method backTest.
@Test
public void backTest() {
double[] gradientError = new double[NUM_EPOCHS];
double[] ecogError = new double[NUM_EPOCHS];
network.reset();
weights = network.getFlat().getWeights();
MLDataSet[] subsets = splitDataSet(training);
Gradient[] workers = new Gradient[numSplit];
Weight weightCalculator = null;
for (int i = 0; i < workers.length; i++) {
workers[i] = initGradient(subsets[i]);
workers[i].setWeights(weights);
}
log.info("Starting back propagation testing!");
NNParams globalParams = new NNParams();
globalParams.setWeights(weights);
for (int i = 0; i < NUM_EPOCHS; i++) {
double error = 0.0;
// each worker do the job
for (int j = 0; j < workers.length; j++) {
workers[j].run();
error += workers[j].getError();
}
gradientError[i] = error / workers.length;
log.info("The #" + i + " training error: " + gradientError[i]);
// master
globalParams.reset();
for (int j = 0; j < workers.length; j++) {
globalParams.accumulateGradients(workers[j].getGradients());
globalParams.accumulateTrainSize(subsets[j].getRecordCount());
}
if (weightCalculator == null) {
weightCalculator = new Weight(globalParams.getGradients().length, globalParams.getTrainSize(), this.rate, DTrainUtils.BACK_PROPAGATION, 0, RegulationLevel.NONE);
}
double[] interWeight = weightCalculator.calculateWeights(globalParams.getWeights(), globalParams.getGradients(), -1);
globalParams.setWeights(interWeight);
// set weights
for (int j = 0; j < workers.length; j++) {
workers[j].setWeights(interWeight);
}
}
// encog
network.reset();
// NNUtils.randomize(numSplit, weights);
network.getFlat().setWeights(weights);
Propagation p = null;
p = new Backpropagation(network, training, rate, 0.5);
p.setThreadCount(numSplit);
for (int i = 0; i < NUM_EPOCHS; i++) {
p.iteration(1);
// System.out.println("the #" + i + " training error: " + p.getError());
ecogError[i] = p.getError();
}
// assert
double diff = 0.0;
for (int i = 0; i < NUM_EPOCHS; i++) {
diff += Math.abs(ecogError[i] - gradientError[i]);
}
Assert.assertTrue(diff / NUM_EPOCHS < 0.2);
}
Aggregations