Search in sources :

Example 86 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTWorker method predictNodeIndex.

private Node predictNodeIndex(Node node, Data data, boolean isForErr) {
    Node currNode = node;
    Split split = currNode.getSplit();
    // if is leaf
    if (split == null || (currNode.getLeft() == null && currNode.getRight() == null)) {
        return currNode;
    }
    ColumnConfig columnConfig = this.columnConfigList.get(split.getColumnNum());
    Node nextNode = null;
    Integer inputIndex = this.inputIndexMap.get(split.getColumnNum());
    if (inputIndex == null) {
        throw new IllegalStateException("InputIndex should not be null: Split is " + split + ", inputIndexMap is " + this.inputIndexMap + ", data is " + data);
    }
    short value = 0;
    if (columnConfig.isNumerical()) {
        short binIndex = data.inputs[inputIndex];
        value = binIndex;
        double valueToBinLowestValue = columnConfig.getBinBoundary().get(binIndex);
        if (valueToBinLowestValue < split.getThreshold()) {
            nextNode = currNode.getLeft();
        } else {
            nextNode = currNode.getRight();
        }
    } else if (columnConfig.isCategorical()) {
        short indexValue = (short) (columnConfig.getBinCategory().size());
        value = indexValue;
        if (data.inputs[inputIndex] >= 0 && data.inputs[inputIndex] < (short) (columnConfig.getBinCategory().size())) {
            indexValue = data.inputs[inputIndex];
        } else {
            // for invalid category, set to last one
            indexValue = (short) (columnConfig.getBinCategory().size());
        }
        Set<Short> childCategories = split.getLeftOrRightCategories();
        if (split.isLeft()) {
            if (childCategories.contains(indexValue)) {
                nextNode = currNode.getLeft();
            } else {
                nextNode = currNode.getRight();
            }
        } else {
            if (childCategories.contains(indexValue)) {
                nextNode = currNode.getRight();
            } else {
                nextNode = currNode.getLeft();
            }
        }
    }
    if (nextNode == null) {
        throw new IllegalStateException("NextNode is null, parent id is " + currNode.getId() + "; parent split is " + split + "; left is " + currNode.getLeft() + "; right is " + currNode.getRight() + "; value is " + value);
    }
    return predictNodeIndex(nextNode, data, isForErr);
}
Also used : Set(java.util.Set) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) GuaguaFileSplit(ml.shifu.guagua.io.GuaguaFileSplit)

Example 87 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTWorker method initTodoNodeStats.

