Search in sources :

Example 1 with BasicFloatMLDataPair

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair in project shifu by ShifuML.

the class AbstractNNWorker method mockRandomRepeatData.

/**
 * From Trainer, the logic is to random choose items in master dataset, but I don't want to load data twice for
 * saving memory. Use this to mock raw random repeat logic. This should be some logic difference because of data are
 * not loaded into data set, not random.
 */
@SuppressWarnings("unused")
private void mockRandomRepeatData(double crossValidationRate, double random) {
    long trainingSize = this.trainingData.getRecordCount();
    long testingSize = this.validationData.getRecordCount();
    long size = trainingSize + testingSize;
    // here we used a strong cast from long to int since it's just a random choosing algorithm
    int next = RandomUtils.nextInt((int) size);
    FloatMLDataPair dataPair = new BasicFloatMLDataPair(new BasicFloatMLData(new float[this.subFeatures.size()]), new BasicFloatMLData(new float[this.outputNodeCount]));
    if (next >= trainingSize) {
        this.validationData.getRecord(next - trainingSize, dataPair);
    } else {
        this.trainingData.getRecord(next, dataPair);
    }
    if (Double.compare(random, crossValidationRate) < 0) {
        this.validationData.add(dataPair);
    } else {
        this.trainingData.add(dataPair);
    }
}
Also used : BasicFloatMLDataPair(ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair) FloatMLDataPair(ml.shifu.shifu.core.dtrain.dataset.FloatMLDataPair) BasicFloatMLDataPair(ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair) BasicFloatMLData(ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLData)

Example 2 with BasicFloatMLDataPair

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair in project shifu by ShifuML.

the class NNParquetWorker method load.

