Search in sources :

Example 1 with BufferedMLDataSet

use of org.encog.ml.data.buffer.BufferedMLDataSet in project shifu by ShifuML.

the class MemoryDiskMLDataSet method add.

/*
     * (non-Javadoc)
     * 
     * @see org.encog.ml.data.MLDataSet#add(org.encog.ml.data.MLData, org.encog.ml.data.MLData)
     */
@Override
public void add(MLData inputData, MLData idealData) {
    long currentSize = SizeEstimator.estimate(inputData) + SizeEstimator.estimate(idealData);
    if (this.byteSize + currentSize < this.maxByteSize) {
        this.byteSize += currentSize;
        this.memoryCount += 1l;
        this.memoryDataSet.add(inputData, idealData);
    } else {
        if (this.diskDataSet == null) {
            this.diskDataSet = new BufferedMLDataSet(new File(this.fileName));
            ((BufferedMLDataSet) this.diskDataSet).beginLoad(this.inputCount, this.outputCount);
        }
        this.byteSize += currentSize;
        this.diskCount += 1l;
        this.diskDataSet.add(inputData, idealData);
    }
}
Also used : File(java.io.File) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet)

Example 2 with BufferedMLDataSet

use of org.encog.ml.data.buffer.BufferedMLDataSet in project shifu by ShifuML.

the class MemoryDiskMLDataSet method add.

/*
     * (non-Javadoc)
     * 
     * @see org.encog.ml.data.MLDataSet#add(org.encog.ml.data.MLDataPair)
     */
@Override
public void add(MLDataPair inputData) {
    long currentSize = SizeEstimator.estimate(inputData);
    if (this.byteSize + currentSize < this.maxByteSize) {
        this.byteSize += currentSize;
        this.memoryCount += 1l;
        this.memoryDataSet.add(inputData);
    } else {
        if (this.diskDataSet == null) {
            this.diskDataSet = new BufferedMLDataSet(new File(this.fileName));
            ((BufferedMLDataSet) this.diskDataSet).beginLoad(this.inputCount, this.outputCount);
        }
        this.byteSize += currentSize;
        this.diskCount += 1l;
        this.diskDataSet.add(inputData);
    }
}
Also used : File(java.io.File) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet)

Example 3 with BufferedMLDataSet

use of org.encog.ml.data.buffer.BufferedMLDataSet in project shifu by ShifuML.

the class AbstractTrainer method setDataSet.

/*
     * Set up the training dataset and validation dataset
     */
public void setDataSet(MLDataSet masterDataSet) throws IOException {
    log.info("Setting Data Set...");
    MLDataSet sampledDataSet;
    if (this.trainingOption.equalsIgnoreCase("M")) {
        log.info("Loading to Memory ...");
        sampledDataSet = new BasicMLDataSet();
        this.trainSet = new BasicMLDataSet();
        this.validSet = new BasicMLDataSet();
    } else if (this.trainingOption.equalsIgnoreCase("D")) {
        log.info("Loading to Disk ...");
        sampledDataSet = new BufferedMLDataSet(new File(Constants.TMP, "sampled.egb"));
        this.trainSet = new BufferedMLDataSet(new File(Constants.TMP, "train.egb"));
        this.validSet = new BufferedMLDataSet(new File(Constants.TMP, "valid.egb"));
        int inputSize = masterDataSet.getInputSize();
        int idealSize = masterDataSet.getIdealSize();
        ((BufferedMLDataSet) sampledDataSet).beginLoad(inputSize, idealSize);
        ((BufferedMLDataSet) trainSet).beginLoad(inputSize, idealSize);
        ((BufferedMLDataSet) validSet).beginLoad(inputSize, idealSize);
    } else {
        throw new RuntimeException("Training Option is not Valid: " + this.trainingOption);
    }
    // Encog 3.1
    // int masterSize = masterDataSet.size();
    // Encog 3.0
    int masterSize = (int) masterDataSet.getRecordCount();
    if (!modelConfig.isFixInitialInput()) {
        // Bagging
        if (modelConfig.isBaggingWithReplacement()) {
            // Bagging With Replacement
            int sampledSize = (int) (masterSize * baggingSampleRate);
            for (int i = 0; i < sampledSize; i++) {
                // Encog 3.1
                // sampledDataSet.add(masterDataSet.get(random.nextInt(masterSize)));
                // Encog 3.0
                double[] input = new double[masterDataSet.getInputSize()];
                double[] ideal = new double[1];
                MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
                masterDataSet.getRecord(random.nextInt(masterSize), pair);
                sampledDataSet.add(pair);
            }
        } else {
            // Bagging Without Replacement
            for (MLDataPair pair : masterDataSet) {
                if (random.nextDouble() < baggingSampleRate) {
                    sampledDataSet.add(pair);
                }
            }
        }
    } else {
        List<Integer> list = loadSampleInput((int) (masterSize * baggingSampleRate), masterSize, modelConfig.isBaggingWithReplacement());
        for (Integer i : list) {
            double[] input = new double[masterDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            masterDataSet.getRecord(i, pair);
            sampledDataSet.add(pair);
        }
    }
    if (this.trainingOption.equalsIgnoreCase("D")) {
        ((BufferedMLDataSet) sampledDataSet).endLoad();
    }
    // Cross Validation
    log.info("Generating Training Set and Validation Set ...");
    if (!modelConfig.isFixInitialInput()) {
        // Encog 3.0
        for (int i = 0; i < sampledDataSet.getRecordCount(); i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            if (random.nextDouble() > crossValidationRate) {
                trainSet.add(pair);
            } else {
                validSet.add(pair);
            }
        }
    } else {
        long sampleSize = sampledDataSet.getRecordCount();
        long trainSetSize = (long) (sampleSize * (1 - crossValidationRate));
        int i = 0;
        for (; i < trainSetSize; i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            trainSet.add(pair);
        }
        for (; i < sampleSize; i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            validSet.add(pair);
        }
    }
    if (this.trainingOption.equalsIgnoreCase("D")) {
        ((BufferedMLDataSet) trainSet).endLoad();
        ((BufferedMLDataSet) validSet).endLoad();
    }
    log.info("    - # Records of the Master Data Set: " + masterSize);
    log.info("    - Bagging Sample Rate: " + baggingSampleRate);
    log.info("    - Bagging With Replacement: " + modelConfig.isBaggingWithReplacement());
    log.info("    - # Records of the Selected Data Set: " + sampledDataSet.getRecordCount());
    log.info("        - Cross Validation Rate: " + crossValidationRate);
    log.info("        - # Records of the Training Set: " + this.getTrainSetSize());
    log.info("        - # Records of the Validation Set: " + this.getValidSetSize());
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet) BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) BasicMLData(org.encog.ml.data.basic.BasicMLData) File(java.io.File) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet)

