use of org.encog.ml.data.buffer.BufferedMLDataSet in project shifu by ShifuML.
the class MemoryDiskMLDataSet method add.
/*
* (non-Javadoc)
*
* @see org.encog.ml.data.MLDataSet#add(org.encog.ml.data.MLData, org.encog.ml.data.MLData)
*/
@Override
public void add(MLData inputData, MLData idealData) {
long currentSize = SizeEstimator.estimate(inputData) + SizeEstimator.estimate(idealData);
if (this.byteSize + currentSize < this.maxByteSize) {
this.byteSize += currentSize;
this.memoryCount += 1l;
this.memoryDataSet.add(inputData, idealData);
} else {
if (this.diskDataSet == null) {
this.diskDataSet = new BufferedMLDataSet(new File(this.fileName));
((BufferedMLDataSet) this.diskDataSet).beginLoad(this.inputCount, this.outputCount);
}
this.byteSize += currentSize;
this.diskCount += 1l;
this.diskDataSet.add(inputData, idealData);
}
}
use of org.encog.ml.data.buffer.BufferedMLDataSet in project shifu by ShifuML.
the class MemoryDiskMLDataSet method add.
/*
* (non-Javadoc)
*
* @see org.encog.ml.data.MLDataSet#add(org.encog.ml.data.MLDataPair)
*/
@Override
public void add(MLDataPair inputData) {
long currentSize = SizeEstimator.estimate(inputData);
if (this.byteSize + currentSize < this.maxByteSize) {
this.byteSize += currentSize;
this.memoryCount += 1l;
this.memoryDataSet.add(inputData);
} else {
if (this.diskDataSet == null) {
this.diskDataSet = new BufferedMLDataSet(new File(this.fileName));
((BufferedMLDataSet) this.diskDataSet).beginLoad(this.inputCount, this.outputCount);
}
this.byteSize += currentSize;
this.diskCount += 1l;
this.diskDataSet.add(inputData);
}
}
use of org.encog.ml.data.buffer.BufferedMLDataSet in project shifu by ShifuML.
the class AbstractTrainer method setDataSet.
/*
* Set up the training dataset and validation dataset
*/
public void setDataSet(MLDataSet masterDataSet) throws IOException {
log.info("Setting Data Set...");
MLDataSet sampledDataSet;
if (this.trainingOption.equalsIgnoreCase("M")) {
log.info("Loading to Memory ...");
sampledDataSet = new BasicMLDataSet();
this.trainSet = new BasicMLDataSet();
this.validSet = new BasicMLDataSet();
} else if (this.trainingOption.equalsIgnoreCase("D")) {
log.info("Loading to Disk ...");
sampledDataSet = new BufferedMLDataSet(new File(Constants.TMP, "sampled.egb"));
this.trainSet = new BufferedMLDataSet(new File(Constants.TMP, "train.egb"));
this.validSet = new BufferedMLDataSet(new File(Constants.TMP, "valid.egb"));
int inputSize = masterDataSet.getInputSize();
int idealSize = masterDataSet.getIdealSize();
((BufferedMLDataSet) sampledDataSet).beginLoad(inputSize, idealSize);
((BufferedMLDataSet) trainSet).beginLoad(inputSize, idealSize);
((BufferedMLDataSet) validSet).beginLoad(inputSize, idealSize);
} else {
throw new RuntimeException("Training Option is not Valid: " + this.trainingOption);
}
// Encog 3.1
// int masterSize = masterDataSet.size();
// Encog 3.0
int masterSize = (int) masterDataSet.getRecordCount();
if (!modelConfig.isFixInitialInput()) {
// Bagging
if (modelConfig.isBaggingWithReplacement()) {
// Bagging With Replacement
int sampledSize = (int) (masterSize * baggingSampleRate);
for (int i = 0; i < sampledSize; i++) {
// Encog 3.1
// sampledDataSet.add(masterDataSet.get(random.nextInt(masterSize)));
// Encog 3.0
double[] input = new double[masterDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
masterDataSet.getRecord(random.nextInt(masterSize), pair);
sampledDataSet.add(pair);
}
} else {
// Bagging Without Replacement
for (MLDataPair pair : masterDataSet) {
if (random.nextDouble() < baggingSampleRate) {
sampledDataSet.add(pair);
}
}
}
} else {
List<Integer> list = loadSampleInput((int) (masterSize * baggingSampleRate), masterSize, modelConfig.isBaggingWithReplacement());
for (Integer i : list) {
double[] input = new double[masterDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
masterDataSet.getRecord(i, pair);
sampledDataSet.add(pair);
}
}
if (this.trainingOption.equalsIgnoreCase("D")) {
((BufferedMLDataSet) sampledDataSet).endLoad();
}
// Cross Validation
log.info("Generating Training Set and Validation Set ...");
if (!modelConfig.isFixInitialInput()) {
// Encog 3.0
for (int i = 0; i < sampledDataSet.getRecordCount(); i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
if (random.nextDouble() > crossValidationRate) {
trainSet.add(pair);
} else {
validSet.add(pair);
}
}
} else {
long sampleSize = sampledDataSet.getRecordCount();
long trainSetSize = (long) (sampleSize * (1 - crossValidationRate));
int i = 0;
for (; i < trainSetSize; i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
trainSet.add(pair);
}
for (; i < sampleSize; i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
validSet.add(pair);
}
}
if (this.trainingOption.equalsIgnoreCase("D")) {
((BufferedMLDataSet) trainSet).endLoad();
((BufferedMLDataSet) validSet).endLoad();
}
log.info(" - # Records of the Master Data Set: " + masterSize);
log.info(" - Bagging Sample Rate: " + baggingSampleRate);
log.info(" - Bagging With Replacement: " + modelConfig.isBaggingWithReplacement());
log.info(" - # Records of the Selected Data Set: " + sampledDataSet.getRecordCount());
log.info(" - Cross Validation Rate: " + crossValidationRate);
log.info(" - # Records of the Training Set: " + this.getTrainSetSize());
log.info(" - # Records of the Validation Set: " + this.getValidSetSize());
}
use of org.encog.ml.data.buffer.BufferedMLDataSet in project shifu by ShifuML.
the class TrainDataPrepWorker method handleMsg.
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) throws IOException {
if (message instanceof TrainPartDataMessage) {
log.debug("Received value object list for training model.");
TrainPartDataMessage msg = (TrainPartDataMessage) message;
for (MLDataPair mlDataPir : msg.getMlDataPairList()) {
if (modelConfig.isTrainOnDisk() && !initialized) {
int inputSize = mlDataPir.getInput().size();
int idealSize = mlDataPir.getIdeal().size();
((BufferedMLDataSet) masterDataSet).beginLoad(inputSize, idealSize);
initialized = true;
}
masterDataSet.add(mlDataPir);
}
receivedMsgCnt++;
log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
if (receivedMsgCnt == msg.getTotalMsgCnt()) {
if (modelConfig.isTrainOnDisk() && initialized) {
((BufferedMLDataSet) masterDataSet).endLoad();
}
for (AbstractTrainer trainer : trainers) {
// if the trainOnDisk is true, setting the "D" option
if (modelConfig.isTrainOnDisk()) {
trainer.setTrainingOption("D");
}
trainer.setDataSet(masterDataSet);
nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
}
if (modelConfig.isTrainOnDisk() && initialized) {
masterDataSet.close();
masterDataSet = null;
}
}
} else if (message instanceof StatsPartRawDataMessage) {
StatsPartRawDataMessage msg = (StatsPartRawDataMessage) message;
receivedMsgCnt++;
log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
if (receivedMsgCnt == msg.getTotalMsgCnt()) {
for (AbstractTrainer trainer : trainers) {
// ((DecisionTreeTrainer)trainer).setDataSet(rawInstanceList);
nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
}
}
} else {
unhandled(message);
}
}
Aggregations