Search in sources :

Example 1 with BasicMLDataSet

use of org.encog.ml.data.basic.BasicMLDataSet in project shifu by ShifuML.

the class AbstractTrainer method setDataSet.

/*
     * Set up the training dataset and validation dataset
     */
public void setDataSet(MLDataSet masterDataSet) throws IOException {
    log.info("Setting Data Set...");
    MLDataSet sampledDataSet;
    if (this.trainingOption.equalsIgnoreCase("M")) {
        log.info("Loading to Memory ...");
        sampledDataSet = new BasicMLDataSet();
        this.trainSet = new BasicMLDataSet();
        this.validSet = new BasicMLDataSet();
    } else if (this.trainingOption.equalsIgnoreCase("D")) {
        log.info("Loading to Disk ...");
        sampledDataSet = new BufferedMLDataSet(new File(Constants.TMP, "sampled.egb"));
        this.trainSet = new BufferedMLDataSet(new File(Constants.TMP, "train.egb"));
        this.validSet = new BufferedMLDataSet(new File(Constants.TMP, "valid.egb"));
        int inputSize = masterDataSet.getInputSize();
        int idealSize = masterDataSet.getIdealSize();
        ((BufferedMLDataSet) sampledDataSet).beginLoad(inputSize, idealSize);
        ((BufferedMLDataSet) trainSet).beginLoad(inputSize, idealSize);
        ((BufferedMLDataSet) validSet).beginLoad(inputSize, idealSize);
    } else {
        throw new RuntimeException("Training Option is not Valid: " + this.trainingOption);
    }
    // Encog 3.1
    // int masterSize = masterDataSet.size();
    // Encog 3.0
    int masterSize = (int) masterDataSet.getRecordCount();
    if (!modelConfig.isFixInitialInput()) {
        // Bagging
        if (modelConfig.isBaggingWithReplacement()) {
            // Bagging With Replacement
            int sampledSize = (int) (masterSize * baggingSampleRate);
            for (int i = 0; i < sampledSize; i++) {
                // Encog 3.1
                // sampledDataSet.add(masterDataSet.get(random.nextInt(masterSize)));
                // Encog 3.0
                double[] input = new double[masterDataSet.getInputSize()];
                double[] ideal = new double[1];
                MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
                masterDataSet.getRecord(random.nextInt(masterSize), pair);
                sampledDataSet.add(pair);
            }
        } else {
            // Bagging Without Replacement
            for (MLDataPair pair : masterDataSet) {
                if (random.nextDouble() < baggingSampleRate) {
                    sampledDataSet.add(pair);
                }
            }
        }
    } else {
        List<Integer> list = loadSampleInput((int) (masterSize * baggingSampleRate), masterSize, modelConfig.isBaggingWithReplacement());
        for (Integer i : list) {
            double[] input = new double[masterDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            masterDataSet.getRecord(i, pair);
            sampledDataSet.add(pair);
        }
    }
    if (this.trainingOption.equalsIgnoreCase("D")) {
        ((BufferedMLDataSet) sampledDataSet).endLoad();
    }
    // Cross Validation
    log.info("Generating Training Set and Validation Set ...");
    if (!modelConfig.isFixInitialInput()) {
        // Encog 3.0
        for (int i = 0; i < sampledDataSet.getRecordCount(); i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            if (random.nextDouble() > crossValidationRate) {
                trainSet.add(pair);
            } else {
                validSet.add(pair);
            }
        }
    } else {
        long sampleSize = sampledDataSet.getRecordCount();
        long trainSetSize = (long) (sampleSize * (1 - crossValidationRate));
        int i = 0;
        for (; i < trainSetSize; i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            trainSet.add(pair);
        }
        for (; i < sampleSize; i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            validSet.add(pair);
        }
    }
    if (this.trainingOption.equalsIgnoreCase("D")) {
        ((BufferedMLDataSet) trainSet).endLoad();
        ((BufferedMLDataSet) validSet).endLoad();
    }
    log.info("    - # Records of the Master Data Set: " + masterSize);
    log.info("    - Bagging Sample Rate: " + baggingSampleRate);
    log.info("    - Bagging With Replacement: " + modelConfig.isBaggingWithReplacement());
    log.info("    - # Records of the Selected Data Set: " + sampledDataSet.getRecordCount());
    log.info("        - Cross Validation Rate: " + crossValidationRate);
    log.info("        - # Records of the Training Set: " + this.getTrainSetSize());
    log.info("        - # Records of the Validation Set: " + this.getValidSetSize());
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet) BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) BasicMLData(org.encog.ml.data.basic.BasicMLData) File(java.io.File) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet)

Aggregations

File (java.io.File)1 MLDataPair (org.encog.ml.data.MLDataPair)1 MLDataSet (org.encog.ml.data.MLDataSet)1 BasicMLData (org.encog.ml.data.basic.BasicMLData)1 BasicMLDataPair (org.encog.ml.data.basic.BasicMLDataPair)1 BasicMLDataSet (org.encog.ml.data.basic.BasicMLDataSet)1 BufferedMLDataSet (org.encog.ml.data.buffer.BufferedMLDataSet)1