Search in sources :

Example 1 with MLDataPair

use of org.encog.ml.data.MLDataPair in project shifu by ShifuML.

the class AbstractTrainer method setDataSet.

/*
     * Set up the training dataset and validation dataset
     */
public void setDataSet(MLDataSet masterDataSet) throws IOException {
    log.info("Setting Data Set...");
    MLDataSet sampledDataSet;
    if (this.trainingOption.equalsIgnoreCase("M")) {
        log.info("Loading to Memory ...");
        sampledDataSet = new BasicMLDataSet();
        this.trainSet = new BasicMLDataSet();
        this.validSet = new BasicMLDataSet();
    } else if (this.trainingOption.equalsIgnoreCase("D")) {
        log.info("Loading to Disk ...");
        sampledDataSet = new BufferedMLDataSet(new File(Constants.TMP, "sampled.egb"));
        this.trainSet = new BufferedMLDataSet(new File(Constants.TMP, "train.egb"));
        this.validSet = new BufferedMLDataSet(new File(Constants.TMP, "valid.egb"));
        int inputSize = masterDataSet.getInputSize();
        int idealSize = masterDataSet.getIdealSize();
        ((BufferedMLDataSet) sampledDataSet).beginLoad(inputSize, idealSize);
        ((BufferedMLDataSet) trainSet).beginLoad(inputSize, idealSize);
        ((BufferedMLDataSet) validSet).beginLoad(inputSize, idealSize);
    } else {
        throw new RuntimeException("Training Option is not Valid: " + this.trainingOption);
    }
    // Encog 3.1
    // int masterSize = masterDataSet.size();
    // Encog 3.0
    int masterSize = (int) masterDataSet.getRecordCount();
    if (!modelConfig.isFixInitialInput()) {
        // Bagging
        if (modelConfig.isBaggingWithReplacement()) {
            // Bagging With Replacement
            int sampledSize = (int) (masterSize * baggingSampleRate);
            for (int i = 0; i < sampledSize; i++) {
                // Encog 3.1
                // sampledDataSet.add(masterDataSet.get(random.nextInt(masterSize)));
                // Encog 3.0
                double[] input = new double[masterDataSet.getInputSize()];
                double[] ideal = new double[1];
                MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
                masterDataSet.getRecord(random.nextInt(masterSize), pair);
                sampledDataSet.add(pair);
            }
        } else {
            // Bagging Without Replacement
            for (MLDataPair pair : masterDataSet) {
                if (random.nextDouble() < baggingSampleRate) {
                    sampledDataSet.add(pair);
                }
            }
        }
    } else {
        List<Integer> list = loadSampleInput((int) (masterSize * baggingSampleRate), masterSize, modelConfig.isBaggingWithReplacement());
        for (Integer i : list) {
            double[] input = new double[masterDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            masterDataSet.getRecord(i, pair);
            sampledDataSet.add(pair);
        }
    }
    if (this.trainingOption.equalsIgnoreCase("D")) {
        ((BufferedMLDataSet) sampledDataSet).endLoad();
    }
    // Cross Validation
    log.info("Generating Training Set and Validation Set ...");
    if (!modelConfig.isFixInitialInput()) {
        // Encog 3.0
        for (int i = 0; i < sampledDataSet.getRecordCount(); i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            if (random.nextDouble() > crossValidationRate) {
                trainSet.add(pair);
            } else {
                validSet.add(pair);
            }
        }
    } else {
        long sampleSize = sampledDataSet.getRecordCount();
        long trainSetSize = (long) (sampleSize * (1 - crossValidationRate));
        int i = 0;
        for (; i < trainSetSize; i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            trainSet.add(pair);
        }
        for (; i < sampleSize; i++) {
            double[] input = new double[sampledDataSet.getInputSize()];
            double[] ideal = new double[1];
            MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
            sampledDataSet.getRecord(i, pair);
            validSet.add(pair);
        }
    }
    if (this.trainingOption.equalsIgnoreCase("D")) {
        ((BufferedMLDataSet) trainSet).endLoad();
        ((BufferedMLDataSet) validSet).endLoad();
    }
    log.info("    - # Records of the Master Data Set: " + masterSize);
    log.info("    - Bagging Sample Rate: " + baggingSampleRate);
    log.info("    - Bagging With Replacement: " + modelConfig.isBaggingWithReplacement());
    log.info("    - # Records of the Selected Data Set: " + sampledDataSet.getRecordCount());
    log.info("        - Cross Validation Rate: " + crossValidationRate);
    log.info("        - # Records of the Training Set: " + this.getTrainSetSize());
    log.info("        - # Records of the Validation Set: " + this.getValidSetSize());
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet) BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) BasicMLData(org.encog.ml.data.basic.BasicMLData) File(java.io.File) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet)

