use of ml.shifu.shifu.message.ColumnScoreMessage in project shifu by ShifuML.
the class DataPrepareWorker method handleMsg.
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) {
if (message instanceof StatsPartRawDataMessage) {
StatsPartRawDataMessage partData = (StatsPartRawDataMessage) message;
Map<Integer, List<ValueObject>> columnVoListMap = buildColumnVoListMap(partData.getRawDataList().size());
DataPrepareStatsResult rt = convertRawDataIntoValueObject(partData.getRawDataList(), columnVoListMap);
int totalMsgCnt = partData.getTotalMsgCnt();
for (Map.Entry<Integer, List<ValueObject>> entry : columnVoListMap.entrySet()) {
Integer columnNum = entry.getKey();
log.info("send {} with {} value object", columnNum, entry.getValue().size());
columnNumToActorMap.get(columnNum).tell(new StatsValueObjectMessage(totalMsgCnt, columnNum, entry.getValue(), rt.getMissingMap().containsKey(columnNum) ? rt.getMissingMap().get(columnNum) : 0, rt.getTotal()), getSelf());
}
} else if (message instanceof RunModelResultMessage) {
RunModelResultMessage msg = (RunModelResultMessage) message;
Map<Integer, List<ColumnScoreObject>> columnScoreListMap = buildColumnScoreListMap();
convertModelResultIntoColScore(msg.getScoreResultList(), columnScoreListMap);
int totalMsgCnt = msg.getTotalStreamCnt();
for (Entry<Integer, List<ColumnScoreObject>> column : columnScoreListMap.entrySet()) {
columnNumToActorMap.get(column.getKey()).tell(new ColumnScoreMessage(totalMsgCnt, column.getKey(), column.getValue()), getSelf());
}
} else {
unhandled(message);
}
}
use of ml.shifu.shifu.message.ColumnScoreMessage in project shifu by ShifuML.
the class PostTrainWorker method handleMsg.
/* (non-Javadoc)
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) {
if (message instanceof ColumnScoreMessage) {
ColumnScoreMessage msg = (ColumnScoreMessage) message;
colScoreList.addAll(msg.getColScoreList());
receivedMsgCnt++;
log.debug("Received " + receivedMsgCnt + " messages, total message count is:" + msg.getTotalMsgCnt());
if (receivedMsgCnt == msg.getTotalMsgCnt()) {
// received all message, start to calculate
int columnNum = msg.getColumnNum();
ColumnConfig config = columnConfigList.get(columnNum);
Double[] binScore = new Double[config.getBinLength()];
Integer[] binCount = new Integer[config.getBinLength()];
for (int i = 0; i < binScore.length; i++) {
binScore[i] = 0.0;
binCount[i] = 0;
}
for (ColumnScoreObject colScore : colScoreList) {
int binNum = CommonUtils.getBinNum(config, colScore.getColumnVal());
binScore[binNum] += Double.valueOf(colScore.getAvgScore());
binCount[binNum]++;
}
List<Integer> binAvgScore = new ArrayList<Integer>();
for (int i = 0; i < binScore.length; i++) {
binScore[i] /= binCount[i];
binAvgScore.add((int) Math.round(binScore[i]));
}
config.setBinAvgScore(binAvgScore);
nextActorRef.tell(new StatsResultMessage(config), getSelf());
}
} else {
unhandled(message);
}
}
Aggregations