Search in sources :

Example 81 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class VarSelWorker method init.

@Override
public void init(WorkerContext<VarSelMasterResult, VarSelWorkerResult> workerContext) {
    Properties props = workerContext.getProps();
    try {
        RawSourceData.SourceType sourceType = RawSourceData.SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, RawSourceData.SourceType.HDFS.toString()));
        this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
        this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
        String conductorClsName = props.getProperty(Constants.VAR_SEL_WORKER_CONDUCTOR);
        this.workerConductor = (AbstractWorkerConductor) Class.forName(conductorClsName).getDeclaredConstructor(ModelConfig.class, List.class).newInstance(this.modelConfig, this.columnConfigList);
    } catch (IOException e) {
        throw new RuntimeException("Fail to load ModelConfig or List<ColumnConfig>", e);
    } catch (ClassNotFoundException e) {
        throw new RuntimeException("Invalid Master Conductor class", e);
    } catch (InstantiationException e) {
        throw new RuntimeException("Fail to create instance", e);
    } catch (IllegalAccessException e) {
        throw new RuntimeException("Illegal access when creating instance", e);
    } catch (NoSuchMethodException e) {
        throw new RuntimeException("Fail to call method when creating instance", e);
    } catch (InvocationTargetException e) {
        throw new RuntimeException("Fail to invoke when creating instance", e);
    }
    List<Integer> normalizedColumnIdList = this.getNormalizedColumnIdList();
    this.inputNodeCount = normalizedColumnIdList.size();
    this.outputNodeCount = this.getTargetColumnCount();
    trainingDataSet = new TrainingDataSet(normalizedColumnIdList);
    try {
        dataPurifier = new DataPurifier(modelConfig, false);
    } catch (IOException e) {
        throw new RuntimeException("Fail to create DataPurifier", e);
    }
    this.targetColumnId = CommonUtils.getTargetColumnNum(this.columnConfigList);
    if (StringUtils.isNotBlank(modelConfig.getWeightColumnName())) {
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.getColumnName().equalsIgnoreCase(modelConfig.getWeightColumnName().trim())) {
                this.weightColumnId = columnConfig.getColumnNum();
                break;
            }
        }
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) IOException(java.io.IOException) Properties(java.util.Properties) RawSourceData(ml.shifu.shifu.container.obj.RawSourceData) InvocationTargetException(java.lang.reflect.InvocationTargetException) ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) DataPurifier(ml.shifu.shifu.core.DataPurifier) TrainingDataSet(ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet) ArrayList(java.util.ArrayList) List(java.util.List)

Example 82 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class MiningSchemaCreator method build.

@Override
public MiningSchema build(BasicML basicML) {
    MiningSchema miningSchema = new MiningSchema();
    boolean isSegExpansionMode = columnConfigList.size() > datasetHeaders.length;
    int segSize = segmentExpansions.size();
    if (basicML != null && basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.getColumnNum() >= datasetHeaders.length) {
                // in order
                break;
            }
            if (isActiveColumn(featureSet, columnConfig)) {
                if (columnConfig.isTarget()) {
                    List<MiningField> miningFields = createTargetMingFields(columnConfig);
                    miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
                } else {
                    miningSchema.addMiningFields(createActiveMingFields(columnConfig));
                }
            } else if (isSegExpansionMode) {
                // even current column not selected, if segment column selected, we should keep raw column
                for (int i = 0; i < segSize; i++) {
                    int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
                    ColumnConfig cc = columnConfigList.get(newIndex);
                    if (cc.isFinalSelect()) {
                        // if one segment feature is selected, we should put raw column in
                        if (columnConfig.isTarget()) {
                            List<MiningField> miningFields = createTargetMingFields(columnConfig);
                            miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
                        } else {
                            miningSchema.addMiningFields(createActiveMingFields(columnConfig));
                        }
                        break;
                    }
                }
            }
        }
    } else {
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.getColumnNum() >= datasetHeaders.length) {
                // in order
                break;
            }
            // FIXME, if no variable is selected
            if (columnConfig.isFinalSelect() || columnConfig.isTarget()) {
                if (columnConfig.isTarget()) {
                    List<MiningField> miningFields = createTargetMingFields(columnConfig);
                    miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
                } else {
                    miningSchema.addMiningFields(createActiveMingFields(columnConfig));
                }
            } else if (isSegExpansionMode) {
                // even current column not selected, if segment column selected, we should keep raw column
                for (int i = 0; i < segSize; i++) {
                    int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
                    ColumnConfig cc = columnConfigList.get(newIndex);
                    if (cc.isFinalSelect()) {
                        // if one segment feature is selected, we should put raw column in
                        if (columnConfig.isTarget()) {
                            List<MiningField> miningFields = createTargetMingFields(columnConfig);
                            miningSchema.addMiningFields(miningFields.toArray(new MiningField[miningFields.size()]));
                        } else {
                            miningSchema.addMiningFields(createActiveMingFields(columnConfig));
                        }
                        break;
                    }
                }
            }
        }
    }
    return miningSchema;
}
Also used : MiningField(org.dmg.pmml.MiningField) MiningSchema(org.dmg.pmml.MiningSchema) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) List(java.util.List) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)

