use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.
the class LogisticRegressionMaster method initOrRecoverParams.
private LogisticRegressionParams initOrRecoverParams(MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) {
LOG.info("read from existing model");
LogisticRegressionParams params = null;
// read existing model weights
try {
Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
LR existingModel = (LR) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
if (existingModel == null) {
params = initWeights();
LOG.info("Starting to train model from scratch.");
} else {
params = initModelParams(existingModel);
LOG.info("Starting to train model from existing model {}.", modelPath);
}
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
}
return params;
}
use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.
the class CorrelationReducer method objectToBytes.
public byte[] objectToBytes(Writable result) {
ByteArrayOutputStream out = null;
DataOutputStream dataOut = null;
try {
out = new ByteArrayOutputStream();
dataOut = new DataOutputStream(out);
result.write(dataOut);
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
} finally {
if (dataOut != null) {
try {
dataOut.close();
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
}
}
}
return out.toByteArray();
}
use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.
the class NNParquetWorker method load.
@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Tuple> currentValue, WorkerContext<NNParams, NNParams> workerContext) {
// init field list for later read
this.initFieldList();
LOG.info("subFeatureSet size: {} ; subFeatureSet: {}", subFeatureSet.size(), subFeatureSet);
super.count += 1;
if ((super.count) % 5000 == 0) {
LOG.info("Read {} records.", super.count);
}
float[] inputs = new float[super.featureInputsCnt];
float[] ideal = new float[super.outputNodeCount];
if (super.isDry) {
// dry train, use empty data.
addDataPairToDataSet(0, new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal)));
return;
}
long hashcode = 0;
float significance = 1f;
// use guava Splitter to iterate only once
// use NNConstants.NN_DEFAULT_COLUMN_SEPARATOR to replace getModelConfig().getDataSetDelimiter(), super follows
// the function in akka mode.
int index = 0, inputsIndex = 0, outputIndex = 0;
Tuple tuple = currentValue.getWritable();
// back from foreach to for loop because of in earlier version, tuple cannot be iterable.
for (int i = 0; i < tuple.size(); i++) {
Object element = null;
try {
element = tuple.get(i);
} catch (ExecException e) {
throw new GuaguaRuntimeException(e);
}
float floatValue = 0f;
if (element != null) {
if (element instanceof Float) {
floatValue = (Float) element;
} else {
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
floatValue = element.toString().length() == 0 ? 0f : NumberFormatUtils.getFloat(element.toString(), 0f);
}
}
// no idea about why NaN in input data, we should process it as missing value TODO , according to norm type
floatValue = (Float.isNaN(floatValue) || Double.isNaN(floatValue)) ? 0f : floatValue;
if (index == (super.inputNodeCount + super.outputNodeCount)) {
// weight, how to process???
if (StringUtils.isBlank(modelConfig.getWeightColumnName())) {
significance = 1f;
// break here if we reach weight column which is last column
break;
}
assert element != null;
if (element != null && element instanceof Float) {
significance = (Float) element;
} else {
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
significance = element.toString().length() == 0 ? 1f : NumberFormatUtils.getFloat(element.toString(), 1f);
}
// if invalid weight, set it to 1f and warning in log
if (Float.compare(significance, 0f) < 0) {
LOG.warn("The {} record in current worker weight {} is less than 0f, it is invalid, set it to 1.", count, significance);
significance = 1f;
}
// break here if we reach weight column which is last column
break;
} else {
int columnIndex = requiredFieldList.getFields().get(index).getIndex();
if (columnIndex >= super.columnConfigList.size()) {
assert element != null;
if (element != null && element instanceof Float) {
significance = (Float) element;
} else {
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
significance = element.toString().length() == 0 ? 1f : NumberFormatUtils.getFloat(element.toString(), 1f);
}
break;
} else {
ColumnConfig columnConfig = super.columnConfigList.get(columnIndex);
if (columnConfig != null && columnConfig.isTarget()) {
if (modelConfig.isRegression()) {
ideal[outputIndex++] = floatValue;
} else {
if (modelConfig.getTrain().isOneVsAll()) {
// if one vs all, set correlated idea value according to trainerId which means in
// trainer with id 0, target 0 is treated with 1, other are 0. Such target value are set
// to index of tags like [0, 1, 2, 3] compared with ["a", "b", "c", "d"]
ideal[outputIndex++] = Float.compare(floatValue, trainerId) == 0 ? 1f : 0f;
} else {
if (modelConfig.getTags().size() == 2) {
// if only 2 classes, output node is 1 node. if target = 0 means 0 is the index for
// positive prediction, set positive to 1 and negative to 0
int ideaIndex = (int) floatValue;
ideal[0] = ideaIndex == 0 ? 1f : 0f;
} else {
// for multiple classification
int ideaIndex = (int) floatValue;
ideal[ideaIndex] = 1f;
}
}
}
} else {
if (subFeatureSet.contains(columnIndex)) {
inputs[inputsIndex++] = floatValue;
hashcode = hashcode * 31 + Double.valueOf(floatValue).hashCode();
}
}
}
}
index += 1;
}
// is helped to quick find such issue.
if (inputsIndex != inputs.length) {
String delimiter = workerContext.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER, Constants.DEFAULT_DELIMITER);
throw new RuntimeException("Input length is inconsistent with parsing size. Input original size: " + inputs.length + ", parsing size:" + inputsIndex + ", delimiter:" + delimiter + ".");
}
// sample negative only logic here
if (modelConfig.getTrain().getSampleNegOnly()) {
if (this.modelConfig.isFixInitialInput()) {
// if fixInitialInput, sample hashcode in 1-sampleRate range out if negative records
int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId;
// here BaggingSampleRate means how many data will be used in training and validation, if it is 0.8, we
// should take 1-0.8 to check endHashCode
int endHashCode = startHashCode + Double.valueOf((1d - this.modelConfig.getBaggingSampleRate()) * 100).intValue();
if ((modelConfig.isRegression() || // regression or
(modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
(int) (ideal[0] + 0.01d) == // negative record
0 && isInRange(hashcode, startHashCode, endHashCode)) {
return;
}
} else {
// if negative record
if ((modelConfig.isRegression() || // regression or
(modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
(int) (ideal[0] + 0.01d) == // negative record
0 && Double.compare(super.sampelNegOnlyRandom.nextDouble(), this.modelConfig.getBaggingSampleRate()) >= 0) {
return;
}
}
}
FloatMLDataPair pair = new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal));
// up sampling logic
if (modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal[0], 1d) == 0) {
// Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
pair.setSignificance(significance * (super.upSampleRng.sample() + 1));
} else {
pair.setSignificance(significance);
}
boolean isValidation = false;
if (workerContext.getAttachment() != null && workerContext.getAttachment() instanceof Boolean) {
isValidation = (Boolean) workerContext.getAttachment();
}
boolean isInTraining = addDataPairToDataSet(hashcode, pair, isValidation);
// do bagging sampling only for training data
if (isInTraining) {
float subsampleWeights = sampleWeights(pair.getIdealArray()[0]);
if (isPositive(pair.getIdealArray()[0])) {
this.positiveSelectedTrainCount += subsampleWeights * 1L;
} else {
this.negativeSelectedTrainCount += subsampleWeights * 1L;
}
// set weights to significance, if 0, significance will be 0, that is bagging sampling
pair.setSignificance(pair.getSignificance() * subsampleWeights);
} else {
// for validation data, according bagging sampling logic, we may need to sampling validation data set, while
// validation data set are only used to compute validation error, not to do real sampling is ok.
}
}
use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.
the class AbstractNNWorker method init.
@Override
public void init(WorkerContext<NNParams, NNParams> context) {
// load props firstly
this.props = context.getProps();
loadConfigFiles(context.getProps());
this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
this.validParams = this.modelConfig.getTrain().getParams();
if (gs.hasHyperParam()) {
this.validParams = gs.getParams(trainerId);
LOG.info("Start grid search master with params: {}", validParams);
}
Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
if (kCrossValidation != null && kCrossValidation > 0) {
isKFoldCV = true;
LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation);
}
this.poissonSampler = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(NNConstants.NN_POISON_SAMPLER));
this.rng = new PoissonDistribution(1.0d);
Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight();
if (Double.compare(upSampleWeight, 1d) != 0 && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) {
// set mean to upSampleWeight -1 and get sample + 1to make sure no zero sample value
LOG.info("Enable up sampling with weight {}.", upSampleWeight);
this.upSampleRng = new PoissonDistribution(upSampleWeight - 1);
}
Integer epochsPerIterationInteger = this.modelConfig.getTrain().getEpochsPerIteration();
this.epochsPerIteration = epochsPerIterationInteger == null ? 1 : epochsPerIterationInteger.intValue();
LOG.info("epochsPerIteration in worker is :{}", epochsPerIteration);
// Object elmObject = validParams.get(DTrainUtils.IS_ELM);
// isELM = elmObject == null ? false : "true".equalsIgnoreCase(elmObject.toString());
// LOG.info("Check isELM: {}", isELM);
Object dropoutRateObj = validParams.get(CommonConstants.DROPOUT_RATE);
if (dropoutRateObj != null) {
this.dropoutRate = Double.valueOf(dropoutRateObj.toString());
}
LOG.info("'dropoutRate' in worker is :{}", this.dropoutRate);
Object miniBatchO = validParams.get(CommonConstants.MINI_BATCH);
if (miniBatchO != null) {
int miniBatchs;
try {
miniBatchs = Integer.parseInt(miniBatchO.toString());
} catch (Exception e) {
miniBatchs = 1;
}
if (miniBatchs < 0) {
this.batchs = 1;
} else if (miniBatchs > 1000) {
this.batchs = 1000;
} else {
this.batchs = miniBatchs;
}
LOG.info("'miniBatchs' in worker is : {}, batchs is {} ", miniBatchs, batchs);
}
int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
this.inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
// if is one vs all classification, outputNodeCount is set to 1, if classes=2, outputNodeCount is also 1
int classes = modelConfig.getTags().size();
this.outputNodeCount = (isLinearTarget || modelConfig.isRegression()) ? inputOutputIndex[1] : (modelConfig.getTrain().isOneVsAll() ? inputOutputIndex[1] : (classes == 2 ? 1 : classes));
this.candidateCount = inputOutputIndex[2];
boolean isAfterVarSelect = inputOutputIndex[0] != 0;
LOG.info("isAfterVarSelect {}: Input count {}, output count {}, candidate count {}", isAfterVarSelect, inputNodeCount, outputNodeCount, candidateCount);
// cache all feature list for sampling features
this.allFeatures = NormalUtils.getAllFeatureList(columnConfigList, isAfterVarSelect);
String subsetStr = context.getProps().getProperty(CommonConstants.SHIFU_NN_FEATURE_SUBSET);
if (StringUtils.isBlank(subsetStr)) {
this.subFeatures = this.allFeatures;
} else {
String[] splits = subsetStr.split(",");
this.subFeatures = new ArrayList<Integer>(splits.length);
for (String split : splits) {
int featureIndex = Integer.parseInt(split);
this.subFeatures.add(featureIndex);
}
}
this.subFeatureSet = new HashSet<Integer>(this.subFeatures);
LOG.info("subFeatures size is {}", subFeatures.size());
this.featureInputsCnt = DTrainUtils.getFeatureInputsCnt(this.modelConfig, this.columnConfigList, this.subFeatureSet);
this.wgtInit = "default";
Object wgtInitObj = validParams.get(CommonConstants.WEIGHT_INITIALIZER);
if (wgtInitObj != null) {
this.wgtInit = wgtInitObj.toString();
}
Object lossObj = validParams.get("Loss");
this.lossStr = lossObj != null ? lossObj.toString() : "squared";
LOG.info("Loss str is {}", this.lossStr);
this.isDry = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(CommonConstants.SHIFU_DRY_DTRAIN));
this.isSpecificValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig.getValidationDataSetRawPath()));
this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample();
if (isOnDisk()) {
LOG.info("NNWorker is loading data into disk.");
try {
initDiskDataSet();
} catch (IOException e) {
throw new RuntimeException(e);
}
// cannot find a good place to close these two data set, using Shutdown hook
Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
@Override
public void run() {
((BufferedFloatMLDataSet) (AbstractNNWorker.this.trainingData)).close();
((BufferedFloatMLDataSet) (AbstractNNWorker.this.validationData)).close();
}
}));
} else {
LOG.info("NNWorker is loading data into memory.");
double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6"));
long memoryStoreSize = (long) (Runtime.getRuntime().maxMemory() * memoryFraction);
LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction);
double crossValidationRate = this.modelConfig.getValidSetRate();
try {
if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) {
// fixed 0.6 and 0.4 of max memory for trainingData and validationData
this.trainingData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * 0.6), DTrainUtils.getTrainingFile().toString(), this.featureInputsCnt, this.outputNodeCount);
this.validationData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * 0.4), DTrainUtils.getTestingFile().toString(), this.featureInputsCnt, this.outputNodeCount);
} else {
this.trainingData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * (1 - crossValidationRate)), DTrainUtils.getTrainingFile().toString(), this.featureInputsCnt, this.outputNodeCount);
this.validationData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * crossValidationRate), DTrainUtils.getTestingFile().toString(), this.featureInputsCnt, this.outputNodeCount);
}
// cannot find a good place to close these two data set, using Shutdown hook
Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
@Override
public void run() {
((MemoryDiskFloatMLDataSet) (AbstractNNWorker.this.trainingData)).close();
((MemoryDiskFloatMLDataSet) (AbstractNNWorker.this.validationData)).close();
}
}));
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
}
}
// create Splitter
String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);
}
use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.
the class GuaguaParquetRecordReader method initialize.
/*
* (non-Javadoc)
*
* @see ml.shifu.guagua.io.GuaguaRecordReader#initialize(ml.shifu.guagua.io.GuaguaFileSplit)
*/
@Override
public void initialize(GuaguaFileSplit split) throws IOException {
ReadSupport<Tuple> readSupport = getReadSupportInstance(this.conf);
this.parquetRecordReader = new ParquetRecordReader<Tuple>(readSupport, getFilter(this.conf));
ParquetInputSplit parquetInputSplit = new ParquetInputSplit(new Path(split.getPath()), split.getOffset(), split.getOffset() + split.getLength(), split.getLength(), null, null);
try {
this.parquetRecordReader.initialize(parquetInputSplit, buildContext());
} catch (InterruptedException e) {
throw new GuaguaRuntimeException(e);
}
}
Aggregations