private Map<Integer, NodeStats> initTodoNodeStats(Map<Integer, TreeNode> todoNodes) {
    Map<Integer, NodeStats> statistics = new HashMap<Integer, NodeStats>(todoNodes.size(), 1f);
    for (Map.Entry<Integer, TreeNode> entry : todoNodes.entrySet()) {
        List<Integer> features = entry.getValue().getFeatures();
        if (features.isEmpty()) {
            features = getAllValidFeatures();
        }
        Map<Integer, double[]> featureStatistics = new HashMap<Integer, double[]>(features.size(), 1f);
        for (Integer columnNum : features) {
            ColumnConfig columnConfig = this.columnConfigList.get(columnNum);
            if (columnConfig.isNumerical()) {
                // TODO, how to process null bin
                int featureStatsSize = columnConfig.getBinBoundary().size() * this.impurity.getStatsSize();
                featureStatistics.put(columnNum, new double[featureStatsSize]);
            } else if (columnConfig.isCategorical()) {
                // the last one is for invalid value category like ?, *, ...
                int featureStatsSize = (columnConfig.getBinCategory().size() + 1) * this.impurity.getStatsSize();
                featureStatistics.put(columnNum, new double[featureStatsSize]);
            }
        }
        NodeStats nodeStats = new NodeStats(entry.getValue().getTreeId(), entry.getValue().getNode().getId(), featureStatistics);
        statistics.put(entry.getKey(), nodeStats);
    }
    return statistics;
}
Also used : NodeStats(ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ConcurrentMap(java.util.concurrent.ConcurrentMap)

Example 88 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTWorker method init.

@Override
public void init(WorkerContext<DTMasterParams, DTWorkerParams> context) {
    Properties props = context.getProps();
    try {
        SourceType sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
        this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
        this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    this.columnCategoryIndexMapping = new HashMap<Integer, Map<String, Integer>>();
    for (ColumnConfig config : this.columnConfigList) {
        if (config.isCategorical()) {
            if (config.getBinCategory() != null) {
                Map<String, Integer> tmpMap = new HashMap<String, Integer>();
                for (int i = 0; i < config.getBinCategory().size(); i++) {
                    List<String> catVals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
                    for (String cval : catVals) {
                        tmpMap.put(cval, i);
                    }
                }
                this.columnCategoryIndexMapping.put(config.getColumnNum(), tmpMap);
            }
        }
    }
    this.hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    // create Splitter
    String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
    this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);
    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;
        LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation);
    }
    Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight();
    if (Double.compare(upSampleWeight, 1d) != 0 && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) {
        // set mean to upSampleWeight -1 and get sample + 1 to make sure no zero sample value
        LOG.info("Enable up sampling with weight {}.", upSampleWeight);
        this.upSampleRng = new PoissonDistribution(upSampleWeight - 1);
    }
    this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING));
    this.workerThreadCount = modelConfig.getTrain().getWorkerThreadCount();
    this.threadPool = Executors.newFixedThreadPool(this.workerThreadCount);
    // enable shut down logic
    context.addCompletionCallBack(new WorkerCompletionCallBack<DTMasterParams, DTWorkerParams>() {

        @Override
        public void callback(WorkerContext<DTMasterParams, DTWorkerParams> context) {
            DTWorker.this.threadPool.shutdownNow();
            try {
                DTWorker.this.threadPool.awaitTermination(2, TimeUnit.SECONDS);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
    });
    this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
    this.isOneVsAll = modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll();
    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
    Map<String, Object> validParams = this.modelConfig.getTrain().getParams();
    if (gs.hasHyperParam()) {
        validParams = gs.getParams(this.trainerId);
        LOG.info("Start grid search worker with params: {}", validParams);
    }
    this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString());
    double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6"));
    LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction);
    double validationRate = this.modelConfig.getValidSetRate();
    if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) {
        // fixed 0.6 and 0.4 of max memory for trainingData and validationData
        this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6), new ArrayList<Data>());
        this.validationData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4), new ArrayList<Data>());
    } else {
        if (Double.compare(validationRate, 0d) != 0) {
            this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - validationRate)), new ArrayList<Data>());
            this.validationData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * validationRate), new ArrayList<Data>());
        } else {
            this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction), new ArrayList<Data>());
        }
    }
    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
    // numerical + categorical = # of all input
    this.inputCount = inputOutputIndex[0] + inputOutputIndex[1];
    // regression outputNodeCount is 1, binaryClassfication, it is 1, OneVsAll it is 1, Native classification it is
    // 1, with index of 0,1,2,3 denotes different classes
    this.isAfterVarSelect = (inputOutputIndex[3] == 1);
    this.isManualValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig.getValidationDataSetRawPath()));
    int numClasses = this.modelConfig.isClassification() ? this.modelConfig.getTags().size() : 2;
    String imStr = validParams.get("Impurity").toString();
    int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString());
    double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString());
    if (imStr.equalsIgnoreCase("entropy")) {
        impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain);
    } else if (imStr.equalsIgnoreCase("gini")) {
        impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain);
    } else if (imStr.equalsIgnoreCase("friedmanmse")) {
        impurity = new FriedmanMSE(minInstancesPerNode, minInfoGain);
    } else {
        impurity = new Variance(minInstancesPerNode, minInfoGain);
    }
    this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
    this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
    String lossStr = validParams.get("Loss").toString();
    if (lossStr.equalsIgnoreCase("log")) {
        this.loss = new LogLoss();
    } else if (lossStr.equalsIgnoreCase("absolute")) {
        this.loss = new AbsoluteLoss();
    } else if (lossStr.equalsIgnoreCase("halfgradsquared")) {
        this.loss = new HalfGradSquaredLoss();
    } else if (lossStr.equalsIgnoreCase("squared")) {
        this.loss = new SquaredLoss();
    } else {
        try {
            this.loss = (Loss) ClassUtils.newInstance(Class.forName(lossStr));
        } catch (ClassNotFoundException e) {
            LOG.warn("Class not found for {}, using default SquaredLoss", lossStr);
            this.loss = new SquaredLoss();
        }
    }
    if (this.isGBDT) {
        this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString());
        Object swrObj = validParams.get("GBTSampleWithReplacement");
        if (swrObj != null) {
            this.gbdtSampleWithReplacement = Boolean.TRUE.toString().equalsIgnoreCase(swrObj.toString());
        }
        Object dropoutObj = validParams.get(CommonConstants.DROPOUT_RATE);
        if (dropoutObj != null) {
            this.dropOutRate = Double.valueOf(dropoutObj.toString());
        }
    }
    this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample();
    this.checkpointOutput = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId()));
    LOG.info("Worker init params:isAfterVarSel={}, treeNum={}, impurity={}, loss={}, learningRate={}, gbdtSampleWithReplacement={}, isRF={}, isGBDT={}, isStratifiedSampling={}, isKFoldCV={}, kCrossValidation={}, dropOutRate={}", isAfterVarSelect, treeNum, impurity.getClass().getName(), loss.getClass().getName(), this.learningRate, this.gbdtSampleWithReplacement, this.isRF, this.isGBDT, this.isStratifiedSampling, this.isKFoldCV, kCrossValidation, this.dropOutRate);
    // for fail over, load existing trees
    if (!context.isFirstIteration()) {
        if (this.isGBDT) {
            // set flag here and recover later in doComputing, this is to make sure recover after load part which
            // can load latest trees in #doCompute
            isNeedRecoverGBDTPredict = true;
        } else {
            // RF , trees are recovered from last master results
            recoverTrees = context.getLastMasterResult().getTrees();
        }
    }
    if (context.isFirstIteration() && this.isContinuousEnabled && this.isGBDT) {
        Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
        TreeModel existingModel = null;
        try {
            existingModel = (TreeModel) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
        } catch (IOException e) {
            LOG.error("Error in get existing model, will ignore and start from scratch", e);
        }
        if (existingModel == null) {
            LOG.warn("No model is found even set to continuous model training.");
            return;
        } else {
            recoverTrees = existingModel.getTrees();
            LOG.info("Loading existing {} trees", recoverTrees.size());
        }
    }
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) ArrayList(java.util.ArrayList) Properties(java.util.Properties) TreeModel(ml.shifu.shifu.core.TreeModel) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) Path(org.apache.hadoop.fs.Path) IOException(java.io.IOException) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ConcurrentMap(java.util.concurrent.ConcurrentMap)

