use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class DataDictionaryCreator method build.
@Override
public DataDictionary build(BasicML basicML) {
DataDictionary dict = new DataDictionary();
List<DataField> fields = new ArrayList<DataField>();
boolean isSegExpansionMode = columnConfigList.size() > datasetHeaders.length;
int segSize = segmentExpansions.size();
if (basicML != null && basicML instanceof BasicFloatNetwork) {
BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
Set<Integer> featureSet = bfn.getFeatureSet();
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.getColumnNum() >= datasetHeaders.length) {
// in order
break;
}
if (isConcise) {
if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum())) || columnConfig.isTarget()) {
fields.add(convertColumnToDataField(columnConfig));
} else if (isSegExpansionMode) {
// even current column not selected, if segment column selected, we should keep raw column
for (int i = 0; i < segSize; i++) {
int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
ColumnConfig cc = columnConfigList.get(newIndex);
if (cc.isFinalSelect()) {
// if one segment feature is selected, we should put raw column in
fields.add(convertColumnToDataField(columnConfig));
break;
}
}
}
} else {
fields.add(convertColumnToDataField(columnConfig));
}
}
} else {
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.getColumnNum() >= datasetHeaders.length) {
// in order
break;
}
if (isConcise) {
if (columnConfig.isFinalSelect() || columnConfig.isTarget()) {
fields.add(convertColumnToDataField(columnConfig));
} else if (isSegExpansionMode) {
// even current column not selected, if segment column selected, we should keep raw column
for (int i = 0; i < segSize; i++) {
int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
ColumnConfig cc = columnConfigList.get(newIndex);
if (cc.isFinalSelect()) {
// if one segment feature is selected, we should put raw column in
fields.add(convertColumnToDataField(columnConfig));
break;
}
}
}
} else {
fields.add(convertColumnToDataField(columnConfig));
}
}
}
dict.addDataFields(fields.toArray(new DataField[fields.size()]));
dict.setNumberOfFields(fields.size());
return dict;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class PostTrainWorker method handleMsg.
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) {
if (message instanceof ColumnScoreMessage) {
ColumnScoreMessage msg = (ColumnScoreMessage) message;
colScoreList.addAll(msg.getColScoreList());
receivedMsgCnt++;
log.debug("Received " + receivedMsgCnt + " messages, total message count is:" + msg.getTotalMsgCnt());
if (receivedMsgCnt == msg.getTotalMsgCnt()) {
// received all message, start to calculate
int columnNum = msg.getColumnNum();
ColumnConfig config = columnConfigList.get(columnNum);
Double[] binScore = new Double[config.getBinLength()];
Integer[] binCount = new Integer[config.getBinLength()];
for (int i = 0; i < binScore.length; i++) {
binScore[i] = 0.0;
binCount[i] = 0;
}
for (ColumnScoreObject colScore : colScoreList) {
int binNum = BinUtils.getBinNum(config, colScore.getColumnVal());
binScore[binNum] += Double.valueOf(colScore.getAvgScore());
binCount[binNum]++;
}
List<Integer> binAvgScore = new ArrayList<Integer>();
for (int i = 0; i < binScore.length; i++) {
binScore[i] /= binCount[i];
binAvgScore.add((int) Math.round(binScore[i]));
}
config.setBinAvgScore(binAvgScore);
nextActorRef.tell(new StatsResultMessage(config), getSelf());
}
} else {
unhandled(message);
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class DataLoadWorker method readTrainingData.
/**
* Read the normalized training data for model training
*
* @param scanner
* - input partition
* @param isDryRun
* - is for test running?
* @return List of data
*/
public List<MLDataPair> readTrainingData(Scanner scanner, boolean isDryRun) {
List<MLDataPair> mlDataPairList = new ArrayList<MLDataPair>();
int numSelected = 0;
for (ColumnConfig config : columnConfigList) {
if (config.isFinalSelect()) {
numSelected++;
}
}
int cnt = 0;
while (scanner.hasNextLine()) {
if ((cnt++) % 100000 == 0) {
log.info("Read " + (cnt) + " Records.");
}
String line = scanner.nextLine();
if (isDryRun) {
MLDataPair dummyPair = new BasicMLDataPair(new BasicMLData(new double[1]), new BasicMLData(new double[1]));
mlDataPairList.add(dummyPair);
continue;
}
// the normalized training data is separated by | by default
double[] inputs = new double[numSelected];
double[] ideal = new double[1];
double significance = 0.0d;
int index = 0, inputsIndex = 0, outputIndex = 0;
for (String input : DEFAULT_SPLITTER.split(line.trim())) {
double doubleValue = NumberFormatUtils.getDouble(input.trim(), 0.0d);
if (index == this.columnConfigList.size()) {
significance = NumberFormatUtils.getDouble(input.trim(), CommonConstants.DEFAULT_SIGNIFICANCE_VALUE);
break;
} else {
ColumnConfig columnConfig = this.columnConfigList.get(index);
if (columnConfig != null && columnConfig.isTarget()) {
ideal[outputIndex++] = doubleValue;
} else {
if (this.inputNodeCount == this.candidateCount) {
// all variables are not set final-select
if (CommonUtils.isGoodCandidate(columnConfig, super.hasCandidates)) {
inputs[inputsIndex++] = doubleValue;
}
} else {
// final select some variables
if (columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
inputs[inputsIndex++] = doubleValue;
}
}
}
}
index++;
}
MLDataPair pair = new BasicMLDataPair(new BasicMLData(inputs), new BasicMLData(ideal));
pair.setSignificance(significance);
mlDataPairList.add(pair);
}
return mlDataPairList;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class DataNormalizeWorker method normalizeRecord.
/**
* Normalize the training data record
*
* @param rfs
* - record fields
* @return the data after normalization
*/
private List<Double> normalizeRecord(String[] rfs) {
List<Double> retDouList = new ArrayList<Double>();
if (rfs == null || rfs.length == 0) {
return null;
}
String tag = CommonUtils.trimTag(rfs[this.targetColumnNum]);
boolean isNotSampled = DataSampler.isNotSampled(modelConfig.getPosTags(), modelConfig.getNegTags(), modelConfig.getNormalizeSampleRate(), modelConfig.isNormalizeSampleNegOnly(), tag);
if (isNotSampled) {
return null;
}
JexlContext jc = new MapContext();
Double cutoff = modelConfig.getNormalizeStdDevCutOff();
for (int i = 0; i < rfs.length; i++) {
ColumnConfig config = columnConfigList.get(i);
if (weightExpr != null) {
jc.set(config.getColumnName(), rfs[i]);
}
if (this.targetColumnNum == i) {
if (modelConfig.getPosTags().contains(tag)) {
retDouList.add(Double.valueOf(1));
} else if (modelConfig.getNegTags().contains(tag)) {
retDouList.add(Double.valueOf(0));
} else {
log.error("Invalid data! The target value is not listed - " + tag);
// Return null to skip such record.
return null;
}
} else if (!CommonUtils.isGoodCandidate(config, super.hasCandidates)) {
retDouList.add(null);
} else {
String val = (rfs[i] == null) ? "" : rfs[i];
retDouList.addAll(Normalizer.normalize(config, val, cutoff, modelConfig.getNormalizeType()));
}
}
double weight = 1.0d;
if (weightExpr != null) {
Object result = weightExpr.evaluate(jc);
if (result instanceof Integer) {
weight = ((Integer) result).doubleValue();
} else if (result instanceof Double) {
weight = ((Double) result).doubleValue();
} else if (result instanceof String) {
// add to parse String data
try {
weight = Double.parseDouble((String) result);
} catch (NumberFormatException e) {
// Not a number, use default
if (System.currentTimeMillis() % 100 == 0) {
log.warn("Weight column type is String and value cannot be parsed with {}, use default 1.0d.", result);
}
weight = 1.0d;
}
}
}
retDouList.add(weight);
return retDouList;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class VarSelWorker method getNormalizedColumnIdList.
private List<Integer> getNormalizedColumnIdList() {
List<Integer> normalizedColumnIdList = new ArrayList<Integer>();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig config : columnConfigList) {
if (CommonUtils.isGoodCandidate(config, hasCandidates)) {
normalizedColumnIdList.add(config.getColumnNum());
}
}
return normalizedColumnIdList;
}
Aggregations