Example 2 with MLDataPair

use of org.encog.ml.data.MLDataPair in project shifu by ShifuML.

the class AbstractTrainer method calculateMSE.

/*
     * non-synchronously version update error
     *
     * @return the standard error
     */
public static Double calculateMSE(BasicNetwork network, MLDataSet dataSet) {
    double mse = 0;
    long numRecords = dataSet.getRecordCount();
    for (int i = 0; i < numRecords; i++) {
        // Encog 3.1
        // MLDataPair pair = dataSet.get(i);
        // Encog 3.0
        double[] input = new double[dataSet.getInputSize()];
        double[] ideal = new double[1];
        MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
        dataSet.getRecord(i, pair);
        MLData result = network.compute(pair.getInput());
        double tmp = result.getData()[0] - pair.getIdeal().getData()[0];
        mse += tmp * tmp;
    }
    mse = mse / numRecords;
    return mse;
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) BasicMLData(org.encog.ml.data.basic.BasicMLData) BasicMLData(org.encog.ml.data.basic.BasicMLData) MLData(org.encog.ml.data.MLData)

Example 3 with MLDataPair

use of org.encog.ml.data.MLDataPair in project shifu by ShifuML.

the class DataLoadWorker method handleMsg.

/*
     * (non-Javadoc)
     * 
     * @see akka.actor.UntypedActor#onReceive(java.lang.Object)
     */
@Override
public void handleMsg(Object message) {
    if (message instanceof ScanStatsRawDataMessage) {
        log.info("DataLoaderActor Starting ...");
        ScanStatsRawDataMessage msg = (ScanStatsRawDataMessage) message;
        Scanner scanner = msg.getScanner();
        int totalMsgCnt = msg.getTotalMsgCnt();
        List<String> rawDataList = readDataIntoList(scanner);
        log.info("DataLoaderActor Finished: Loaded " + rawDataList.size() + " Records.");
        nextActorRef.tell(new StatsPartRawDataMessage(totalMsgCnt, rawDataList), getSelf());
    } else if (message instanceof ScanNormInputDataMessage) {
        log.info("DataLoaderActor Starting ...");
        ScanNormInputDataMessage msg = (ScanNormInputDataMessage) message;
        Scanner scanner = msg.getScanner();
        int totalMsgCnt = msg.getTotalMsgCnt();
        List<String> rawDataList = readDataIntoList(scanner);
        log.info("DataLoaderActor Finished: Loaded " + rawDataList.size() + " Records.");
        nextActorRef.tell(new NormPartRawDataMessage(totalMsgCnt, rawDataList), getSelf());
    } else if (message instanceof ScanTrainDataMessage) {
        ScanTrainDataMessage msg = (ScanTrainDataMessage) message;
        Scanner scanner = msg.getScanner();
        int totalMsgCnt = msg.getTotalMsgCnt();
        List<MLDataPair> mlDataPairList = readTrainingData(scanner, msg.isDryRun());
        log.info("DataLoaderActor Finished: Loaded " + mlDataPairList.size() + " Records for Training.");
        nextActorRef.tell(new TrainPartDataMessage(totalMsgCnt, msg.isDryRun(), mlDataPairList), getSelf());
    } else if (message instanceof ScanEvalDataMessage) {
        log.info("DataLoaderActor Starting ...");
        ScanEvalDataMessage msg = (ScanEvalDataMessage) message;
        Scanner scanner = msg.getScanner();
        int streamId = msg.getStreamId();
        int totalStreamCnt = msg.getTotalStreamCnt();
        splitDataIntoMultiMessages(streamId, totalStreamCnt, scanner, Environment.getInt(Environment.RECORD_CNT_PER_MESSAGE, 100000));
    /*
             * List<String> evalDataList = readDataIntoList(scanner);
             * 
             * log.info("DataLoaderActor Finished: Loaded " + evalDataList.size() + " Records.");
             * nextActorRef.tell( new RunModelDataMessage(totalMsgCnt, evalDataList), getSelf());
             */
    } else {
        unhandled(message);
    }
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) Scanner(java.util.Scanner) StatsPartRawDataMessage(ml.shifu.shifu.message.StatsPartRawDataMessage) ScanTrainDataMessage(ml.shifu.shifu.message.ScanTrainDataMessage) ScanNormInputDataMessage(ml.shifu.shifu.message.ScanNormInputDataMessage) ScanStatsRawDataMessage(ml.shifu.shifu.message.ScanStatsRawDataMessage) TrainPartDataMessage(ml.shifu.shifu.message.TrainPartDataMessage) NormPartRawDataMessage(ml.shifu.shifu.message.NormPartRawDataMessage) ScanEvalDataMessage(ml.shifu.shifu.message.ScanEvalDataMessage) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List)

