Search in sources :

Example 1 with NFoldFrameExtractor

use of hex.NFoldFrameExtractor in project h2o-2 by h2oai.

the class CrossValUtils method crossValidate.

/**
   * Cross-Validate a ValidatedJob
   * @param job (must contain valid entries for n_folds, validation, destination_key, source, response)
   */
public static void crossValidate(Job.ValidatedJob job) {
    //don't do cross-validation if the full model builder failed
    if (job.state != Job.JobState.RUNNING)
        return;
    if (job.validation != null)
        throw new IllegalArgumentException("Cannot provide validation dataset and n_folds > 0 at the same time.");
    if (job.n_folds <= 1)
        throw new IllegalArgumentException("n_folds must be >= 2 for cross-validation.");
    final String basename = job.destination_key.toString();
    long[] offsets = new long[job.n_folds + 1];
    Frame[] cv_preds = new Frame[job.n_folds];
    try {
        for (int i = 0; i < job.n_folds; ++i) {
            if (job.state != Job.JobState.RUNNING)
                break;
            Key[] destkeys = new Key[] { Key.make(basename + "_xval" + i + "_train"), Key.make(basename + "_xval" + i + "_holdout") };
            NFoldFrameExtractor nffe = new NFoldFrameExtractor(job.source, job.n_folds, i, destkeys, Key.make());
            H2O.submitTask(nffe);
            Frame[] splits = nffe.getResult();
            // Cross-validate individual splits
            try {
                //this removes the enum-ified response!
                job.crossValidate(splits, cv_preds, offsets, i);
                job._cv_count++;
            } finally {
                // clean-up the results
                if (!job.keep_cross_validation_splits)
                    for (Frame f : splits) f.delete();
            }
        }
        if (job.state != Job.JobState.RUNNING)
            return;
        final int resp_idx = job.source.find(job._responseName);
        Vec response = job.source.vecs()[resp_idx];
        // In the case of rebalance, rebalance response will be deleted
        boolean put_back = UKV.get(job.response._key) == null;
        if (put_back) {
            job.response = response;
            if (job.classification)
                job.response = job.response.toEnum();
            //put enum-ified response back to K-V store
            DKV.put(job.response._key, job.response);
        }
        ((Model) UKV.get(job.destination_key)).scoreCrossValidation(job, job.source, response, cv_preds, offsets);
        if (put_back)
            UKV.remove(job.response._key);
    } finally {
        // clean-up prediction frames for splits
        for (Frame f : cv_preds) if (f != null)
            f.delete();
    }
}
Also used : Frame(water.fvec.Frame) Vec(water.fvec.Vec) NFoldFrameExtractor(hex.NFoldFrameExtractor)

Example 2 with NFoldFrameExtractor

use of hex.NFoldFrameExtractor in project h2o-2 by h2oai.

the class NFoldFrameExtractPage method execImpl.

@Override
protected void execImpl() {
    NFoldFrameExtractor extractor = new NFoldFrameExtractor(source, nfolds, afold, null, null);
    H2O.submitTask(extractor);
    Frame[] splits = extractor.getResult();
    split_keys = new Key[splits.length];
    split_rows = new long[splits.length];
    long sum = 0;
    for (int i = 0; i < splits.length; i++) {
        sum += splits[i].numRows();
        split_keys[i] = splits[i]._key;
        split_rows[i] = splits[i].numRows();
    }
    assert sum == source.numRows() : "Frame split produced wrong number of rows: nrows(source) != sum(nrows(splits))";
}
Also used : Frame(water.fvec.Frame) NFoldFrameExtractor(hex.NFoldFrameExtractor)

Aggregations

NFoldFrameExtractor (hex.NFoldFrameExtractor)2 Frame (water.fvec.Frame)2 Vec (water.fvec.Vec)1