Search in sources :

Example 41 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DataDictionaryCreator method build.

@Override
public DataDictionary build(BasicML basicML) {
    DataDictionary dict = new DataDictionary();
    List<DataField> fields = new ArrayList<DataField>();
    boolean isSegExpansionMode = columnConfigList.size() > datasetHeaders.length;
    int segSize = segmentExpansions.size();
    if (basicML != null && basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.getColumnNum() >= datasetHeaders.length) {
                // in order
                break;
            }
            if (isConcise) {
                if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum())) || columnConfig.isTarget()) {
                    fields.add(convertColumnToDataField(columnConfig));
                } else if (isSegExpansionMode) {
                    // even current column not selected, if segment column selected, we should keep raw column
                    for (int i = 0; i < segSize; i++) {
                        int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
                        ColumnConfig cc = columnConfigList.get(newIndex);
                        if (cc.isFinalSelect()) {
                            // if one segment feature is selected, we should put raw column in
                            fields.add(convertColumnToDataField(columnConfig));
                            break;
                        }
                    }
                }
            } else {
                fields.add(convertColumnToDataField(columnConfig));
            }
        }
    } else {
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.getColumnNum() >= datasetHeaders.length) {
                // in order
                break;
            }
            if (isConcise) {
                if (columnConfig.isFinalSelect() || columnConfig.isTarget()) {
                    fields.add(convertColumnToDataField(columnConfig));
                } else if (isSegExpansionMode) {
                    // even current column not selected, if segment column selected, we should keep raw column
                    for (int i = 0; i < segSize; i++) {
                        int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
                        ColumnConfig cc = columnConfigList.get(newIndex);
                        if (cc.isFinalSelect()) {
                            // if one segment feature is selected, we should put raw column in
                            fields.add(convertColumnToDataField(columnConfig));
                            break;
                        }
                    }
                }
            } else {
                fields.add(convertColumnToDataField(columnConfig));
            }
        }
    }
    dict.addDataFields(fields.toArray(new DataField[fields.size()]));
    dict.setNumberOfFields(fields.size());
    return dict;
}
Also used : DataField(org.dmg.pmml.DataField) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) DataDictionary(org.dmg.pmml.DataDictionary)

Example 42 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class PostTrainWorker method handleMsg.

/*
     * (non-Javadoc)
     * 
     * @see akka.actor.UntypedActor#onReceive(java.lang.Object)
     */
@Override
public void handleMsg(Object message) {
    if (message instanceof ColumnScoreMessage) {
        ColumnScoreMessage msg = (ColumnScoreMessage) message;
        colScoreList.addAll(msg.getColScoreList());
        receivedMsgCnt++;
        log.debug("Received " + receivedMsgCnt + " messages, total message count is:" + msg.getTotalMsgCnt());
        if (receivedMsgCnt == msg.getTotalMsgCnt()) {
            // received all message, start to calculate
            int columnNum = msg.getColumnNum();
            ColumnConfig config = columnConfigList.get(columnNum);
            Double[] binScore = new Double[config.getBinLength()];
            Integer[] binCount = new Integer[config.getBinLength()];
            for (int i = 0; i < binScore.length; i++) {
                binScore[i] = 0.0;
                binCount[i] = 0;
            }
            for (ColumnScoreObject colScore : colScoreList) {
                int binNum = BinUtils.getBinNum(config, colScore.getColumnVal());
                binScore[binNum] += Double.valueOf(colScore.getAvgScore());
                binCount[binNum]++;
            }
            List<Integer> binAvgScore = new ArrayList<Integer>();
            for (int i = 0; i < binScore.length; i++) {
                binScore[i] /= binCount[i];
                binAvgScore.add((int) Math.round(binScore[i]));
            }
            config.setBinAvgScore(binAvgScore);
            nextActorRef.tell(new StatsResultMessage(config), getSelf());
        }
    } else {
        unhandled(message);
    }
}
Also used : ColumnScoreObject(ml.shifu.shifu.container.ColumnScoreObject) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) StatsResultMessage(ml.shifu.shifu.message.StatsResultMessage) ArrayList(java.util.ArrayList) ColumnScoreMessage(ml.shifu.shifu.message.ColumnScoreMessage)

Example 43 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DataLoadWorker method readTrainingData.

/**
 * Read the normalized training data for model training
 *
 * @param scanner
 *            - input partition
 * @param isDryRun
 *            - is for test running?
 * @return List of data
 */
