use of ml.shifu.shifu.message.StatsPartRawDataMessage in project shifu by ShifuML.
the class TrainDataPrepWorker method handleMsg.
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) throws IOException {
if (message instanceof TrainPartDataMessage) {
log.debug("Received value object list for training model.");
TrainPartDataMessage msg = (TrainPartDataMessage) message;
for (MLDataPair mlDataPir : msg.getMlDataPairList()) {
if (modelConfig.isTrainOnDisk() && !initialized) {
int inputSize = mlDataPir.getInput().size();
int idealSize = mlDataPir.getIdeal().size();
((BufferedMLDataSet) masterDataSet).beginLoad(inputSize, idealSize);
initialized = true;
}
masterDataSet.add(mlDataPir);
}
receivedMsgCnt++;
log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
if (receivedMsgCnt == msg.getTotalMsgCnt()) {
if (modelConfig.isTrainOnDisk() && initialized) {
((BufferedMLDataSet) masterDataSet).endLoad();
}
for (AbstractTrainer trainer : trainers) {
// if the trainOnDisk is true, setting the "D" option
if (modelConfig.isTrainOnDisk()) {
trainer.setTrainingOption("D");
}
trainer.setDataSet(masterDataSet);
nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
}
if (modelConfig.isTrainOnDisk() && initialized) {
masterDataSet.close();
masterDataSet = null;
}
}
} else if (message instanceof StatsPartRawDataMessage) {
StatsPartRawDataMessage msg = (StatsPartRawDataMessage) message;
receivedMsgCnt++;
log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
if (receivedMsgCnt == msg.getTotalMsgCnt()) {
for (AbstractTrainer trainer : trainers) {
// ((DecisionTreeTrainer)trainer).setDataSet(rawInstanceList);
nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
}
}
} else {
unhandled(message);
}
}
use of ml.shifu.shifu.message.StatsPartRawDataMessage in project shifu by ShifuML.
the class DataLoadWorker method handleMsg.
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) {
if (message instanceof ScanStatsRawDataMessage) {
log.info("DataLoaderActor Starting ...");
ScanStatsRawDataMessage msg = (ScanStatsRawDataMessage) message;
Scanner scanner = msg.getScanner();
int totalMsgCnt = msg.getTotalMsgCnt();
List<String> rawDataList = readDataIntoList(scanner);
log.info("DataLoaderActor Finished: Loaded " + rawDataList.size() + " Records.");
nextActorRef.tell(new StatsPartRawDataMessage(totalMsgCnt, rawDataList), getSelf());
} else if (message instanceof ScanNormInputDataMessage) {
log.info("DataLoaderActor Starting ...");
ScanNormInputDataMessage msg = (ScanNormInputDataMessage) message;
Scanner scanner = msg.getScanner();
int totalMsgCnt = msg.getTotalMsgCnt();
List<String> rawDataList = readDataIntoList(scanner);
log.info("DataLoaderActor Finished: Loaded " + rawDataList.size() + " Records.");
nextActorRef.tell(new NormPartRawDataMessage(totalMsgCnt, rawDataList), getSelf());
} else if (message instanceof ScanTrainDataMessage) {
ScanTrainDataMessage msg = (ScanTrainDataMessage) message;
Scanner scanner = msg.getScanner();
int totalMsgCnt = msg.getTotalMsgCnt();
List<MLDataPair> mlDataPairList = readTrainingData(scanner, msg.isDryRun());
log.info("DataLoaderActor Finished: Loaded " + mlDataPairList.size() + " Records for Training.");
nextActorRef.tell(new TrainPartDataMessage(totalMsgCnt, msg.isDryRun(), mlDataPairList), getSelf());
} else if (message instanceof ScanEvalDataMessage) {
log.info("DataLoaderActor Starting ...");
ScanEvalDataMessage msg = (ScanEvalDataMessage) message;
Scanner scanner = msg.getScanner();
int streamId = msg.getStreamId();
int totalStreamCnt = msg.getTotalStreamCnt();
splitDataIntoMultiMessages(streamId, totalStreamCnt, scanner, Environment.getInt(Environment.RECORD_CNT_PER_MESSAGE, 100000));
/*
* List<String> evalDataList = readDataIntoList(scanner);
*
* log.info("DataLoaderActor Finished: Loaded " + evalDataList.size() + " Records.");
* nextActorRef.tell( new RunModelDataMessage(totalMsgCnt, evalDataList), getSelf());
*/
} else {
unhandled(message);
}
}
use of ml.shifu.shifu.message.StatsPartRawDataMessage in project shifu by ShifuML.
the class DataPrepareWorker method handleMsg.
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) {
if (message instanceof StatsPartRawDataMessage) {
StatsPartRawDataMessage partData = (StatsPartRawDataMessage) message;
Map<Integer, List<ValueObject>> columnVoListMap = buildColumnVoListMap(partData.getRawDataList().size());
DataPrepareStatsResult rt = convertRawDataIntoValueObject(partData.getRawDataList(), columnVoListMap);
int totalMsgCnt = partData.getTotalMsgCnt();
for (Map.Entry<Integer, List<ValueObject>> entry : columnVoListMap.entrySet()) {
Integer columnNum = entry.getKey();
log.info("send {} with {} value object", columnNum, entry.getValue().size());
columnNumToActorMap.get(columnNum).tell(new StatsValueObjectMessage(totalMsgCnt, columnNum, entry.getValue(), rt.getMissingMap().containsKey(columnNum) ? rt.getMissingMap().get(columnNum) : 0, rt.getTotal()), getSelf());
}
} else if (message instanceof RunModelResultMessage) {
RunModelResultMessage msg = (RunModelResultMessage) message;
Map<Integer, List<ColumnScoreObject>> columnScoreListMap = buildColumnScoreListMap();
convertModelResultIntoColScore(msg.getScoreResultList(), columnScoreListMap);
int totalMsgCnt = msg.getTotalStreamCnt();
for (Entry<Integer, List<ColumnScoreObject>> column : columnScoreListMap.entrySet()) {
columnNumToActorMap.get(column.getKey()).tell(new ColumnScoreMessage(totalMsgCnt, column.getKey(), column.getValue()), getSelf());
}
} else {
unhandled(message);
}
}
use of ml.shifu.shifu.message.StatsPartRawDataMessage in project shifu by ShifuML.
the class DataFilterWorker method handleMsg.
/* (non-Javadoc)
* @see ml.shifu.shifu.actor.worker.AbstractWorkerActor#handleMsg(java.lang.Object)
*/
@Override
public void handleMsg(Object message) throws Exception {
if (message instanceof StatsPartRawDataMessage) {
StatsPartRawDataMessage msg = (StatsPartRawDataMessage) message;
purifyData(msg.getRawDataList());
nextActorRef.tell(msg, getSelf());
} else if (message instanceof NormPartRawDataMessage) {
NormPartRawDataMessage msg = (NormPartRawDataMessage) message;
purifyData(msg.getRawDataList());
nextActorRef.tell(msg, getSelf());
} else if (message instanceof RunModelDataMessage) {
RunModelDataMessage msg = (RunModelDataMessage) message;
purifyData(msg.getEvalDataList());
nextActorRef.tell(msg, getSelf());
} else {
unhandled(message);
}
}
Aggregations