use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair in project shifu by ShifuML.
the class AbstractNNWorker method mockRandomRepeatData.
/**
* From Trainer, the logic is to random choose items in master dataset, but I don't want to load data twice for
* saving memory. Use this to mock raw random repeat logic. This should be some logic difference because of data are
* not loaded into data set, not random.
*/
@SuppressWarnings("unused")
private void mockRandomRepeatData(double crossValidationRate, double random) {
long trainingSize = this.trainingData.getRecordCount();
long testingSize = this.validationData.getRecordCount();
long size = trainingSize + testingSize;
// here we used a strong cast from long to int since it's just a random choosing algorithm
int next = RandomUtils.nextInt((int) size);
FloatMLDataPair dataPair = new BasicFloatMLDataPair(new BasicFloatMLData(new float[this.subFeatures.size()]), new BasicFloatMLData(new float[this.outputNodeCount]));
if (next >= trainingSize) {
this.validationData.getRecord(next - trainingSize, dataPair);
} else {
this.trainingData.getRecord(next, dataPair);
}
if (Double.compare(random, crossValidationRate) < 0) {
this.validationData.add(dataPair);
} else {
this.trainingData.add(dataPair);
}
}
use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair in project shifu by ShifuML.
the class NNParquetWorker method load.
@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Tuple> currentValue, WorkerContext<NNParams, NNParams> workerContext) {
// init field list for later read
this.initFieldList();
LOG.info("subFeatureSet size: {} ; subFeatureSet: {}", subFeatureSet.size(), subFeatureSet);
super.count += 1;
if ((super.count) % 5000 == 0) {
LOG.info("Read {} records.", super.count);
}
float[] inputs = new float[super.featureInputsCnt];
float[] ideal = new float[super.outputNodeCount];
if (super.isDry) {
// dry train, use empty data.
addDataPairToDataSet(0, new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal)));
return;
}
long hashcode = 0;
float significance = 1f;
// use guava Splitter to iterate only once
// use NNConstants.NN_DEFAULT_COLUMN_SEPARATOR to replace getModelConfig().getDataSetDelimiter(), super follows
// the function in akka mode.
int index = 0, inputsIndex = 0, outputIndex = 0;
Tuple tuple = currentValue.getWritable();
// back from foreach to for loop because of in earlier version, tuple cannot be iterable.
for (int i = 0; i < tuple.size(); i++) {
Object element = null;
try {
element = tuple.get(i);
} catch (ExecException e) {
throw new GuaguaRuntimeException(e);
}
float floatValue = 0f;
if (element != null) {
if (element instanceof Float) {
floatValue = (Float) element;
} else {
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
floatValue = element.toString().length() == 0 ? 0f : NumberFormatUtils.getFloat(element.toString(), 0f);
}
}
// no idea about why NaN in input data, we should process it as missing value TODO , according to norm type
floatValue = (Float.isNaN(floatValue) || Double.isNaN(floatValue)) ? 0f : floatValue;
if (index == (super.inputNodeCount + super.outputNodeCount)) {
// weight, how to process???
if (StringUtils.isBlank(modelConfig.getWeightColumnName())) {
significance = 1f;
// break here if we reach weight column which is last column
break;
}
assert element != null;
if (element != null && element instanceof Float) {
significance = (Float) element;
} else {
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
significance = element.toString().length() == 0 ? 1f : NumberFormatUtils.getFloat(element.toString(), 1f);
}
// if invalid weight, set it to 1f and warning in log
if (Float.compare(significance, 0f) < 0) {
LOG.warn("The {} record in current worker weight {} is less than 0f, it is invalid, set it to 1.", count, significance);
significance = 1f;
}
// break here if we reach weight column which is last column
break;
} else {
int columnIndex = requiredFieldList.getFields().get(index).getIndex();
if (columnIndex >= super.columnConfigList.size()) {
assert element != null;
if (element != null && element instanceof Float) {
significance = (Float) element;
} else {
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
significance = element.toString().length() == 0 ? 1f : NumberFormatUtils.getFloat(element.toString(), 1f);
}
break;
} else {
ColumnConfig columnConfig = super.columnConfigList.get(columnIndex);
if (columnConfig != null && columnConfig.isTarget()) {
if (modelConfig.isRegression()) {
ideal[outputIndex++] = floatValue;
} else {
if (modelConfig.getTrain().isOneVsAll()) {
// if one vs all, set correlated idea value according to trainerId which means in
// trainer with id 0, target 0 is treated with 1, other are 0. Such target value are set
// to index of tags like [0, 1, 2, 3] compared with ["a", "b", "c", "d"]
ideal[outputIndex++] = Float.compare(floatValue, trainerId) == 0 ? 1f : 0f;
} else {
if (modelConfig.getTags().size() == 2) {
// if only 2 classes, output node is 1 node. if target = 0 means 0 is the index for
// positive prediction, set positive to 1 and negative to 0
int ideaIndex = (int) floatValue;
ideal[0] = ideaIndex == 0 ? 1f : 0f;
} else {
// for multiple classification
int ideaIndex = (int) floatValue;
ideal[ideaIndex] = 1f;
}
}
}
} else {
if (subFeatureSet.contains(columnIndex)) {
inputs[inputsIndex++] = floatValue;
hashcode = hashcode * 31 + Double.valueOf(floatValue).hashCode();
}
}
}
}
index += 1;
}
// is helped to quick find such issue.
if (inputsIndex != inputs.length) {
String delimiter = workerContext.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER, Constants.DEFAULT_DELIMITER);
throw new RuntimeException("Input length is inconsistent with parsing size. Input original size: " + inputs.length + ", parsing size:" + inputsIndex + ", delimiter:" + delimiter + ".");
}
// sample negative only logic here
if (modelConfig.getTrain().getSampleNegOnly()) {
if (this.modelConfig.isFixInitialInput()) {
// if fixInitialInput, sample hashcode in 1-sampleRate range out if negative records
int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId;
// here BaggingSampleRate means how many data will be used in training and validation, if it is 0.8, we
// should take 1-0.8 to check endHashCode
int endHashCode = startHashCode + Double.valueOf((1d - this.modelConfig.getBaggingSampleRate()) * 100).intValue();
if ((modelConfig.isRegression() || // regression or
(modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
(int) (ideal[0] + 0.01d) == // negative record
0 && isInRange(hashcode, startHashCode, endHashCode)) {
return;
}
} else {
// if negative record
if ((modelConfig.isRegression() || // regression or
(modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
(int) (ideal[0] + 0.01d) == // negative record
0 && Double.compare(super.sampelNegOnlyRandom.nextDouble(), this.modelConfig.getBaggingSampleRate()) >= 0) {
return;
}
}
}
FloatMLDataPair pair = new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal));
// up sampling logic
if (modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal[0], 1d) == 0) {
// Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
pair.setSignificance(significance * (super.upSampleRng.sample() + 1));
} else {
pair.setSignificance(significance);
}
boolean isValidation = false;
if (workerContext.getAttachment() != null && workerContext.getAttachment() instanceof Boolean) {
isValidation = (Boolean) workerContext.getAttachment();
}
boolean isInTraining = addDataPairToDataSet(hashcode, pair, isValidation);
// do bagging sampling only for training data
if (isInTraining) {
float subsampleWeights = sampleWeights(pair.getIdealArray()[0]);
if (isPositive(pair.getIdealArray()[0])) {
this.positiveSelectedTrainCount += subsampleWeights * 1L;
} else {
this.negativeSelectedTrainCount += subsampleWeights * 1L;
}
// set weights to significance, if 0, significance will be 0, that is bagging sampling
pair.setSignificance(pair.getSignificance() * subsampleWeights);
} else {
// for validation data, according bagging sampling logic, we may need to sampling validation data set, while
// validation data set are only used to compute validation error, not to do real sampling is ok.
}
}
use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair in project shifu by ShifuML.
the class NNWorker method load.
@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<NNParams, NNParams> workerContext) {
super.count += 1;
if ((super.count) % 5000 == 0) {
LOG.info("Read {} records.", super.count);
}
float[] inputs = new float[super.featureInputsCnt];
float[] ideal = new float[super.outputNodeCount];
if (super.isDry) {
// dry train, use empty data.
addDataPairToDataSet(0, new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal)));
return;
}
long hashcode = 0;
float significance = 1f;
// use guava Splitter to iterate only once
// use NNConstants.NN_DEFAULT_COLUMN_SEPARATOR to replace getModelConfig().getDataSetDelimiter(), super follows
// the function in akka mode.
int index = 0, inputsIndex = 0, outputIndex = 0;
String[] fields = Lists.newArrayList(this.splitter.split(currentValue.getWritable().toString())).toArray(new String[0]);
int pos = 0;
for (pos = 0; pos < fields.length; ) {
String input = fields[pos];
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
float floatValue = input.length() == 0 ? 0f : NumberFormatUtils.getFloat(input, 0f);
// no idea about why NaN in input data, we should process it as missing value TODO , according to norm type
floatValue = (Float.isNaN(floatValue) || Double.isNaN(floatValue)) ? 0f : floatValue;
if (pos == fields.length - 1) {
// weight, how to process???
if (StringUtils.isBlank(modelConfig.getWeightColumnName())) {
significance = 1f;
// break here if we reach weight column which is last column
break;
}
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 1f)
significance = input.length() == 0 ? 1f : NumberFormatUtils.getFloat(input, 1f);
// if invalid weight, set it to 1f and warning in log
if (Float.compare(significance, 0f) < 0) {
LOG.warn("The {} record in current worker weight {} is less than 0f, it is invalid, set it to 1.", count, significance);
significance = 1f;
}
// the last field is significance, break here
break;
} else {
ColumnConfig columnConfig = super.columnConfigList.get(index);
if (columnConfig != null && columnConfig.isTarget()) {
if (isLinearTarget || modelConfig.isRegression()) {
ideal[outputIndex++] = floatValue;
} else {
if (modelConfig.getTrain().isOneVsAll()) {
// if one vs all, set correlated idea value according to trainerId which means in trainer
// with id 0, target 0 is treated with 1, other are 0. Such target value are set to index of
// tags like [0, 1, 2, 3] compared with ["a", "b", "c", "d"]
ideal[outputIndex++] = Float.compare(floatValue, trainerId) == 0 ? 1f : 0f;
} else {
if (modelConfig.getTags().size() == 2) {
// if only 2 classes, output node is 1 node. if target = 0 means 0 is the index for
// positive prediction, set positive to 1 and negative to 0
int ideaIndex = (int) floatValue;
ideal[0] = ideaIndex == 0 ? 1f : 0f;
} else {
// for multiple classification
int ideaIndex = (int) floatValue;
ideal[ideaIndex] = 1f;
}
}
}
pos++;
} else {
if (subFeatureSet.contains(index)) {
if (columnConfig.isMeta() || columnConfig.isForceRemove()) {
// it shouldn't happen here
pos += 1;
} else if (columnConfig != null && columnConfig.isNumerical() && modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT)) {
for (int k = 0; k < columnConfig.getBinBoundary().size() + 1; k++) {
String tval = fields[pos];
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
float fval = tval.length() == 0 ? 0f : NumberFormatUtils.getFloat(tval, 0f);
// no idea about why NaN in input data, we should process it as missing value TODO ,
// according to norm type
fval = (Float.isNaN(fval) || Double.isNaN(fval)) ? 0f : fval;
inputs[inputsIndex++] = fval;
pos++;
}
} else if (columnConfig != null && columnConfig.isCategorical() && (modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ZSCALE_ONEHOT) || modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT))) {
for (int k = 0; k < columnConfig.getBinCategory().size() + 1; k++) {
String tval = fields[pos];
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
float fval = tval.length() == 0 ? 0f : NumberFormatUtils.getFloat(tval, 0f);
// no idea about why NaN in input data, we should process it as missing value TODO ,
// according to norm type
fval = (Float.isNaN(fval) || Double.isNaN(fval)) ? 0f : fval;
inputs[inputsIndex++] = fval;
pos++;
}
} else {
inputs[inputsIndex++] = floatValue;
pos++;
}
hashcode = hashcode * 31 + Double.valueOf(floatValue).hashCode();
} else {
if (!CommonUtils.isToNormVariable(columnConfig, hasCandidates, modelConfig.isRegression())) {
pos += 1;
} else if (columnConfig.isNumerical() && modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT) && columnConfig.getBinBoundary() != null && columnConfig.getBinBoundary().size() > 0) {
pos += (columnConfig.getBinBoundary().size() + 1);
} else if (columnConfig.isCategorical() && (modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ZSCALE_ONEHOT) || modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT)) && columnConfig.getBinCategory().size() > 0) {
pos += (columnConfig.getBinCategory().size() + 1);
} else {
pos += 1;
}
}
}
}
index += 1;
}
if (index != this.columnConfigList.size() || pos != fields.length - 1) {
throw new RuntimeException("Wrong data indexing. ColumnConfig index = " + index + ", while it should be " + columnConfigList.size() + ". " + "Data Pos = " + pos + ", while it should be " + (fields.length - 1));
}
// is helped to quick find such issue.
if (inputsIndex != inputs.length) {
String delimiter = workerContext.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER, Constants.DEFAULT_DELIMITER);
throw new RuntimeException("Input length is inconsistent with parsing size. Input original size: " + inputs.length + ", parsing size:" + inputsIndex + ", delimiter:" + delimiter + ".");
}
// sample negative only logic here
if (modelConfig.getTrain().getSampleNegOnly()) {
if (this.modelConfig.isFixInitialInput()) {
// if fixInitialInput, sample hashcode in 1-sampleRate range out if negative records
int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId;
// here BaggingSampleRate means how many data will be used in training and validation, if it is 0.8, we
// should take 1-0.8 to check endHashCode
int endHashCode = startHashCode + Double.valueOf((1d - this.modelConfig.getBaggingSampleRate()) * 100).intValue();
if ((modelConfig.isRegression() || // regression or
(modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
(int) (ideal[0] + 0.01d) == // negative record
0 && isInRange(hashcode, startHashCode, endHashCode)) {
return;
}
} else {
// if negative record
if ((modelConfig.isRegression() || // regression or
(modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
(int) (ideal[0] + 0.01d) == // negative record
0 && Double.compare(super.sampelNegOnlyRandom.nextDouble(), this.modelConfig.getBaggingSampleRate()) >= 0) {
return;
}
}
}
FloatMLDataPair pair = new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal));
// up sampling logic, just add more weights while bagging sampling rate is still not changed
if (modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal[0], 1d) == 0) {
// Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
pair.setSignificance(significance * (super.upSampleRng.sample() + 1));
} else {
pair.setSignificance(significance);
}
boolean isValidation = false;
if (workerContext.getAttachment() != null && workerContext.getAttachment() instanceof Boolean) {
isValidation = (Boolean) workerContext.getAttachment();
}
boolean isInTraining = addDataPairToDataSet(hashcode, pair, isValidation);
// do bagging sampling only for training data
if (isInTraining) {
float subsampleWeights = sampleWeights(pair.getIdealArray()[0]);
if (isPositive(pair.getIdealArray()[0])) {
this.positiveSelectedTrainCount += subsampleWeights * 1L;
} else {
this.negativeSelectedTrainCount += subsampleWeights * 1L;
}
// set weights to significance, if 0, significance will be 0, that is bagging sampling
pair.setSignificance(pair.getSignificance() * subsampleWeights);
} else {
// for validation data, according bagging sampling logic, we may need to sampling validation data set, while
// validation data set are only used to compute validation error, not to do real sampling is ok.
}
}
Aggregations