Search in sources :

Example 6 with GuaguaRuntimeException

use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.

the class LogisticRegressionMaster method initOrRecoverParams.

private LogisticRegressionParams initOrRecoverParams(MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) {
    LOG.info("read from existing model");
    LogisticRegressionParams params = null;
    // read existing model weights
    try {
        Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
        LR existingModel = (LR) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
        if (existingModel == null) {
            params = initWeights();
            LOG.info("Starting to train model from scratch.");
        } else {
            params = initModelParams(existingModel);
            LOG.info("Starting to train model from existing model {}.", modelPath);
        }
    } catch (IOException e) {
        throw new GuaguaRuntimeException(e);
    }
    return params;
}
Also used : Path(org.apache.hadoop.fs.Path) LR(ml.shifu.shifu.core.LR) IOException(java.io.IOException) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException)

Example 7 with GuaguaRuntimeException

use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.

the class CorrelationReducer method objectToBytes.

public byte[] objectToBytes(Writable result) {
    ByteArrayOutputStream out = null;
    DataOutputStream dataOut = null;
    try {
        out = new ByteArrayOutputStream();
        dataOut = new DataOutputStream(out);
        result.write(dataOut);
    } catch (IOException e) {
        throw new GuaguaRuntimeException(e);
    } finally {
        if (dataOut != null) {
            try {
                dataOut.close();
            } catch (IOException e) {
                throw new GuaguaRuntimeException(e);
            }
        }
    }
    return out.toByteArray();
}
Also used : DataOutputStream(java.io.DataOutputStream) ByteArrayOutputStream(java.io.ByteArrayOutputStream) IOException(java.io.IOException) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException)

Example 8 with GuaguaRuntimeException

use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.

the class NNParquetWorker method load.

