use of org.encog.ml.data.basic.BasicMLDataSet in project shifu by ShifuML.
the class AbstractTrainer method setDataSet.
/*
* Set up the training dataset and validation dataset
*/
public void setDataSet(MLDataSet masterDataSet) throws IOException {
log.info("Setting Data Set...");
MLDataSet sampledDataSet;
if (this.trainingOption.equalsIgnoreCase("M")) {
log.info("Loading to Memory ...");
sampledDataSet = new BasicMLDataSet();
this.trainSet = new BasicMLDataSet();
this.validSet = new BasicMLDataSet();
} else if (this.trainingOption.equalsIgnoreCase("D")) {
log.info("Loading to Disk ...");
sampledDataSet = new BufferedMLDataSet(new File(Constants.TMP, "sampled.egb"));
this.trainSet = new BufferedMLDataSet(new File(Constants.TMP, "train.egb"));
this.validSet = new BufferedMLDataSet(new File(Constants.TMP, "valid.egb"));
int inputSize = masterDataSet.getInputSize();
int idealSize = masterDataSet.getIdealSize();
((BufferedMLDataSet) sampledDataSet).beginLoad(inputSize, idealSize);
((BufferedMLDataSet) trainSet).beginLoad(inputSize, idealSize);
((BufferedMLDataSet) validSet).beginLoad(inputSize, idealSize);
} else {
throw new RuntimeException("Training Option is not Valid: " + this.trainingOption);
}
// Encog 3.1
// int masterSize = masterDataSet.size();
// Encog 3.0
int masterSize = (int) masterDataSet.getRecordCount();
if (!modelConfig.isFixInitialInput()) {
// Bagging
if (modelConfig.isBaggingWithReplacement()) {
// Bagging With Replacement
int sampledSize = (int) (masterSize * baggingSampleRate);
for (int i = 0; i < sampledSize; i++) {
// Encog 3.1
// sampledDataSet.add(masterDataSet.get(random.nextInt(masterSize)));
// Encog 3.0
double[] input = new double[masterDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
masterDataSet.getRecord(random.nextInt(masterSize), pair);
sampledDataSet.add(pair);
}
} else {
// Bagging Without Replacement
for (MLDataPair pair : masterDataSet) {
if (random.nextDouble() < baggingSampleRate) {
sampledDataSet.add(pair);
}
}
}
} else {
List<Integer> list = loadSampleInput((int) (masterSize * baggingSampleRate), masterSize, modelConfig.isBaggingWithReplacement());
for (Integer i : list) {
double[] input = new double[masterDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
masterDataSet.getRecord(i, pair);
sampledDataSet.add(pair);
}
}
if (this.trainingOption.equalsIgnoreCase("D")) {
((BufferedMLDataSet) sampledDataSet).endLoad();
}
// Cross Validation
log.info("Generating Training Set and Validation Set ...");
if (!modelConfig.isFixInitialInput()) {
// Encog 3.0
for (int i = 0; i < sampledDataSet.getRecordCount(); i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
if (random.nextDouble() > crossValidationRate) {
trainSet.add(pair);
} else {
validSet.add(pair);
}
}
} else {
long sampleSize = sampledDataSet.getRecordCount();
long trainSetSize = (long) (sampleSize * (1 - crossValidationRate));
int i = 0;
for (; i < trainSetSize; i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
trainSet.add(pair);
}
for (; i < sampleSize; i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
validSet.add(pair);
}
}
if (this.trainingOption.equalsIgnoreCase("D")) {
((BufferedMLDataSet) trainSet).endLoad();
((BufferedMLDataSet) validSet).endLoad();
}
log.info(" - # Records of the Master Data Set: " + masterSize);
log.info(" - Bagging Sample Rate: " + baggingSampleRate);
log.info(" - Bagging With Replacement: " + modelConfig.isBaggingWithReplacement());
log.info(" - # Records of the Selected Data Set: " + sampledDataSet.getRecordCount());
log.info(" - Cross Validation Rate: " + crossValidationRate);
log.info(" - # Records of the Training Set: " + this.getTrainSetSize());
log.info(" - # Records of the Validation Set: " + this.getValidSetSize());
}
Aggregations