@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Tuple> currentValue, WorkerContext<NNParams, NNParams> workerContext) {
    // init field list for later read
    this.initFieldList();
    LOG.info("subFeatureSet size: {} ; subFeatureSet: {}", subFeatureSet.size(), subFeatureSet);
    super.count += 1;
    if ((super.count) % 5000 == 0) {
        LOG.info("Read {} records.", super.count);
    }
    float[] inputs = new float[super.featureInputsCnt];
    float[] ideal = new float[super.outputNodeCount];
    if (super.isDry) {
        // dry train, use empty data.
        addDataPairToDataSet(0, new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal)));
        return;
    }
    long hashcode = 0;
    float significance = 1f;
    // use guava Splitter to iterate only once
    // use NNConstants.NN_DEFAULT_COLUMN_SEPARATOR to replace getModelConfig().getDataSetDelimiter(), super follows
    // the function in akka mode.
    int index = 0, inputsIndex = 0, outputIndex = 0;
    Tuple tuple = currentValue.getWritable();
    // back from foreach to for loop because of in earlier version, tuple cannot be iterable.
    for (int i = 0; i < tuple.size(); i++) {
        Object element = null;
        try {
            element = tuple.get(i);
        } catch (ExecException e) {
            throw new GuaguaRuntimeException(e);
        }
        float floatValue = 0f;
        if (element != null) {
            if (element instanceof Float) {
                floatValue = (Float) element;
            } else {
                // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
                floatValue = element.toString().length() == 0 ? 0f : NumberFormatUtils.getFloat(element.toString(), 0f);
            }
        }
        // no idea about why NaN in input data, we should process it as missing value TODO , according to norm type
        floatValue = (Float.isNaN(floatValue) || Double.isNaN(floatValue)) ? 0f : floatValue;
        if (index == (super.inputNodeCount + super.outputNodeCount)) {
            // weight, how to process???
            if (StringUtils.isBlank(modelConfig.getWeightColumnName())) {
                significance = 1f;
                // break here if we reach weight column which is last column
                break;
            }
            assert element != null;
            if (element != null && element instanceof Float) {
                significance = (Float) element;
            } else {
                // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
                significance = element.toString().length() == 0 ? 1f : NumberFormatUtils.getFloat(element.toString(), 1f);
            }
            // if invalid weight, set it to 1f and warning in log
            if (Float.compare(significance, 0f) < 0) {
                LOG.warn("The {} record in current worker weight {} is less than 0f, it is invalid, set it to 1.", count, significance);
                significance = 1f;
            }
            // break here if we reach weight column which is last column
            break;
        } else {
            int columnIndex = requiredFieldList.getFields().get(index).getIndex();
            if (columnIndex >= super.columnConfigList.size()) {
                assert element != null;
                if (element != null && element instanceof Float) {
                    significance = (Float) element;
                } else {
                    // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
                    significance = element.toString().length() == 0 ? 1f : NumberFormatUtils.getFloat(element.toString(), 1f);
                }
                break;
            } else {
                ColumnConfig columnConfig = super.columnConfigList.get(columnIndex);
                if (columnConfig != null && columnConfig.isTarget()) {
                    if (modelConfig.isRegression()) {
                        ideal[outputIndex++] = floatValue;
                    } else {
                        if (modelConfig.getTrain().isOneVsAll()) {
                            // if one vs all, set correlated idea value according to trainerId which means in
                            // trainer with id 0, target 0 is treated with 1, other are 0. Such target value are set
                            // to index of tags like [0, 1, 2, 3] compared with ["a", "b", "c", "d"]
                            ideal[outputIndex++] = Float.compare(floatValue, trainerId) == 0 ? 1f : 0f;
                        } else {
                            if (modelConfig.getTags().size() == 2) {
                                // if only 2 classes, output node is 1 node. if target = 0 means 0 is the index for
                                // positive prediction, set positive to 1 and negative to 0
                                int ideaIndex = (int) floatValue;
                                ideal[0] = ideaIndex == 0 ? 1f : 0f;
                            } else {
                                // for multiple classification
                                int ideaIndex = (int) floatValue;
                                ideal[ideaIndex] = 1f;
                            }
                        }
                    }
                } else {
                    if (subFeatureSet.contains(columnIndex)) {
                        inputs[inputsIndex++] = floatValue;
                        hashcode = hashcode * 31 + Double.valueOf(floatValue).hashCode();
                    }
                }
            }
        }
        index += 1;
    }
    // is helped to quick find such issue.
    if (inputsIndex != inputs.length) {
        String delimiter = workerContext.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER, Constants.DEFAULT_DELIMITER);
        throw new RuntimeException("Input length is inconsistent with parsing size. Input original size: " + inputs.length + ", parsing size:" + inputsIndex + ", delimiter:" + delimiter + ".");
    }
    // sample negative only logic here
    if (modelConfig.getTrain().getSampleNegOnly()) {
        if (this.modelConfig.isFixInitialInput()) {
            // if fixInitialInput, sample hashcode in 1-sampleRate range out if negative records
            int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId;
            // here BaggingSampleRate means how many data will be used in training and validation, if it is 0.8, we
            // should take 1-0.8 to check endHashCode
            int endHashCode = startHashCode + Double.valueOf((1d - this.modelConfig.getBaggingSampleRate()) * 100).intValue();
            if ((modelConfig.isRegression() || // regression or
            (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
            (int) (ideal[0] + 0.01d) == // negative record
            0 && isInRange(hashcode, startHashCode, endHashCode)) {
                return;
            }
        } else {
            // if negative record
            if ((modelConfig.isRegression() || // regression or
            (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
            (int) (ideal[0] + 0.01d) == // negative record
            0 && Double.compare(super.sampelNegOnlyRandom.nextDouble(), this.modelConfig.getBaggingSampleRate()) >= 0) {
                return;
            }
        }
    }
    FloatMLDataPair pair = new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal));
    // up sampling logic
    if (modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal[0], 1d) == 0) {
        // Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
        pair.setSignificance(significance * (super.upSampleRng.sample() + 1));
    } else {
        pair.setSignificance(significance);
    }
    boolean isValidation = false;
    if (workerContext.getAttachment() != null && workerContext.getAttachment() instanceof Boolean) {
        isValidation = (Boolean) workerContext.getAttachment();
    }
    boolean isInTraining = addDataPairToDataSet(hashcode, pair, isValidation);
    // do bagging sampling only for training data
    if (isInTraining) {
        float subsampleWeights = sampleWeights(pair.getIdealArray()[0]);
        if (isPositive(pair.getIdealArray()[0])) {
            this.positiveSelectedTrainCount += subsampleWeights * 1L;
        } else {
            this.negativeSelectedTrainCount += subsampleWeights * 1L;
        }
        // set weights to significance, if 0, significance will be 0, that is bagging sampling
        pair.setSignificance(pair.getSignificance() * subsampleWeights);
    } else {
    // for validation data, according bagging sampling logic, we may need to sampling validation data set, while
    // validation data set are only used to compute validation error, not to do real sampling is ok.
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ExecException(org.apache.pig.backend.executionengine.ExecException) BasicFloatMLDataPair(ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair) FloatMLDataPair(ml.shifu.shifu.core.dtrain.dataset.FloatMLDataPair) BasicFloatMLDataPair(ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair) BasicFloatMLData(ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLData) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) Tuple(org.apache.pig.data.Tuple)

Example 3 with BasicFloatMLDataPair

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair in project shifu by ShifuML.

the class NNWorker method load.

@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<NNParams, NNParams> workerContext) {
    super.count += 1;
    if ((super.count) % 5000 == 0) {
        LOG.info("Read {} records.", super.count);
    }
    float[] inputs = new float[super.featureInputsCnt];
    float[] ideal = new float[super.outputNodeCount];
    if (super.isDry) {
        // dry train, use empty data.
        addDataPairToDataSet(0, new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal)));
        return;
    }
    long hashcode = 0;
    float significance = 1f;
    // use guava Splitter to iterate only once
    // use NNConstants.NN_DEFAULT_COLUMN_SEPARATOR to replace getModelConfig().getDataSetDelimiter(), super follows
    // the function in akka mode.
    int index = 0, inputsIndex = 0, outputIndex = 0;
    String[] fields = Lists.newArrayList(this.splitter.split(currentValue.getWritable().toString())).toArray(new String[0]);
    int pos = 0;
    for (pos = 0; pos < fields.length; ) {
        String input = fields[pos];
        // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
        float floatValue = input.length() == 0 ? 0f : NumberFormatUtils.getFloat(input, 0f);
        // no idea about why NaN in input data, we should process it as missing value TODO , according to norm type
        floatValue = (Float.isNaN(floatValue) || Double.isNaN(floatValue)) ? 0f : floatValue;
        if (pos == fields.length - 1) {
            // weight, how to process???
            if (StringUtils.isBlank(modelConfig.getWeightColumnName())) {
                significance = 1f;
                // break here if we reach weight column which is last column
                break;
            }
            // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 1f)
            significance = input.length() == 0 ? 1f : NumberFormatUtils.getFloat(input, 1f);
            // if invalid weight, set it to 1f and warning in log
            if (Float.compare(significance, 0f) < 0) {
                LOG.warn("The {} record in current worker weight {} is less than 0f, it is invalid, set it to 1.", count, significance);
                significance = 1f;
            }
            // the last field is significance, break here
            break;
        } else {
            ColumnConfig columnConfig = super.columnConfigList.get(index);
            if (columnConfig != null && columnConfig.isTarget()) {
                if (isLinearTarget || modelConfig.isRegression()) {
                    ideal[outputIndex++] = floatValue;
                } else {
                    if (modelConfig.getTrain().isOneVsAll()) {
                        // if one vs all, set correlated idea value according to trainerId which means in trainer
                        // with id 0, target 0 is treated with 1, other are 0. Such target value are set to index of
                        // tags like [0, 1, 2, 3] compared with ["a", "b", "c", "d"]
                        ideal[outputIndex++] = Float.compare(floatValue, trainerId) == 0 ? 1f : 0f;
                    } else {
                        if (modelConfig.getTags().size() == 2) {
                            // if only 2 classes, output node is 1 node. if target = 0 means 0 is the index for
                            // positive prediction, set positive to 1 and negative to 0
                            int ideaIndex = (int) floatValue;
                            ideal[0] = ideaIndex == 0 ? 1f : 0f;
                        } else {
                            // for multiple classification
                            int ideaIndex = (int) floatValue;
                            ideal[ideaIndex] = 1f;
                        }
                    }
                }
                pos++;
            } else {
                if (subFeatureSet.contains(index)) {
                    if (columnConfig.isMeta() || columnConfig.isForceRemove()) {
                        // it shouldn't happen here
                        pos += 1;
                    } else if (columnConfig != null && columnConfig.isNumerical() && modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT)) {
                        for (int k = 0; k < columnConfig.getBinBoundary().size() + 1; k++) {
                            String tval = fields[pos];
                            // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
                            float fval = tval.length() == 0 ? 0f : NumberFormatUtils.getFloat(tval, 0f);
                            // no idea about why NaN in input data, we should process it as missing value TODO ,
                            // according to norm type
                            fval = (Float.isNaN(fval) || Double.isNaN(fval)) ? 0f : fval;
                            inputs[inputsIndex++] = fval;
                            pos++;
                        }
                    } else if (columnConfig != null && columnConfig.isCategorical() && (modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ZSCALE_ONEHOT) || modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT))) {
                        for (int k = 0; k < columnConfig.getBinCategory().size() + 1; k++) {
                            String tval = fields[pos];
                            // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
                            float fval = tval.length() == 0 ? 0f : NumberFormatUtils.getFloat(tval, 0f);
                            // no idea about why NaN in input data, we should process it as missing value TODO ,
                            // according to norm type
                            fval = (Float.isNaN(fval) || Double.isNaN(fval)) ? 0f : fval;
                            inputs[inputsIndex++] = fval;
                            pos++;
                        }
                    } else {
                        inputs[inputsIndex++] = floatValue;
                        pos++;
                    }
                    hashcode = hashcode * 31 + Double.valueOf(floatValue).hashCode();
                } else {
                    if (!CommonUtils.isToNormVariable(columnConfig, hasCandidates, modelConfig.isRegression())) {
                        pos += 1;
                    } else if (columnConfig.isNumerical() && modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT) && columnConfig.getBinBoundary() != null && columnConfig.getBinBoundary().size() > 0) {
                        pos += (columnConfig.getBinBoundary().size() + 1);
                    } else if (columnConfig.isCategorical() && (modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ZSCALE_ONEHOT) || modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT)) && columnConfig.getBinCategory().size() > 0) {
                        pos += (columnConfig.getBinCategory().size() + 1);
                    } else {
                        pos += 1;
                    }
                }
            }
        }
        index += 1;
    }
    if (index != this.columnConfigList.size() || pos != fields.length - 1) {
        throw new RuntimeException("Wrong data indexing. ColumnConfig index = " + index + ", while it should be " + columnConfigList.size() + ". " + "Data Pos = " + pos + ", while it should be " + (fields.length - 1));
    }
    // is helped to quick find such issue.
    if (inputsIndex != inputs.length) {
        String delimiter = workerContext.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER, Constants.DEFAULT_DELIMITER);
        throw new RuntimeException("Input length is inconsistent with parsing size. Input original size: " + inputs.length + ", parsing size:" + inputsIndex + ", delimiter:" + delimiter + ".");
    }
    // sample negative only logic here
    if (modelConfig.getTrain().getSampleNegOnly()) {
        if (this.modelConfig.isFixInitialInput()) {
            // if fixInitialInput, sample hashcode in 1-sampleRate range out if negative records
            int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId;
            // here BaggingSampleRate means how many data will be used in training and validation, if it is 0.8, we
            // should take 1-0.8 to check endHashCode
            int endHashCode = startHashCode + Double.valueOf((1d - this.modelConfig.getBaggingSampleRate()) * 100).intValue();
            if ((modelConfig.isRegression() || // regression or
            (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
            (int) (ideal[0] + 0.01d) == // negative record
            0 && isInRange(hashcode, startHashCode, endHashCode)) {
                return;
            }
        } else {
            // if negative record
            if ((modelConfig.isRegression() || // regression or
            (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
            (int) (ideal[0] + 0.01d) == // negative record
            0 && Double.compare(super.sampelNegOnlyRandom.nextDouble(), this.modelConfig.getBaggingSampleRate()) >= 0) {
                return;
            }
        }
    }
    FloatMLDataPair pair = new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal));
    // up sampling logic, just add more weights while bagging sampling rate is still not changed
    if (modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal[0], 1d) == 0) {
        // Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
        pair.setSignificance(significance * (super.upSampleRng.sample() + 1));
    } else {
        pair.setSignificance(significance);
    }
    boolean isValidation = false;
    if (workerContext.getAttachment() != null && workerContext.getAttachment() instanceof Boolean) {
        isValidation = (Boolean) workerContext.getAttachment();
    }
    boolean isInTraining = addDataPairToDataSet(hashcode, pair, isValidation);
    // do bagging sampling only for training data
    if (isInTraining) {
        float subsampleWeights = sampleWeights(pair.getIdealArray()[0]);
        if (isPositive(pair.getIdealArray()[0])) {
            this.positiveSelectedTrainCount += subsampleWeights * 1L;
        } else {
            this.negativeSelectedTrainCount += subsampleWeights * 1L;
        }
        // set weights to significance, if 0, significance will be 0, that is bagging sampling
        pair.setSignificance(pair.getSignificance() * subsampleWeights);
    } else {
    // for validation data, according bagging sampling logic, we may need to sampling validation data set, while
    // validation data set are only used to compute validation error, not to do real sampling is ok.
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicFloatMLDataPair(ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair) FloatMLDataPair(ml.shifu.shifu.core.dtrain.dataset.FloatMLDataPair) BasicFloatMLDataPair(ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair) BasicFloatMLData(ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLData)

Aggregations

BasicFloatMLData (ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLData)3 BasicFloatMLDataPair (ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair)3 FloatMLDataPair (ml.shifu.shifu.core.dtrain.dataset.FloatMLDataPair)3 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)2 GuaguaRuntimeException (ml.shifu.guagua.GuaguaRuntimeException)1 ExecException (org.apache.pig.backend.executionengine.ExecException)1 Tuple (org.apache.pig.data.Tuple)1