Example 83 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class PersistWideAndDeep method readColumnConfig.

private static ColumnConfig readColumnConfig(DataInputStream dis) throws IOException {
    ColumnConfig columnConfig = new ColumnConfig();
    columnConfig.read(dis);
    return columnConfig;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 84 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class WDLWorker method load.

/**
 * Logic to load data into memory list which includes float array for numerical features and sparse object array for
 * categorical features.
 */
@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<WDLParams, WDLParams> context) {
    if ((++this.count) % 5000 == 0) {
        LOG.info("Read {} records.", this.count);
    }
    // hashcode for fixed input split in train and validation
    long hashcode = 0;
    float[] inputs = new float[this.numInputs];
    this.cateInputs = (int) this.columnConfigList.stream().filter(ColumnConfig::isCategorical).count();
    SparseInput[] cateInputs = new SparseInput[this.cateInputs];
    float ideal = 0f, significance = 1f;
    int index = 0, numIndex = 0, cateIndex = 0;
    // use guava Splitter to iterate only once
    for (String input : this.splitter.split(currentValue.getWritable().toString())) {
        if (index == this.columnConfigList.size()) {
            significance = getWeightValue(input);
            // the last field is significance, break here
            break;
        } else {
            ColumnConfig config = this.columnConfigList.get(index);
            if (config != null && config.isTarget()) {
                ideal = getFloatValue(input);
            } else {
                // final select some variables but meta and target are not included
                if (validColumn(config)) {
                    if (config.isNumerical()) {
                        inputs[numIndex] = getFloatValue(input);
                        this.inputIndexMap.putIfAbsent(config.getColumnNum(), numIndex++);
                    } else if (config.isCategorical()) {
                        cateInputs[cateIndex] = new SparseInput(config.getColumnNum(), getCateIndex(input, config));
                        this.inputIndexMap.putIfAbsent(config.getColumnNum(), cateIndex++);
                    }
                    hashcode = hashcode * 31 + input.hashCode();
                }
            }
        }
        index += 1;
    }
    // output delimiter in norm can be set by user now and if user set a special one later changed, this exception
    // is helped to quick find such issue.
    validateInputLength(context, inputs, numIndex);
    // sample negative only logic here
    if (sampleNegOnly(hashcode, ideal)) {
        return;
    }
    // up sampling logic, just add more weights while bagging sampling rate is still not changed
    if (modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal, 1d) == 0) {
        // ideal == 1 means positive tags; sample + 1 to avoid sample count to 0
        significance = significance * (this.upSampleRng.sample() + 1);
    }
    Data data = new Data(inputs, cateInputs, significance, ideal);
    // split into validation and training data set according to validation rate
    boolean isInTraining = this.addDataPairToDataSet(hashcode, data, context.getAttachment());
    // update some positive or negative selected count in metrics
    this.updateMetrics(data, isInTraining);
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 85 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class WDLWorker method initCateIndexMap.

private void initCateIndexMap() {
    this.columnCategoryIndexMapping = new HashMap<Integer, Map<String, Integer>>();
    for (ColumnConfig config : this.columnConfigList) {
        if (config.isCategorical() && config.getBinCategory() != null) {
            Map<String, Integer> tmpMap = new HashMap<String, Integer>();
            for (int i = 0; i < config.getBinCategory().size(); i++) {
                List<String> catVals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
                for (String cval : catVals) {
                    tmpMap.put(cval, i);
                }
            }
            this.columnCategoryIndexMapping.put(config.getColumnNum(), tmpMap);
        }
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) ConcurrentMap(java.util.concurrent.ConcurrentMap) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap)

Aggregations

ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException (java.io.IOException)16 HashMap (java.util.HashMap)12 Tuple (org.apache.pig.data.Tuple)10 File (java.io.File)8 NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag (org.apache.pig.data.DataBag)7 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData (org.encog.ml.data.basic.BasicMLData)5 BufferedWriter (java.io.BufferedWriter)3 FileInputStream (java.io.FileInputStream)3