@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Tuple> currentValue, WorkerContext<NNParams, NNParams> workerContext) {
    // init field list for later read
    this.initFieldList();
    LOG.info("subFeatureSet size: {} ; subFeatureSet: {}", subFeatureSet.size(), subFeatureSet);
    super.count += 1;
    if ((super.count) % 5000 == 0) {
        LOG.info("Read {} records.", super.count);
    }
    float[] inputs = new float[super.featureInputsCnt];
    float[] ideal = new float[super.outputNodeCount];
    if (super.isDry) {
        // dry train, use empty data.
        addDataPairToDataSet(0, new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal)));
        return;
    }
    long hashcode = 0;
    float significance = 1f;
    // use guava Splitter to iterate only once
    // use NNConstants.NN_DEFAULT_COLUMN_SEPARATOR to replace getModelConfig().getDataSetDelimiter(), super follows
    // the function in akka mode.
    int index = 0, inputsIndex = 0, outputIndex = 0;
    Tuple tuple = currentValue.getWritable();
    // back from foreach to for loop because of in earlier version, tuple cannot be iterable.
    for (int i = 0; i < tuple.size(); i++) {
        Object element = null;
        try {
            element = tuple.get(i);
        } catch (ExecException e) {
            throw new GuaguaRuntimeException(e);
        }
        float floatValue = 0f;
        if (element != null) {
            if (element instanceof Float) {
                floatValue = (Float) element;
            } else {
                // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
                floatValue = element.toString().length() == 0 ? 0f : NumberFormatUtils.getFloat(element.toString(), 0f);
            }
        }
        // no idea about why NaN in input data, we should process it as missing value TODO , according to norm type
        floatValue = (Float.isNaN(floatValue) || Double.isNaN(floatValue)) ? 0f : floatValue;
        if (index == (super.inputNodeCount + super.outputNodeCount)) {
            // weight, how to process???
            if (StringUtils.isBlank(modelConfig.getWeightColumnName())) {
                significance = 1f;
                // break here if we reach weight column which is last column
                break;
            }
            assert element != null;
            if (element != null && element instanceof Float) {
                significance = (Float) element;
            } else {
                // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
                significance = element.toString().length() == 0 ? 1f : NumberFormatUtils.getFloat(element.toString(), 1f);
            }
            // if invalid weight, set it to 1f and warning in log
            if (Float.compare(significance, 0f) < 0) {
                LOG.warn("The {} record in current worker weight {} is less than 0f, it is invalid, set it to 1.", count, significance);
                significance = 1f;
            }
            // break here if we reach weight column which is last column
            break;
        } else {
            int columnIndex = requiredFieldList.getFields().get(index).getIndex();
            if (columnIndex >= super.columnConfigList.size()) {
                assert element != null;
                if (element != null && element instanceof Float) {
                    significance = (Float) element;
                } else {
                    // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
                    significance = element.toString().length() == 0 ? 1f : NumberFormatUtils.getFloat(element.toString(), 1f);
                }
                break;
            } else {
                ColumnConfig columnConfig = super.columnConfigList.get(columnIndex);
                if (columnConfig != null && columnConfig.isTarget()) {
                    if (modelConfig.isRegression()) {
                        ideal[outputIndex++] = floatValue;
                    } else {
                        if (modelConfig.getTrain().isOneVsAll()) {
                            // if one vs all, set correlated idea value according to trainerId which means in
                            // trainer with id 0, target 0 is treated with 1, other are 0. Such target value are set
                            // to index of tags like [0, 1, 2, 3] compared with ["a", "b", "c", "d"]
                            ideal[outputIndex++] = Float.compare(floatValue, trainerId) == 0 ? 1f : 0f;
                        } else {
                            if (modelConfig.getTags().size() == 2) {
                                // if only 2 classes, output node is 1 node. if target = 0 means 0 is the index for
                                // positive prediction, set positive to 1 and negative to 0
                                int ideaIndex = (int) floatValue;
                                ideal[0] = ideaIndex == 0 ? 1f : 0f;
                            } else {
                                // for multiple classification
                                int ideaIndex = (int) floatValue;
                                ideal[ideaIndex] = 1f;
                            }
                        }
                    }
                } else {
                    if (subFeatureSet.contains(columnIndex)) {
                        inputs[inputsIndex++] = floatValue;
                        hashcode = hashcode * 31 + Double.valueOf(floatValue).hashCode();
                    }
                }
            }
        }
        index += 1;
    }
    // is helped to quick find such issue.
    if (inputsIndex != inputs.length) {
        String delimiter = workerContext.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER, Constants.DEFAULT_DELIMITER);
        throw new RuntimeException("Input length is inconsistent with parsing size. Input original size: " + inputs.length + ", parsing size:" + inputsIndex + ", delimiter:" + delimiter + ".");
    }
    // sample negative only logic here
    if (modelConfig.getTrain().getSampleNegOnly()) {
        if (this.modelConfig.isFixInitialInput()) {
            // if fixInitialInput, sample hashcode in 1-sampleRate range out if negative records
            int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId;
            // here BaggingSampleRate means how many data will be used in training and validation, if it is 0.8, we
            // should take 1-0.8 to check endHashCode
            int endHashCode = startHashCode + Double.valueOf((1d - this.modelConfig.getBaggingSampleRate()) * 100).intValue();
            if ((modelConfig.isRegression() || // regression or
            (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
            (int) (ideal[0] + 0.01d) == // negative record
            0 && isInRange(hashcode, startHashCode, endHashCode)) {
                return;
            }
        } else {
            // if negative record
            if ((modelConfig.isRegression() || // regression or
            (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
            (int) (ideal[0] + 0.01d) == // negative record
            0 && Double.compare(super.sampelNegOnlyRandom.nextDouble(), this.modelConfig.getBaggingSampleRate()) >= 0) {
                return;
            }
        }
    }
    FloatMLDataPair pair = new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal));
    // up sampling logic
    if (modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal[0], 1d) == 0) {
        // Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
        pair.setSignificance(significance * (super.upSampleRng.sample() + 1));
    } else {
        pair.setSignificance(significance);
    }
    boolean isValidation = false;
    if (workerContext.getAttachment() != null && workerContext.getAttachment() instanceof Boolean) {
        isValidation = (Boolean) workerContext.getAttachment();
    }
    boolean isInTraining = addDataPairToDataSet(hashcode, pair, isValidation);
    // do bagging sampling only for training data
    if (isInTraining) {
        float subsampleWeights = sampleWeights(pair.getIdealArray()[0]);
        if (isPositive(pair.getIdealArray()[0])) {
            this.positiveSelectedTrainCount += subsampleWeights * 1L;
        } else {
            this.negativeSelectedTrainCount += subsampleWeights * 1L;
        }
        // set weights to significance, if 0, significance will be 0, that is bagging sampling
        pair.setSignificance(pair.getSignificance() * subsampleWeights);
    } else {
    // for validation data, according bagging sampling logic, we may need to sampling validation data set, while
    // validation data set are only used to compute validation error, not to do real sampling is ok.
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ExecException(org.apache.pig.backend.executionengine.ExecException) BasicFloatMLDataPair(ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair) FloatMLDataPair(ml.shifu.shifu.core.dtrain.dataset.FloatMLDataPair) BasicFloatMLDataPair(ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair) BasicFloatMLData(ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLData) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) Tuple(org.apache.pig.data.Tuple)

Example 9 with GuaguaRuntimeException

use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.

the class AbstractNNWorker method init.

@Override
public void init(WorkerContext<NNParams, NNParams> context) {
    // load props firstly
    this.props = context.getProps();
    loadConfigFiles(context.getProps());
    this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
    this.validParams = this.modelConfig.getTrain().getParams();
    if (gs.hasHyperParam()) {
        this.validParams = gs.getParams(trainerId);
        LOG.info("Start grid search master with params: {}", validParams);
    }
    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;
        LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation);
    }
    this.poissonSampler = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(NNConstants.NN_POISON_SAMPLER));
    this.rng = new PoissonDistribution(1.0d);
    Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight();
    if (Double.compare(upSampleWeight, 1d) != 0 && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) {
        // set mean to upSampleWeight -1 and get sample + 1to make sure no zero sample value
        LOG.info("Enable up sampling with weight {}.", upSampleWeight);
        this.upSampleRng = new PoissonDistribution(upSampleWeight - 1);
    }
    Integer epochsPerIterationInteger = this.modelConfig.getTrain().getEpochsPerIteration();
    this.epochsPerIteration = epochsPerIterationInteger == null ? 1 : epochsPerIterationInteger.intValue();
    LOG.info("epochsPerIteration in worker is :{}", epochsPerIteration);
    // Object elmObject = validParams.get(DTrainUtils.IS_ELM);
    // isELM = elmObject == null ? false : "true".equalsIgnoreCase(elmObject.toString());
    // LOG.info("Check isELM: {}", isELM);
    Object dropoutRateObj = validParams.get(CommonConstants.DROPOUT_RATE);
    if (dropoutRateObj != null) {
        this.dropoutRate = Double.valueOf(dropoutRateObj.toString());
    }
    LOG.info("'dropoutRate' in worker is :{}", this.dropoutRate);
    Object miniBatchO = validParams.get(CommonConstants.MINI_BATCH);
    if (miniBatchO != null) {
        int miniBatchs;
        try {
            miniBatchs = Integer.parseInt(miniBatchO.toString());
        } catch (Exception e) {
            miniBatchs = 1;
        }
        if (miniBatchs < 0) {
            this.batchs = 1;
        } else if (miniBatchs > 1000) {
            this.batchs = 1000;
        } else {
            this.batchs = miniBatchs;
        }
        LOG.info("'miniBatchs' in worker is : {}, batchs is {} ", miniBatchs, batchs);
    }
    int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
    this.inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
    // if is one vs all classification, outputNodeCount is set to 1, if classes=2, outputNodeCount is also 1
    int classes = modelConfig.getTags().size();
    this.outputNodeCount = (isLinearTarget || modelConfig.isRegression()) ? inputOutputIndex[1] : (modelConfig.getTrain().isOneVsAll() ? inputOutputIndex[1] : (classes == 2 ? 1 : classes));
    this.candidateCount = inputOutputIndex[2];
    boolean isAfterVarSelect = inputOutputIndex[0] != 0;
    LOG.info("isAfterVarSelect {}: Input count {}, output count {}, candidate count {}", isAfterVarSelect, inputNodeCount, outputNodeCount, candidateCount);
    // cache all feature list for sampling features
    this.allFeatures = NormalUtils.getAllFeatureList(columnConfigList, isAfterVarSelect);
    String subsetStr = context.getProps().getProperty(CommonConstants.SHIFU_NN_FEATURE_SUBSET);
    if (StringUtils.isBlank(subsetStr)) {
        this.subFeatures = this.allFeatures;
    } else {
        String[] splits = subsetStr.split(",");
        this.subFeatures = new ArrayList<Integer>(splits.length);
        for (String split : splits) {
            int featureIndex = Integer.parseInt(split);
            this.subFeatures.add(featureIndex);
        }
    }
    this.subFeatureSet = new HashSet<Integer>(this.subFeatures);
    LOG.info("subFeatures size is {}", subFeatures.size());
    this.featureInputsCnt = DTrainUtils.getFeatureInputsCnt(this.modelConfig, this.columnConfigList, this.subFeatureSet);
    this.wgtInit = "default";
    Object wgtInitObj = validParams.get(CommonConstants.WEIGHT_INITIALIZER);
    if (wgtInitObj != null) {
        this.wgtInit = wgtInitObj.toString();
    }
    Object lossObj = validParams.get("Loss");
    this.lossStr = lossObj != null ? lossObj.toString() : "squared";
    LOG.info("Loss str is {}", this.lossStr);
    this.isDry = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(CommonConstants.SHIFU_DRY_DTRAIN));
    this.isSpecificValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig.getValidationDataSetRawPath()));
    this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample();
    if (isOnDisk()) {
        LOG.info("NNWorker is loading data into disk.");
        try {
            initDiskDataSet();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        // cannot find a good place to close these two data set, using Shutdown hook
        Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {

            @Override
            public void run() {
                ((BufferedFloatMLDataSet) (AbstractNNWorker.this.trainingData)).close();
                ((BufferedFloatMLDataSet) (AbstractNNWorker.this.validationData)).close();
            }
        }));
    } else {
        LOG.info("NNWorker is loading data into memory.");
        double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6"));
        long memoryStoreSize = (long) (Runtime.getRuntime().maxMemory() * memoryFraction);
        LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction);
        double crossValidationRate = this.modelConfig.getValidSetRate();
        try {
            if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) {
                // fixed 0.6 and 0.4 of max memory for trainingData and validationData
                this.trainingData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * 0.6), DTrainUtils.getTrainingFile().toString(), this.featureInputsCnt, this.outputNodeCount);
                this.validationData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * 0.4), DTrainUtils.getTestingFile().toString(), this.featureInputsCnt, this.outputNodeCount);
            } else {
                this.trainingData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * (1 - crossValidationRate)), DTrainUtils.getTrainingFile().toString(), this.featureInputsCnt, this.outputNodeCount);
                this.validationData = new MemoryDiskFloatMLDataSet((long) (memoryStoreSize * crossValidationRate), DTrainUtils.getTestingFile().toString(), this.featureInputsCnt, this.outputNodeCount);
            }
            // cannot find a good place to close these two data set, using Shutdown hook
            Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {

                @Override
                public void run() {
                    ((MemoryDiskFloatMLDataSet) (AbstractNNWorker.this.trainingData)).close();
                    ((MemoryDiskFloatMLDataSet) (AbstractNNWorker.this.validationData)).close();
                }
            }));
        } catch (IOException e) {
            throw new GuaguaRuntimeException(e);
        }
    }
    // create Splitter
    String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
    this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) MemoryDiskFloatMLDataSet(ml.shifu.shifu.core.dtrain.dataset.MemoryDiskFloatMLDataSet) IOException(java.io.IOException) IOException(java.io.IOException) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) BufferedFloatMLDataSet(ml.shifu.shifu.core.dtrain.dataset.BufferedFloatMLDataSet)