Example 89 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class FastCorrelationMapper method setup.

@Override
protected void setup(Context context) throws IOException, InterruptedException {
    loadConfigFiles(context);
    this.dataSetDelimiter = modelConfig.getDataSetDelimiter();
    this.dataPurifier = new DataPurifier(modelConfig, false);
    this.isComputeAll = Boolean.valueOf(context.getConfiguration().get(Constants.SHIFU_CORRELATION_COMPUTE_ALL, "false"));
    this.outputKey = new IntWritable();
    this.correlationMap = new HashMap<Integer, CorrelationWritable>();
    for (ColumnConfig config : columnConfigList) {
        if (config.isCategorical()) {
            Map<String, Integer> map = new HashMap<String, Integer>();
            if (config.getBinCategory() != null) {
                for (int i = 0; i < config.getBinCategory().size(); i++) {
                    List<String> cvals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
                    for (String cval : cvals) {
                        map.put(cval, i);
                    }
                }
            }
            this.categoricalIndexMap.put(config.getColumnNum(), map);
        }
    }
    if (modelConfig != null && modelConfig.getPosTags() != null) {
        this.posTagSet = new HashSet<String>(modelConfig.getPosTags());
    }
    if (modelConfig != null && modelConfig.getNegTags() != null) {
        this.negTagSet = new HashSet<String>(modelConfig.getNegTags());
    }
    if (modelConfig != null && modelConfig.getFlattenTags() != null) {
        this.tagSet = new HashSet<String>(modelConfig.getFlattenTags());
    }
    if (modelConfig != null) {
        this.tags = modelConfig.getSetTags();
    }
}
Also used : DataPurifier(ml.shifu.shifu.core.DataPurifier) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) HashMap(java.util.HashMap) IntWritable(org.apache.hadoop.io.IntWritable)

Example 90 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTrainUtils method getNumericAndCategoricalInputAndOutputCounts.

/**
 * Get numeric and categorical input nodes number (final select) and output nodes number from column config, and
 * candidate input node number.
 *
 * <p>
 * If number of column in final-select is 0, which means to select all non meta and non target columns. So the input
 * number is set to all candidates.
 *
 * @param columnConfigList
 *            the column config list
 * @return [input, output, candidate]
 * @throws NullPointerException
 *             if columnConfigList or ColumnConfig object in columnConfigList is null.
 */
public static int[] getNumericAndCategoricalInputAndOutputCounts(List<ColumnConfig> columnConfigList) {
    int numericInput = 0, categoricalInput = 0, output = 0, numericCandidateInput = 0, categoricalCandidateInput = 0;
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig config : columnConfigList) {
        if (!config.isTarget() && !config.isMeta() && CommonUtils.isGoodCandidate(config, hasCandidates)) {
            if (config.isNumerical()) {
                numericCandidateInput += 1;
            }
            if (config.isCategorical()) {
                categoricalCandidateInput += 1;
            }
        }
        if (config.isFinalSelect() && !config.isTarget() && !config.isMeta()) {
            if (config.isNumerical()) {
                numericInput += 1;
            }
            if (config.isCategorical()) {
                categoricalInput += 1;
            }
        }
        if (config.isTarget()) {
            output += 1;
        }
    }
    // check if it is after varselect, if not, no variable is set to finalSelect which means, all good variable
    // should be set as finalSelect TODO, bad practice, refactor me
    int isVarSelect = 1;
    if (numericInput == 0 && categoricalInput == 0) {
        numericInput = numericCandidateInput;
        categoricalInput = categoricalCandidateInput;
        isVarSelect = 0;
    }
    return new int[] { numericInput, categoricalInput, output, isVarSelect };
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Aggregations

ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException (java.io.IOException)16 HashMap (java.util.HashMap)12 Tuple (org.apache.pig.data.Tuple)10 File (java.io.File)8 NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag (org.apache.pig.data.DataBag)7 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData (org.encog.ml.data.basic.BasicMLData)5 BufferedWriter (java.io.BufferedWriter)3 FileInputStream (java.io.FileInputStream)3