Example 4 with MLDataPair

use of org.encog.ml.data.MLDataPair in project shifu by ShifuML.

the class DataLoadWorker method readTrainingData.

/**
     * Read the normalized training data for model training
     * 
     * @param scanner
     *            - input partition
     * @param isDryRun
     *            - is for test running?
     * @return List of data
     */
public List<MLDataPair> readTrainingData(Scanner scanner, boolean isDryRun) {
    List<MLDataPair> mlDataPairList = new ArrayList<MLDataPair>();
    int numSelected = 0;
    for (ColumnConfig config : columnConfigList) {
        if (config.isFinalSelect()) {
            numSelected++;
        }
    }
    int cnt = 0;
    while (scanner.hasNextLine()) {
        if ((cnt++) % 100000 == 0) {
            log.info("Read " + (cnt) + " Records.");
        }
        String line = scanner.nextLine();
        if (isDryRun) {
            MLDataPair dummyPair = new BasicMLDataPair(new BasicMLData(new double[1]), new BasicMLData(new double[1]));
            mlDataPairList.add(dummyPair);
            continue;
        }
        // the normalized training data is separated by | by default
        double[] inputs = new double[numSelected];
        double[] ideal = new double[1];
        double significance = 0.0d;
        int index = 0, inputsIndex = 0, outputIndex = 0;
        for (String input : DEFAULT_SPLITTER.split(line.trim())) {
            double doubleValue = NumberFormatUtils.getDouble(input.trim(), 0.0d);
            if (index == this.columnConfigList.size()) {
                significance = NumberFormatUtils.getDouble(input.trim(), CommonConstants.DEFAULT_SIGNIFICANCE_VALUE);
                break;
            } else {
                ColumnConfig columnConfig = this.columnConfigList.get(index);
                if (columnConfig != null && columnConfig.isTarget()) {
                    ideal[outputIndex++] = doubleValue;
                } else {
                    if (this.inputNodeCount == this.candidateCount) {
                        // all variables are not set final-select
                        if (CommonUtils.isGoodCandidate(columnConfig)) {
                            inputs[inputsIndex++] = doubleValue;
                        }
                    } else {
                        // final select some variables
                        if (columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
                            inputs[inputsIndex++] = doubleValue;
                        }
                    }
                }
            }
            index++;
        }
        MLDataPair pair = new BasicMLDataPair(new BasicMLData(inputs), new BasicMLData(ideal));
        pair.setSignificance(significance);
        mlDataPairList.add(pair);
    }
    return mlDataPairList;
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) BasicMLData(org.encog.ml.data.basic.BasicMLData) ArrayList(java.util.ArrayList)

