Search in sources :

Example 1 with BasicMLData

use of org.encog.ml.data.basic.BasicMLData in project shifu by ShifuML.

the class AbstractTrainer method setDataSet.

/*
     * Set up the training dataset and validation dataset
     */
public void setDataSet(MLDataSet masterDataSet) throws IOException {
    log.info("Setting Data Set...");
    MLDataSet sampledDataSet;
    if (this.trainingOption.equalsIgnoreCase("M")) {
        log.info("Loading to Memory ...");
        sampledDataSet = new BasicMLDataSet();
        this.trainSet = new BasicMLDataSet();
        this.validSet = new BasicMLDataSet();
    } else if (this.trainingOption.equalsIgnoreCase("D")) {
        log.info("Loading to Disk ...");
        sampledDataSet = new BufferedMLDataSet(new File(Constants.TMP, "sampled.egb"));
        this.trainSet = new BufferedMLDataSet(new File(Constants.TMP, "train.egb"));
        this.validSet = new BufferedMLDataSet(new File(Constants.TMP, "valid.egb"));
        int inputSize = masterDataSet.getInputSize();
        int idealSize = masterDataSet.getIdealSize();
        ((BufferedMLDataSet) sampledDataSet).beginLoad(inputSize, idealSize);
        ((BufferedMLDataSet) trainSet).beginLoad(inputSize, idealSize);
        ((BufferedMLDataSet) validSet).beginLoad(inputSize, idealSize);
    } else {
        throw new RuntimeException("Training Option is not Valid: " + this.trainingOption);
    }
    // Encog 3.1
    // int masterSize = masterDataSet.size();
    // Encog 3.0
    int masterSize = (int) masterDataSet.getRecordCount();
    if (!modelConfig.isFixInitialInput()) {
        // Bagging
        if (modelConfig.isBaggingWithReplacement()) {
            // Bagging With Replacement
            int sampledSize = (int) (masterSize * baggingSampleRate);
            for (int i = 0; i < sampledSize; i++) {
                // Encog 3.1
                // sampledDataSet.add(masterDataSet.get(random.nextInt(masterSize)));
                // Encog 3.0
                double[] input = new double[masterDataSet.getInputSize()];
                double[] ideal = new double[1];
                MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
                masterDataSet.getRecord(random.nextInt(masterSize), pair);
                sampledDataSet.add(pair);
            }
        } else {
            // Bagging Without Replacement
            for (MLDataPair pair : masterDataSet) {
                if (random.nextDouble() < baggingSampleRate) {
                    sampledDataSet.add(pair);
                }
            }
        }
    } else {
        List<Integer> list = loadSampleInput((int) (masterSize * baggingSampleRate), masterSize, modelConfig.isBaggingWithReplacement());
        for (Integer i : list) {
            double[] input = new double[masterDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            masterDataSet.getRecord(i, pair);
            sampledDataSet.add(pair);
        }
    }
    if (this.trainingOption.equalsIgnoreCase("D")) {
        ((BufferedMLDataSet) sampledDataSet).endLoad();
    }
    // Cross Validation
    log.info("Generating Training Set and Validation Set ...");
    if (!modelConfig.isFixInitialInput()) {
        // Encog 3.0
        for (int i = 0; i < sampledDataSet.getRecordCount(); i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            if (random.nextDouble() > crossValidationRate) {
                trainSet.add(pair);
            } else {
                validSet.add(pair);
            }
        }
    } else {
        long sampleSize = sampledDataSet.getRecordCount();
        long trainSetSize = (long) (sampleSize * (1 - crossValidationRate));
        int i = 0;
        for (; i < trainSetSize; i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            trainSet.add(pair);
        }
        for (; i < sampleSize; i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            validSet.add(pair);
        }
    }
    if (this.trainingOption.equalsIgnoreCase("D")) {
        ((BufferedMLDataSet) trainSet).endLoad();
        ((BufferedMLDataSet) validSet).endLoad();
    }
    log.info("    - # Records of the Master Data Set: " + masterSize);
    log.info("    - Bagging Sample Rate: " + baggingSampleRate);
    log.info("    - Bagging With Replacement: " + modelConfig.isBaggingWithReplacement());
    log.info("    - # Records of the Selected Data Set: " + sampledDataSet.getRecordCount());
    log.info("        - Cross Validation Rate: " + crossValidationRate);
    log.info("        - # Records of the Training Set: " + this.getTrainSetSize());
    log.info("        - # Records of the Validation Set: " + this.getValidSetSize());
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet) BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) BasicMLData(org.encog.ml.data.basic.BasicMLData) File(java.io.File) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet)

