use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class DataPrepareWorker method convertRawDataIntoValueObject.
/*
* Convert raw data into @ValueObject for calculating stats
*
* @param rawDataList
* - raw data for training
* @param columnVoListMap
* <column-id --> @ValueObject list>
* @throws ShifuException
* if the data field length is not equal header length
*/
private DataPrepareStatsResult convertRawDataIntoValueObject(List<String> rawDataList, Map<Integer, List<ValueObject>> columnVoListMap) throws ShifuException {
double sampleRate = modelConfig.getBinningSampleRate();
long total = 0l;
Map<Integer, Long> missingMap = new HashMap<Integer, Long>();
for (String line : rawDataList) {
total++;
String[] raw = CommonUtils.split(line, modelConfig.getDataSetDelimiter());
if (raw.length != columnConfigList.size()) {
log.error("Expected Columns: " + columnConfigList.size() + ", but got: " + raw.length);
throw new ShifuException(ShifuErrorCode.ERROR_NO_EQUAL_COLCONFIG);
}
String tag = CommonUtils.trimTag(raw[targetColumnNum]);
if (modelConfig.isBinningSampleNegOnly()) {
if (modelConfig.getNegTags().contains(tag) && random.nextDouble() > sampleRate) {
continue;
}
} else {
if (random.nextDouble() > sampleRate) {
continue;
}
}
for (int i = 0; i < raw.length; i++) {
if (!columnNumToActorMap.containsKey(i)) {
// ignore non-used columns
continue;
}
ValueObject vo = new ValueObject();
if (i >= columnConfigList.size()) {
log.error("The input size is longer than expected, need to check your data");
continue;
}
ColumnConfig config = columnConfigList.get(i);
if (config.isNumerical()) {
// NUMERICAL
try {
vo.setValue(Double.valueOf(raw[i].trim()));
vo.setRaw(null);
} catch (Exception e) {
log.debug("Column " + config.getColumnNum() + ": " + config.getColumnName() + " is expected to be NUMERICAL, however received: " + raw[i]);
incMap(i, missingMap);
continue;
}
} else if (config.isCategorical()) {
// CATEGORICAL
if (raw[i] == null || StringUtils.isEmpty(raw[i]) || modelConfig.getDataSet().getMissingOrInvalidValues().contains(raw[i].toLowerCase().trim())) {
incMap(i, missingMap);
}
vo.setRaw(raw[i].trim());
vo.setValue(null);
} else {
// AUTO TYPE
try {
vo.setValue(Double.valueOf(raw[i]));
vo.setRaw(null);
} catch (Exception e) {
incMap(i, missingMap);
vo.setRaw(raw[i]);
vo.setValue(null);
}
}
if (this.weightedColumnNum != -1) {
try {
vo.setWeight(Double.valueOf(raw[weightedColumnNum]));
} catch (NumberFormatException e) {
vo.setWeight(1.0);
}
vo.setWeight(1.0);
}
vo.setTag(tag);
List<ValueObject> voList = columnVoListMap.get(i);
if (voList == null) {
voList = new ArrayList<ValueObject>();
columnVoListMap.put(i, voList);
}
voList.add(vo);
}
}
DataPrepareStatsResult rt = new DataPrepareStatsResult(total, missingMap);
return rt;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class DataPrepareWorker method convertModelResultIntoColScore.
/*
* Convert model result data into column-based
*
* @param evalDataList
* evaluation result list
* @param columnScoreListMap
* (column-id, List<ColumnScoreObject>)
*/
private void convertModelResultIntoColScore(List<CaseScoreResult> scoreResultList, Map<Integer, List<ColumnScoreObject>> columnScoreListMap) {
for (CaseScoreResult scoreResult : scoreResultList) {
Map<String, String> rawDataMap = CommonUtils.convertDataIntoMap(scoreResult.getInputData(), super.modelConfig.getDataSetDelimiter(), this.trainDataHeader);
for (ColumnConfig config : columnConfigList) {
if (config.isFinalSelect()) {
ColumnScoreObject columnScore = new ColumnScoreObject(config.getColumnNum(), rawDataMap.get(config.getColumnName()));
columnScore.setScores(scoreResult.getScores());
columnScore.setMaxScore(scoreResult.getMaxScore());
columnScore.setMinScore(scoreResult.getMinScore());
columnScore.setAvgScore(scoreResult.getAvgScore());
List<ColumnScoreObject> csList = columnScoreListMap.get(config.getColumnNum());
if (csList == null) {
csList = new ArrayList<ColumnScoreObject>();
columnScoreListMap.put(config.getColumnNum(), csList);
}
csList.add(columnScore);
}
}
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class PostTrainActor method onReceive.
/* (non-Javadoc)
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void onReceive(Object message) throws Exception {
if (message instanceof AkkaActorInputMessage) {
resultCnt = 0;
AkkaActorInputMessage msg = (AkkaActorInputMessage) message;
List<Scanner> scanners = msg.getScanners();
log.debug("Num of Scanners: " + scanners.size());
int streamId = 0;
for (Scanner scanner : scanners) {
dataLoadRef.tell(new ScanEvalDataMessage(streamId++, scanners.size(), scanner), getSelf());
}
} else if (message instanceof StatsResultMessage) {
StatsResultMessage statsRstMsg = (StatsResultMessage) message;
ColumnConfig columnConfig = statsRstMsg.getColumnConfig();
columnConfigList.set(columnConfig.getColumnNum(), columnConfig);
resultCnt++;
log.debug("Received " + resultCnt + " messages, expected message count is:" + expectedResultCnt);
if (resultCnt == expectedResultCnt) {
log.info("Finished post-train.");
PathFinder pathFinder = new PathFinder(modelConfig);
JSONUtils.writeValue(new File(pathFinder.getColumnConfigPath()), columnConfigList);
getContext().system().shutdown();
}
} else if (message instanceof ExceptionMessage) {
// since some children actors meet some exception, shutdown the system
ExceptionMessage msg = (ExceptionMessage) message;
getContext().system().shutdown();
// and wrapper the exception into Return status
addExceptionIntoCondition(msg.getException());
} else {
unhandled(message);
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class StatsCalculateWorker method handleMsg.
@Override
public void handleMsg(Object message) {
if (message instanceof StatsValueObjectMessage) {
StatsValueObjectMessage statsVoMessage = (StatsValueObjectMessage) message;
voList.addAll(statsVoMessage.getVoList());
this.missing += statsVoMessage.getMissing();
this.total += statsVoMessage.getTotal();
receivedMsgCnt++;
if (receivedMsgCnt == statsVoMessage.getTotalMsgCnt()) {
ColumnConfig columnConfig = columnConfigList.get(statsVoMessage.getColumnNum());
calculateColumnStats(columnConfig, voList);
columnConfig.setMissingCnt(this.missing);
columnConfig.setTotalCount(this.total);
columnConfig.setMissingPercentage((double) missing / total);
parentActorRef.tell(new StatsResultMessage(columnConfig), this.getSelf());
}
} else {
unhandled(message);
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class InitStep method initColumnConfigList.
private int initColumnConfigList() throws IOException {
String[] fields = null;
boolean isSchemaProvided = true;
if (StringUtils.isNotBlank(modelConfig.getHeaderPath())) {
fields = CommonUtils.getHeaders(modelConfig.getHeaderPath(), modelConfig.getHeaderDelimiter(), modelConfig.getDataSet().getSource());
String[] dataInFirstLine = CommonUtils.takeFirstLine(modelConfig.getDataSetRawPath(), modelConfig.getDataSetDelimiter(), modelConfig.getDataSet().getSource());
if (fields.length != dataInFirstLine.length) {
throw new IllegalArgumentException("Header length and data length are not consistent, please check you header setting and data set setting.");
}
} else {
fields = CommonUtils.takeFirstLine(modelConfig.getDataSetRawPath(), StringUtils.isBlank(modelConfig.getHeaderDelimiter()) ? modelConfig.getDataSetDelimiter() : modelConfig.getHeaderDelimiter(), modelConfig.getDataSet().getSource());
if (StringUtils.join(fields, "").contains(modelConfig.getTargetColumnName())) {
// if first line contains target column name, we guess it is csv format and first line is header.
isSchemaProvided = true;
// first line of data meaning second line in data files excluding first header line
String[] dataInFirstLine = CommonUtils.takeFirstTwoLines(modelConfig.getDataSetRawPath(), StringUtils.isBlank(modelConfig.getHeaderDelimiter()) ? modelConfig.getDataSetDelimiter() : modelConfig.getHeaderDelimiter(), modelConfig.getDataSet().getSource())[1];
if (dataInFirstLine != null && fields.length != dataInFirstLine.length) {
throw new IllegalArgumentException("Header length and data length are not consistent, please check you header setting and data set setting.");
}
LOG.warn("No header path is provided, we will try to read first line and detect schema.");
LOG.warn("Schema in ColumnConfig.json are named as first line of data set path.");
} else {
isSchemaProvided = false;
LOG.warn("No header path is provided, we will try to read first line and detect schema.");
LOG.warn("Schema in ColumnConfig.json are named as index 0, 1, 2, 3 ...");
LOG.warn("Please make sure weight column and tag column are also taking index as name.");
}
}
columnConfigList = new ArrayList<ColumnConfig>();
for (int i = 0; i < fields.length; i++) {
ColumnConfig config = new ColumnConfig();
config.setColumnNum(i);
if (isSchemaProvided) {
fields[i] = CommonUtils.normColumnName(fields[i]);
config.setColumnName(CommonUtils.getRelativePigHeaderColumnName(fields[i]));
} else {
config.setColumnName(i + "");
}
columnConfigList.add(config);
}
ColumnConfigUpdater.updateColumnConfigFlags(modelConfig, columnConfigList, ModelStep.INIT);
boolean hasTarget = false;
for (ColumnConfig config : columnConfigList) {
if (config.isTarget()) {
hasTarget = true;
}
}
if (!hasTarget) {
LOG.error("Target is not valid: " + modelConfig.getTargetColumnName());
LOG.error("Please check your header file {} and your header delimiter {}", modelConfig.getHeaderPath(), modelConfig.getHeaderDelimiter());
return 1;
}
return 0;
}
Aggregations