use of org.encog.ml.data.basic.BasicMLDataPair in project shifu by ShifuML.
the class AbstractTrainer method setDataSet.
/*
* Set up the training dataset and validation dataset
*/
public void setDataSet(MLDataSet masterDataSet) throws IOException {
log.info("Setting Data Set...");
MLDataSet sampledDataSet;
if (this.trainingOption.equalsIgnoreCase("M")) {
log.info("Loading to Memory ...");
sampledDataSet = new BasicMLDataSet();
this.trainSet = new BasicMLDataSet();
this.validSet = new BasicMLDataSet();
} else if (this.trainingOption.equalsIgnoreCase("D")) {
log.info("Loading to Disk ...");
sampledDataSet = new BufferedMLDataSet(new File(Constants.TMP, "sampled.egb"));
this.trainSet = new BufferedMLDataSet(new File(Constants.TMP, "train.egb"));
this.validSet = new BufferedMLDataSet(new File(Constants.TMP, "valid.egb"));
int inputSize = masterDataSet.getInputSize();
int idealSize = masterDataSet.getIdealSize();
((BufferedMLDataSet) sampledDataSet).beginLoad(inputSize, idealSize);
((BufferedMLDataSet) trainSet).beginLoad(inputSize, idealSize);
((BufferedMLDataSet) validSet).beginLoad(inputSize, idealSize);
} else {
throw new RuntimeException("Training Option is not Valid: " + this.trainingOption);
}
// Encog 3.1
// int masterSize = masterDataSet.size();
// Encog 3.0
int masterSize = (int) masterDataSet.getRecordCount();
if (!modelConfig.isFixInitialInput()) {
// Bagging
if (modelConfig.isBaggingWithReplacement()) {
// Bagging With Replacement
int sampledSize = (int) (masterSize * baggingSampleRate);
for (int i = 0; i < sampledSize; i++) {
// Encog 3.1
// sampledDataSet.add(masterDataSet.get(random.nextInt(masterSize)));
// Encog 3.0
double[] input = new double[masterDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
masterDataSet.getRecord(random.nextInt(masterSize), pair);
sampledDataSet.add(pair);
}
} else {
// Bagging Without Replacement
for (MLDataPair pair : masterDataSet) {
if (random.nextDouble() < baggingSampleRate) {
sampledDataSet.add(pair);
}
}
}
} else {
List<Integer> list = loadSampleInput((int) (masterSize * baggingSampleRate), masterSize, modelConfig.isBaggingWithReplacement());
for (Integer i : list) {
double[] input = new double[masterDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
masterDataSet.getRecord(i, pair);
sampledDataSet.add(pair);
}
}
if (this.trainingOption.equalsIgnoreCase("D")) {
((BufferedMLDataSet) sampledDataSet).endLoad();
}
// Cross Validation
log.info("Generating Training Set and Validation Set ...");
if (!modelConfig.isFixInitialInput()) {
// Encog 3.0
for (int i = 0; i < sampledDataSet.getRecordCount(); i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
if (random.nextDouble() > crossValidationRate) {
trainSet.add(pair);
} else {
validSet.add(pair);
}
}
} else {
long sampleSize = sampledDataSet.getRecordCount();
long trainSetSize = (long) (sampleSize * (1 - crossValidationRate));
int i = 0;
for (; i < trainSetSize; i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
trainSet.add(pair);
}
for (; i < sampleSize; i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
validSet.add(pair);
}
}
if (this.trainingOption.equalsIgnoreCase("D")) {
((BufferedMLDataSet) trainSet).endLoad();
((BufferedMLDataSet) validSet).endLoad();
}
log.info(" - # Records of the Master Data Set: " + masterSize);
log.info(" - Bagging Sample Rate: " + baggingSampleRate);
log.info(" - Bagging With Replacement: " + modelConfig.isBaggingWithReplacement());
log.info(" - # Records of the Selected Data Set: " + sampledDataSet.getRecordCount());
log.info(" - Cross Validation Rate: " + crossValidationRate);
log.info(" - # Records of the Training Set: " + this.getTrainSetSize());
log.info(" - # Records of the Validation Set: " + this.getValidSetSize());
}
use of org.encog.ml.data.basic.BasicMLDataPair in project shifu by ShifuML.
the class AbstractTrainer method calculateMSE.
/*
* non-synchronously version update error
*
* @return the standard error
*/
public static Double calculateMSE(BasicNetwork network, MLDataSet dataSet) {
double mse = 0;
long numRecords = dataSet.getRecordCount();
for (int i = 0; i < numRecords; i++) {
// Encog 3.1
// MLDataPair pair = dataSet.get(i);
// Encog 3.0
double[] input = new double[dataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
dataSet.getRecord(i, pair);
MLData result = network.compute(pair.getInput());
double tmp = result.getData()[0] - pair.getIdeal().getData()[0];
mse += tmp * tmp;
}
mse = mse / numRecords;
return mse;
}
use of org.encog.ml.data.basic.BasicMLDataPair in project shifu by ShifuML.
the class DataLoadWorker method readTrainingData.
/**
* Read the normalized training data for model training
*
* @param scanner
* - input partition
* @param isDryRun
* - is for test running?
* @return List of data
*/
public List<MLDataPair> readTrainingData(Scanner scanner, boolean isDryRun) {
List<MLDataPair> mlDataPairList = new ArrayList<MLDataPair>();
int numSelected = 0;
for (ColumnConfig config : columnConfigList) {
if (config.isFinalSelect()) {
numSelected++;
}
}
int cnt = 0;
while (scanner.hasNextLine()) {
if ((cnt++) % 100000 == 0) {
log.info("Read " + (cnt) + " Records.");
}
String line = scanner.nextLine();
if (isDryRun) {
MLDataPair dummyPair = new BasicMLDataPair(new BasicMLData(new double[1]), new BasicMLData(new double[1]));
mlDataPairList.add(dummyPair);
continue;
}
// the normalized training data is separated by | by default
double[] inputs = new double[numSelected];
double[] ideal = new double[1];
double significance = 0.0d;
int index = 0, inputsIndex = 0, outputIndex = 0;
for (String input : DEFAULT_SPLITTER.split(line.trim())) {
double doubleValue = NumberFormatUtils.getDouble(input.trim(), 0.0d);
if (index == this.columnConfigList.size()) {
significance = NumberFormatUtils.getDouble(input.trim(), CommonConstants.DEFAULT_SIGNIFICANCE_VALUE);
break;
} else {
ColumnConfig columnConfig = this.columnConfigList.get(index);
if (columnConfig != null && columnConfig.isTarget()) {
ideal[outputIndex++] = doubleValue;
} else {
if (this.inputNodeCount == this.candidateCount) {
// all variables are not set final-select
if (CommonUtils.isGoodCandidate(columnConfig)) {
inputs[inputsIndex++] = doubleValue;
}
} else {
// final select some variables
if (columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
inputs[inputsIndex++] = doubleValue;
}
}
}
}
index++;
}
MLDataPair pair = new BasicMLDataPair(new BasicMLData(inputs), new BasicMLData(ideal));
pair.setSignificance(significance);
mlDataPairList.add(pair);
}
return mlDataPairList;
}
Aggregations