Example 10 with GuaguaRuntimeException

use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.

the class GuaguaParquetRecordReader method initialize.

/*
     * (non-Javadoc)
     * 
     * @see ml.shifu.guagua.io.GuaguaRecordReader#initialize(ml.shifu.guagua.io.GuaguaFileSplit)
     */
@Override
public void initialize(GuaguaFileSplit split) throws IOException {
    ReadSupport<Tuple> readSupport = getReadSupportInstance(this.conf);
    this.parquetRecordReader = new ParquetRecordReader<Tuple>(readSupport, getFilter(this.conf));
    ParquetInputSplit parquetInputSplit = new ParquetInputSplit(new Path(split.getPath()), split.getOffset(), split.getOffset() + split.getLength(), split.getLength(), null, null);
    try {
        this.parquetRecordReader.initialize(parquetInputSplit, buildContext());
    } catch (InterruptedException e) {
        throw new GuaguaRuntimeException(e);
    }
}
Also used : Path(org.apache.hadoop.fs.Path) ParquetInputSplit(parquet.hadoop.ParquetInputSplit) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) Tuple(org.apache.pig.data.Tuple)

Aggregations

GuaguaRuntimeException (ml.shifu.guagua.GuaguaRuntimeException)10 IOException (java.io.IOException)7 Path (org.apache.hadoop.fs.Path)4 GridSearch (ml.shifu.shifu.core.dtrain.gs.GridSearch)2 FSDataInputStream (org.apache.hadoop.fs.FSDataInputStream)2 FileSystem (org.apache.hadoop.fs.FileSystem)2 Tuple (org.apache.pig.data.Tuple)2 ByteArrayOutputStream (java.io.ByteArrayOutputStream)1 DataOutputStream (java.io.DataOutputStream)1 Comparator (java.util.Comparator)1 Properties (java.util.Properties)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)1 LR (ml.shifu.shifu.core.LR)1 TreeModel (ml.shifu.shifu.core.TreeModel)1 BasicFloatMLData (ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLData)1 BasicFloatMLDataPair (ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair)1 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)1 BufferedFloatMLDataSet (ml.shifu.shifu.core.dtrain.dataset.BufferedFloatMLDataSet)1 FloatMLDataPair (ml.shifu.shifu.core.dtrain.dataset.FloatMLDataPair)1