Example 5 with MLDataPair

use of org.encog.ml.data.MLDataPair in project shifu by ShifuML.

the class TrainDataPrepWorker method handleMsg.

/*
     * (non-Javadoc)
     * 
     * @see akka.actor.UntypedActor#onReceive(java.lang.Object)
     */
@Override
public void handleMsg(Object message) throws IOException {
    if (message instanceof TrainPartDataMessage) {
        log.debug("Received value object list for training model.");
        TrainPartDataMessage msg = (TrainPartDataMessage) message;
        for (MLDataPair mlDataPir : msg.getMlDataPairList()) {
            if (modelConfig.isTrainOnDisk() && !initialized) {
                int inputSize = mlDataPir.getInput().size();
                int idealSize = mlDataPir.getIdeal().size();
                ((BufferedMLDataSet) masterDataSet).beginLoad(inputSize, idealSize);
                initialized = true;
            }
            masterDataSet.add(mlDataPir);
        }
        receivedMsgCnt++;
        log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
        if (receivedMsgCnt == msg.getTotalMsgCnt()) {
            if (modelConfig.isTrainOnDisk() && initialized) {
                ((BufferedMLDataSet) masterDataSet).endLoad();
            }
            for (AbstractTrainer trainer : trainers) {
                // if the trainOnDisk is true, setting the "D" option
                if (modelConfig.isTrainOnDisk()) {
                    trainer.setTrainingOption("D");
                }
                trainer.setDataSet(masterDataSet);
                nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
            }
            if (modelConfig.isTrainOnDisk() && initialized) {
                masterDataSet.close();
                masterDataSet = null;
            }
        }
    } else if (message instanceof StatsPartRawDataMessage) {
        StatsPartRawDataMessage msg = (StatsPartRawDataMessage) message;
        receivedMsgCnt++;
        log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
        if (receivedMsgCnt == msg.getTotalMsgCnt()) {
            for (AbstractTrainer trainer : trainers) {
                // ((DecisionTreeTrainer)trainer).setDataSet(rawInstanceList);
                nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
            }
        }
    } else {
        unhandled(message);
    }
}
Also used : MLDataPair(org.encog.ml.data.MLDataPair) TrainInstanceMessage(ml.shifu.shifu.message.TrainInstanceMessage) StatsPartRawDataMessage(ml.shifu.shifu.message.StatsPartRawDataMessage) TrainPartDataMessage(ml.shifu.shifu.message.TrainPartDataMessage) AbstractTrainer(ml.shifu.shifu.core.AbstractTrainer) BufferedMLDataSet(org.encog.ml.data.buffer.BufferedMLDataSet)

Aggregations

MLDataPair (org.encog.ml.data.MLDataPair)5 BasicMLDataPair (org.encog.ml.data.basic.BasicMLDataPair)4 BasicMLData (org.encog.ml.data.basic.BasicMLData)3 ArrayList (java.util.ArrayList)2 StatsPartRawDataMessage (ml.shifu.shifu.message.StatsPartRawDataMessage)2 TrainPartDataMessage (ml.shifu.shifu.message.TrainPartDataMessage)2 BufferedMLDataSet (org.encog.ml.data.buffer.BufferedMLDataSet)2 File (java.io.File)1 LinkedList (java.util.LinkedList)1 List (java.util.List)1 Scanner (java.util.Scanner)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 AbstractTrainer (ml.shifu.shifu.core.AbstractTrainer)1 NormPartRawDataMessage (ml.shifu.shifu.message.NormPartRawDataMessage)1 ScanEvalDataMessage (ml.shifu.shifu.message.ScanEvalDataMessage)1 ScanNormInputDataMessage (ml.shifu.shifu.message.ScanNormInputDataMessage)1 ScanStatsRawDataMessage (ml.shifu.shifu.message.ScanStatsRawDataMessage)1 ScanTrainDataMessage (ml.shifu.shifu.message.ScanTrainDataMessage)1 TrainInstanceMessage (ml.shifu.shifu.message.TrainInstanceMessage)1 MLData (org.encog.ml.data.MLData)1