use of org.encog.ml.data.MLDataPair in project shifu by ShifuML.
the class AbstractTrainer method setDataSet.
/*
* Set up the training dataset and validation dataset
*/
public void setDataSet(MLDataSet masterDataSet) throws IOException {
log.info("Setting Data Set...");
MLDataSet sampledDataSet;
if (this.trainingOption.equalsIgnoreCase("M")) {
log.info("Loading to Memory ...");
sampledDataSet = new BasicMLDataSet();
this.trainSet = new BasicMLDataSet();
this.validSet = new BasicMLDataSet();
} else if (this.trainingOption.equalsIgnoreCase("D")) {
log.info("Loading to Disk ...");
sampledDataSet = new BufferedMLDataSet(new File(Constants.TMP, "sampled.egb"));
this.trainSet = new BufferedMLDataSet(new File(Constants.TMP, "train.egb"));
this.validSet = new BufferedMLDataSet(new File(Constants.TMP, "valid.egb"));
int inputSize = masterDataSet.getInputSize();
int idealSize = masterDataSet.getIdealSize();
((BufferedMLDataSet) sampledDataSet).beginLoad(inputSize, idealSize);
((BufferedMLDataSet) trainSet).beginLoad(inputSize, idealSize);
((BufferedMLDataSet) validSet).beginLoad(inputSize, idealSize);
} else {
throw new RuntimeException("Training Option is not Valid: " + this.trainingOption);
}
// Encog 3.1
// int masterSize = masterDataSet.size();
// Encog 3.0
int masterSize = (int) masterDataSet.getRecordCount();
if (!modelConfig.isFixInitialInput()) {
// Bagging
if (modelConfig.isBaggingWithReplacement()) {
// Bagging With Replacement
int sampledSize = (int) (masterSize * baggingSampleRate);
for (int i = 0; i < sampledSize; i++) {
// Encog 3.1
// sampledDataSet.add(masterDataSet.get(random.nextInt(masterSize)));
// Encog 3.0
double[] input = new double[masterDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
masterDataSet.getRecord(random.nextInt(masterSize), pair);
sampledDataSet.add(pair);
}
} else {
// Bagging Without Replacement
for (MLDataPair pair : masterDataSet) {
if (random.nextDouble() < baggingSampleRate) {
sampledDataSet.add(pair);
}
}
}
} else {
List<Integer> list = loadSampleInput((int) (masterSize * baggingSampleRate), masterSize, modelConfig.isBaggingWithReplacement());
for (Integer i : list) {
double[] input = new double[masterDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
masterDataSet.getRecord(i, pair);
sampledDataSet.add(pair);
}
}
if (this.trainingOption.equalsIgnoreCase("D")) {
((BufferedMLDataSet) sampledDataSet).endLoad();
}
// Cross Validation
log.info("Generating Training Set and Validation Set ...");
if (!modelConfig.isFixInitialInput()) {
// Encog 3.0
for (int i = 0; i < sampledDataSet.getRecordCount(); i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
if (random.nextDouble() > crossValidationRate) {
trainSet.add(pair);
} else {
validSet.add(pair);
}
}
} else {
long sampleSize = sampledDataSet.getRecordCount();
long trainSetSize = (long) (sampleSize * (1 - crossValidationRate));
int i = 0;
for (; i < trainSetSize; i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
trainSet.add(pair);
}
for (; i < sampleSize; i++) {
double[] input = new double[sampledDataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
sampledDataSet.getRecord(i, pair);
validSet.add(pair);
}
}
if (this.trainingOption.equalsIgnoreCase("D")) {
((BufferedMLDataSet) trainSet).endLoad();
((BufferedMLDataSet) validSet).endLoad();
}
log.info(" - # Records of the Master Data Set: " + masterSize);
log.info(" - Bagging Sample Rate: " + baggingSampleRate);
log.info(" - Bagging With Replacement: " + modelConfig.isBaggingWithReplacement());
log.info(" - # Records of the Selected Data Set: " + sampledDataSet.getRecordCount());
log.info(" - Cross Validation Rate: " + crossValidationRate);
log.info(" - # Records of the Training Set: " + this.getTrainSetSize());
log.info(" - # Records of the Validation Set: " + this.getValidSetSize());
}
use of org.encog.ml.data.MLDataPair in project shifu by ShifuML.
the class AbstractTrainer method calculateMSE.
/*
* non-synchronously version update error
*
* @return the standard error
*/
public static Double calculateMSE(BasicNetwork network, MLDataSet dataSet) {
double mse = 0;
long numRecords = dataSet.getRecordCount();
for (int i = 0; i < numRecords; i++) {
// Encog 3.1
// MLDataPair pair = dataSet.get(i);
// Encog 3.0
double[] input = new double[dataSet.getInputSize()];
double[] ideal = new double[1];
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
dataSet.getRecord(i, pair);
MLData result = network.compute(pair.getInput());
double tmp = result.getData()[0] - pair.getIdeal().getData()[0];
mse += tmp * tmp;
}
mse = mse / numRecords;
return mse;
}
use of org.encog.ml.data.MLDataPair in project shifu by ShifuML.
the class DataLoadWorker method handleMsg.
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) {
if (message instanceof ScanStatsRawDataMessage) {
log.info("DataLoaderActor Starting ...");
ScanStatsRawDataMessage msg = (ScanStatsRawDataMessage) message;
Scanner scanner = msg.getScanner();
int totalMsgCnt = msg.getTotalMsgCnt();
List<String> rawDataList = readDataIntoList(scanner);
log.info("DataLoaderActor Finished: Loaded " + rawDataList.size() + " Records.");
nextActorRef.tell(new StatsPartRawDataMessage(totalMsgCnt, rawDataList), getSelf());
} else if (message instanceof ScanNormInputDataMessage) {
log.info("DataLoaderActor Starting ...");
ScanNormInputDataMessage msg = (ScanNormInputDataMessage) message;
Scanner scanner = msg.getScanner();
int totalMsgCnt = msg.getTotalMsgCnt();
List<String> rawDataList = readDataIntoList(scanner);
log.info("DataLoaderActor Finished: Loaded " + rawDataList.size() + " Records.");
nextActorRef.tell(new NormPartRawDataMessage(totalMsgCnt, rawDataList), getSelf());
} else if (message instanceof ScanTrainDataMessage) {
ScanTrainDataMessage msg = (ScanTrainDataMessage) message;
Scanner scanner = msg.getScanner();
int totalMsgCnt = msg.getTotalMsgCnt();
List<MLDataPair> mlDataPairList = readTrainingData(scanner, msg.isDryRun());
log.info("DataLoaderActor Finished: Loaded " + mlDataPairList.size() + " Records for Training.");
nextActorRef.tell(new TrainPartDataMessage(totalMsgCnt, msg.isDryRun(), mlDataPairList), getSelf());
} else if (message instanceof ScanEvalDataMessage) {
log.info("DataLoaderActor Starting ...");
ScanEvalDataMessage msg = (ScanEvalDataMessage) message;
Scanner scanner = msg.getScanner();
int streamId = msg.getStreamId();
int totalStreamCnt = msg.getTotalStreamCnt();
splitDataIntoMultiMessages(streamId, totalStreamCnt, scanner, Environment.getInt(Environment.RECORD_CNT_PER_MESSAGE, 100000));
/*
* List<String> evalDataList = readDataIntoList(scanner);
*
* log.info("DataLoaderActor Finished: Loaded " + evalDataList.size() + " Records.");
* nextActorRef.tell( new RunModelDataMessage(totalMsgCnt, evalDataList), getSelf());
*/
} else {
unhandled(message);
}
}
use of org.encog.ml.data.MLDataPair in project shifu by ShifuML.
the class DataLoadWorker method readTrainingData.
/**
* Read the normalized training data for model training
*
* @param scanner
* - input partition
* @param isDryRun
* - is for test running?
* @return List of data
*/
public List<MLDataPair> readTrainingData(Scanner scanner, boolean isDryRun) {
List<MLDataPair> mlDataPairList = new ArrayList<MLDataPair>();
int numSelected = 0;
for (ColumnConfig config : columnConfigList) {
if (config.isFinalSelect()) {
numSelected++;
}
}
int cnt = 0;
while (scanner.hasNextLine()) {
if ((cnt++) % 100000 == 0) {
log.info("Read " + (cnt) + " Records.");
}
String line = scanner.nextLine();
if (isDryRun) {
MLDataPair dummyPair = new BasicMLDataPair(new BasicMLData(new double[1]), new BasicMLData(new double[1]));
mlDataPairList.add(dummyPair);
continue;
}
// the normalized training data is separated by | by default
double[] inputs = new double[numSelected];
double[] ideal = new double[1];
double significance = 0.0d;
int index = 0, inputsIndex = 0, outputIndex = 0;
for (String input : DEFAULT_SPLITTER.split(line.trim())) {
double doubleValue = NumberFormatUtils.getDouble(input.trim(), 0.0d);
if (index == this.columnConfigList.size()) {
significance = NumberFormatUtils.getDouble(input.trim(), CommonConstants.DEFAULT_SIGNIFICANCE_VALUE);
break;
} else {
ColumnConfig columnConfig = this.columnConfigList.get(index);
if (columnConfig != null && columnConfig.isTarget()) {
ideal[outputIndex++] = doubleValue;
} else {
if (this.inputNodeCount == this.candidateCount) {
// all variables are not set final-select
if (CommonUtils.isGoodCandidate(columnConfig)) {
inputs[inputsIndex++] = doubleValue;
}
} else {
// final select some variables
if (columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
inputs[inputsIndex++] = doubleValue;
}
}
}
}
index++;
}
MLDataPair pair = new BasicMLDataPair(new BasicMLData(inputs), new BasicMLData(ideal));
pair.setSignificance(significance);
mlDataPairList.add(pair);
}
return mlDataPairList;
}
use of org.encog.ml.data.MLDataPair in project shifu by ShifuML.
the class TrainDataPrepWorker method handleMsg.
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) throws IOException {
if (message instanceof TrainPartDataMessage) {
log.debug("Received value object list for training model.");
TrainPartDataMessage msg = (TrainPartDataMessage) message;
for (MLDataPair mlDataPir : msg.getMlDataPairList()) {
if (modelConfig.isTrainOnDisk() && !initialized) {
int inputSize = mlDataPir.getInput().size();
int idealSize = mlDataPir.getIdeal().size();
((BufferedMLDataSet) masterDataSet).beginLoad(inputSize, idealSize);
initialized = true;
}
masterDataSet.add(mlDataPir);
}
receivedMsgCnt++;
log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
if (receivedMsgCnt == msg.getTotalMsgCnt()) {
if (modelConfig.isTrainOnDisk() && initialized) {
((BufferedMLDataSet) masterDataSet).endLoad();
}
for (AbstractTrainer trainer : trainers) {
// if the trainOnDisk is true, setting the "D" option
if (modelConfig.isTrainOnDisk()) {
trainer.setTrainingOption("D");
}
trainer.setDataSet(masterDataSet);
nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
}
if (modelConfig.isTrainOnDisk() && initialized) {
masterDataSet.close();
masterDataSet = null;
}
}
} else if (message instanceof StatsPartRawDataMessage) {
StatsPartRawDataMessage msg = (StatsPartRawDataMessage) message;
receivedMsgCnt++;
log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
if (receivedMsgCnt == msg.getTotalMsgCnt()) {
for (AbstractTrainer trainer : trainers) {
// ((DecisionTreeTrainer)trainer).setDataSet(rawInstanceList);
nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
}
}
} else {
unhandled(message);
}
}
Aggregations