public List<MLDataPair> readTrainingData(Scanner scanner, boolean isDryRun) {
    List<MLDataPair> mlDataPairList = new ArrayList<MLDataPair>();
    int numSelected = 0;
    for (ColumnConfig config : columnConfigList) {
        if (config.isFinalSelect()) {
            numSelected++;
        }
    }
    int cnt = 0;
    while (scanner.hasNextLine()) {
        if ((cnt++) % 100000 == 0) {
            log.info("Read " + (cnt) + " Records.");
        }
        String line = scanner.nextLine();
        if (isDryRun) {
            MLDataPair dummyPair = new BasicMLDataPair(new BasicMLData(new double[1]), new BasicMLData(new double[1]));
            mlDataPairList.add(dummyPair);
            continue;
        }
        // the normalized training data is separated by | by default
        double[] inputs = new double[numSelected];
        double[] ideal = new double[1];
        double significance = 0.0d;
        int index = 0, inputsIndex = 0, outputIndex = 0;
        for (String input : DEFAULT_SPLITTER.split(line.trim())) {
            double doubleValue = NumberFormatUtils.getDouble(input.trim(), 0.0d);
            if (index == this.columnConfigList.size()) {
                significance = NumberFormatUtils.getDouble(input.trim(), CommonConstants.DEFAULT_SIGNIFICANCE_VALUE);
                break;
            } else {
                ColumnConfig columnConfig = this.columnConfigList.get(index);
                if (columnConfig != null && columnConfig.isTarget()) {
                    ideal[outputIndex++] = doubleValue;
                } else {
                    if (this.inputNodeCount == this.candidateCount) {
                        // all variables are not set final-select
                        if (CommonUtils.isGoodCandidate(columnConfig, super.hasCandidates)) {
                            inputs[inputsIndex++] = doubleValue;
                        }
                    } else {
                        // final select some variables
                        if (columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
                            inputs[inputsIndex++] = doubleValue;
                        }
                    }
                }
            }
            index++;
        }
        MLDataPair pair = new BasicMLDataPair(new BasicMLData(inputs), new BasicMLData(ideal));
        pair.setSignificance(significance);
        mlDataPairList.add(pair);
    }
    return mlDataPairList;
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) BasicMLData(org.encog.ml.data.basic.BasicMLData) ArrayList(java.util.ArrayList)

Example 44 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DataNormalizeWorker method normalizeRecord.

/**
 * Normalize the training data record
 *
 * @param rfs
 *            - record fields
 * @return the data after normalization
 */
private List<Double> normalizeRecord(String[] rfs) {
    List<Double> retDouList = new ArrayList<Double>();
    if (rfs == null || rfs.length == 0) {
        return null;
    }
    String tag = CommonUtils.trimTag(rfs[this.targetColumnNum]);
    boolean isNotSampled = DataSampler.isNotSampled(modelConfig.getPosTags(), modelConfig.getNegTags(), modelConfig.getNormalizeSampleRate(), modelConfig.isNormalizeSampleNegOnly(), tag);
    if (isNotSampled) {
        return null;
    }
    JexlContext jc = new MapContext();
    Double cutoff = modelConfig.getNormalizeStdDevCutOff();
    for (int i = 0; i < rfs.length; i++) {
        ColumnConfig config = columnConfigList.get(i);
        if (weightExpr != null) {
            jc.set(config.getColumnName(), rfs[i]);
        }
        if (this.targetColumnNum == i) {
            if (modelConfig.getPosTags().contains(tag)) {
                retDouList.add(Double.valueOf(1));
            } else if (modelConfig.getNegTags().contains(tag)) {
                retDouList.add(Double.valueOf(0));
            } else {
                log.error("Invalid data! The target value is not listed - " + tag);
                // Return null to skip such record.
                return null;
            }
        } else if (!CommonUtils.isGoodCandidate(config, super.hasCandidates)) {
            retDouList.add(null);
        } else {
            String val = (rfs[i] == null) ? "" : rfs[i];
            retDouList.addAll(Normalizer.normalize(config, val, cutoff, modelConfig.getNormalizeType()));
        }
    }
    double weight = 1.0d;
    if (weightExpr != null) {
        Object result = weightExpr.evaluate(jc);
        if (result instanceof Integer) {
            weight = ((Integer) result).doubleValue();
        } else if (result instanceof Double) {
            weight = ((Double) result).doubleValue();
        } else if (result instanceof String) {
            // add to parse String data
            try {
                weight = Double.parseDouble((String) result);
            } catch (NumberFormatException e) {
                // Not a number, use default
                if (System.currentTimeMillis() % 100 == 0) {
                    log.warn("Weight column type is String and value cannot be parsed with {}, use default 1.0d.", result);
                }
                weight = 1.0d;
            }
        }
    }
    retDouList.add(weight);
    return retDouList;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) MapContext(org.apache.commons.jexl2.MapContext) JexlContext(org.apache.commons.jexl2.JexlContext)

Example 45 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class VarSelWorker method getNormalizedColumnIdList.

private List<Integer> getNormalizedColumnIdList() {
    List<Integer> normalizedColumnIdList = new ArrayList<Integer>();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig config : columnConfigList) {
        if (CommonUtils.isGoodCandidate(config, hasCandidates)) {
            normalizedColumnIdList.add(config.getColumnNum());
        }
    }
    return normalizedColumnIdList;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList)

Aggregations

ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException (java.io.IOException)16 HashMap (java.util.HashMap)12 Tuple (org.apache.pig.data.Tuple)10 File (java.io.File)8 NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag (org.apache.pig.data.DataBag)7 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData (org.encog.ml.data.basic.BasicMLData)5 BufferedWriter (java.io.BufferedWriter)3 FileInputStream (java.io.FileInputStream)3