Example 2 with BasicMLData

use of org.encog.ml.data.basic.BasicMLData in project shifu by ShifuML.

the class AbstractTrainer method calculateMSE.

/*
     * non-synchronously version update error
     *
     * @return the standard error
     */
public static Double calculateMSE(BasicNetwork network, MLDataSet dataSet) {
    double mse = 0;
    long numRecords = dataSet.getRecordCount();
    for (int i = 0; i < numRecords; i++) {
        // Encog 3.1
        // MLDataPair pair = dataSet.get(i);
        // Encog 3.0
        double[] input = new double[dataSet.getInputSize()];
        double[] ideal = new double[1];
        MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
        dataSet.getRecord(i, pair);
        MLData result = network.compute(pair.getInput());
        double tmp = result.getData()[0] - pair.getIdeal().getData()[0];
        mse += tmp * tmp;
    }
    mse = mse / numRecords;
    return mse;
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) BasicMLData(org.encog.ml.data.basic.BasicMLData) BasicMLData(org.encog.ml.data.basic.BasicMLData) MLData(org.encog.ml.data.MLData)

Example 3 with BasicMLData

use of org.encog.ml.data.basic.BasicMLData in project shifu by ShifuML.

the class DataLoadWorker method readTrainingData.

/**
     * Read the normalized training data for model training
     * 
     * @param scanner
     *            - input partition
     * @param isDryRun
     *            - is for test running?
     * @return List of data
     */
public List<MLDataPair> readTrainingData(Scanner scanner, boolean isDryRun) {
    List<MLDataPair> mlDataPairList = new ArrayList<MLDataPair>();
    int numSelected = 0;
    for (ColumnConfig config : columnConfigList) {
        if (config.isFinalSelect()) {
            numSelected++;
        }
    }
    int cnt = 0;
    while (scanner.hasNextLine()) {
        if ((cnt++) % 100000 == 0) {
            log.info("Read " + (cnt) + " Records.");
        }
        String line = scanner.nextLine();
        if (isDryRun) {
            MLDataPair dummyPair = new BasicMLDataPair(new BasicMLData(new double[1]), new BasicMLData(new double[1]));
            mlDataPairList.add(dummyPair);
            continue;
        }
        // the normalized training data is separated by | by default
        double[] inputs = new double[numSelected];
        double[] ideal = new double[1];
        double significance = 0.0d;
        int index = 0, inputsIndex = 0, outputIndex = 0;
        for (String input : DEFAULT_SPLITTER.split(line.trim())) {
            double doubleValue = NumberFormatUtils.getDouble(input.trim(), 0.0d);
            if (index == this.columnConfigList.size()) {
                significance = NumberFormatUtils.getDouble(input.trim(), CommonConstants.DEFAULT_SIGNIFICANCE_VALUE);
                break;
            } else {
                ColumnConfig columnConfig = this.columnConfigList.get(index);
                if (columnConfig != null && columnConfig.isTarget()) {
                    ideal[outputIndex++] = doubleValue;
                } else {
                    if (this.inputNodeCount == this.candidateCount) {
                        // all variables are not set final-select
                        if (CommonUtils.isGoodCandidate(columnConfig)) {
                            inputs[inputsIndex++] = doubleValue;
                        }
                    } else {
                        // final select some variables
                        if (columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
                            inputs[inputsIndex++] = doubleValue;
                        }
                    }
                }
            }
            index++;
        }
        MLDataPair pair = new BasicMLDataPair(new BasicMLData(inputs), new BasicMLData(ideal));
        pair.setSignificance(significance);
        mlDataPairList.add(pair);
    }
    return mlDataPairList;
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) BasicMLData(org.encog.ml.data.basic.BasicMLData) ArrayList(java.util.ArrayList)

Aggregations

MLDataPair (org.encog.ml.data.MLDataPair)3 BasicMLData (org.encog.ml.data.basic.BasicMLData)3 BasicMLDataPair (org.encog.ml.data.basic.BasicMLDataPair)3 File (java.io.File)1 ArrayList (java.util.ArrayList)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 MLData (org.encog.ml.data.MLData)1 MLDataSet (org.encog.ml.data.MLDataSet)1 BasicMLDataSet (org.encog.ml.data.basic.BasicMLDataSet)1 BufferedMLDataSet (org.encog.ml.data.buffer.BufferedMLDataSet)1