Example 4 with BufferedMLDataSet

use of org.encog.ml.data.buffer.BufferedMLDataSet in project shifu by ShifuML.

the class TrainDataPrepWorker method handleMsg.

/*
     * (non-Javadoc)
     * 
     * @see akka.actor.UntypedActor#onReceive(java.lang.Object)
     */
@Override
public void handleMsg(Object message) throws IOException {
    if (message instanceof TrainPartDataMessage) {
        log.debug("Received value object list for training model.");
        TrainPartDataMessage msg = (TrainPartDataMessage) message;
        for (MLDataPair mlDataPir : msg.getMlDataPairList()) {
            if (modelConfig.isTrainOnDisk() && !initialized) {
                int inputSize = mlDataPir.getInput().size();
                int idealSize = mlDataPir.getIdeal().size();
                ((BufferedMLDataSet) masterDataSet).beginLoad(inputSize, idealSize);
                initialized = true;
            }
            masterDataSet.add(mlDataPir);
        }
        receivedMsgCnt++;
        log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
        if (receivedMsgCnt == msg.getTotalMsgCnt()) {
            if (modelConfig.isTrainOnDisk() && initialized) {
                ((BufferedMLDataSet) masterDataSet).endLoad();
            }
            for (AbstractTrainer trainer : trainers) {
                // if the trainOnDisk is true, setting the "D" option
                if (modelConfig.isTrainOnDisk()) {
                    trainer.setTrainingOption("D");
                }
                trainer.setDataSet(masterDataSet);
                nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
            }
            if (modelConfig.isTrainOnDisk() && initialized) {
                masterDataSet.close();
                masterDataSet = null;
            }
        }
    } else if (message instanceof StatsPartRawDataMessage) {
        StatsPartRawDataMessage msg = (StatsPartRawDataMessage) message;
        receivedMsgCnt++;
        log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
        if (receivedMsgCnt == msg.getTotalMsgCnt()) {
            for (AbstractTrainer trainer : trainers) {
                // ((DecisionTreeTrainer)trainer).setDataSet(rawInstanceList);
                nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
            }
        }
    } else {
        unhandled(message);
    }
}
Also used : MLDataPair(org.encog.ml.data.MLDataPair) TrainInstanceMessage(ml.shifu.shifu.message.TrainInstanceMessage) StatsPartRawDataMessage(ml.shifu.shifu.message.StatsPartRawDataMessage) TrainPartDataMessage(ml.shifu.shifu.message.TrainPartDataMessage) AbstractTrainer(ml.shifu.shifu.core.AbstractTrainer) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet)

Aggregations

BufferedMLDataSet (org.encog.ml.data.buffer.BufferedMLDataSet)4 File (java.io.File)3 MLDataPair (org.encog.ml.data.MLDataPair)2 AbstractTrainer (ml.shifu.shifu.core.AbstractTrainer)1 StatsPartRawDataMessage (ml.shifu.shifu.message.StatsPartRawDataMessage)1 TrainInstanceMessage (ml.shifu.shifu.message.TrainInstanceMessage)1 TrainPartDataMessage (ml.shifu.shifu.message.TrainPartDataMessage)1 MLDataSet (org.encog.ml.data.MLDataSet)1 BasicMLData (org.encog.ml.data.basic.BasicMLData)1 BasicMLDataPair (org.encog.ml.data.basic.BasicMLDataPair)1 BasicMLDataSet (org.encog.ml.data.basic.BasicMLDataSet)1