use of ml.shifu.shifu.message.TrainInstanceMessage in project shifu by ShifuML.
the class TrainDataPrepWorker method handleMsg.
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) throws IOException {
if (message instanceof TrainPartDataMessage) {
log.debug("Received value object list for training model.");
TrainPartDataMessage msg = (TrainPartDataMessage) message;
for (MLDataPair mlDataPir : msg.getMlDataPairList()) {
if (modelConfig.isTrainOnDisk() && !initialized) {
int inputSize = mlDataPir.getInput().size();
int idealSize = mlDataPir.getIdeal().size();
((BufferedMLDataSet) masterDataSet).beginLoad(inputSize, idealSize);
initialized = true;
}
masterDataSet.add(mlDataPir);
}
receivedMsgCnt++;
log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
if (receivedMsgCnt == msg.getTotalMsgCnt()) {
if (modelConfig.isTrainOnDisk() && initialized) {
((BufferedMLDataSet) masterDataSet).endLoad();
}
for (AbstractTrainer trainer : trainers) {
// if the trainOnDisk is true, setting the "D" option
if (modelConfig.isTrainOnDisk()) {
trainer.setTrainingOption("D");
}
trainer.setDataSet(masterDataSet);
nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
}
if (modelConfig.isTrainOnDisk() && initialized) {
masterDataSet.close();
masterDataSet = null;
}
}
} else if (message instanceof StatsPartRawDataMessage) {
StatsPartRawDataMessage msg = (StatsPartRawDataMessage) message;
receivedMsgCnt++;
log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
if (receivedMsgCnt == msg.getTotalMsgCnt()) {
for (AbstractTrainer trainer : trainers) {
// ((DecisionTreeTrainer)trainer).setDataSet(rawInstanceList);
nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
}
}
} else {
unhandled(message);
}
}
use of ml.shifu.shifu.message.TrainInstanceMessage in project shifu by ShifuML.
the class TrainModelWorker method handleMsg.
/* (non-Javadoc)
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) throws IOException {
if (message instanceof TrainInstanceMessage) {
log.info("Received train data for model training");
TrainInstanceMessage msg = (TrainInstanceMessage) message;
msg.getTrainer().train();
nextActorRef.tell(new TrainResultMessage(), getSelf());
} else {
unhandled(message);
}
}
Aggregations