Search in sources :

Example 1 with DTree

use of hex.gbm.DTree in project h2o-2 by h2oai.

the class DRF method buildNextKTrees.

// --------------------------------------------------------------------------
// Build the next random k-trees representing tid-th tree
private DTree[] buildNextKTrees(Frame fr, int mtrys, float sample_rate, Random rand, int tid) {
    // We're going to build K (nclass) trees - each focused on correcting
    // errors for a single class.
    final DTree[] ktrees = new DTree[_nclass];
    // Initial set of histograms.  All trees; one leaf per tree (the root
    // leaf); all columns
    DHistogram[][][] hcs = new DHistogram[_nclass][1][_ncols];
    // Adjust nbins for the top-levels
    int adj_nbins = Math.max((1 << (10 - 0)), nbins);
    // Use for all k-trees the same seed. NOTE: this is only to make a fair
    // view for all k-trees
    long rseed = rand.nextLong();
    // Initially setup as-if an empty-split had just happened
    for (int k = 0; k < _nclass; k++) {
        assert (_distribution != null && classification) || (_distribution == null && !classification);
        if (_distribution == null || _distribution[k] != 0) {
            // Ignore missing classes
            // The Boolean Optimization cannot be applied here for RF !
            // This optimization assumes the 2nd tree of a 2-class system is the
            // inverse of the first.  This is false for DRF (and true for GBM) -
            // DRF picks a random different set of columns for the 2nd tree.
            //if( k==1 && _nclass==2 ) continue;
            ktrees[k] = new DRFTree(fr, _ncols, (char) nbins, (char) _nclass, min_rows, mtrys, rseed);
            boolean isBinom = classification;
            // The "root" node
            new DRFUndecidedNode(ktrees[k], -1, DHistogram.initialHist(fr, _ncols, adj_nbins, hcs[k][0], min_rows, do_grpsplit, isBinom));
        }
    }
    // Sample - mark the lines by putting 'OUT_OF_BAG' into nid(<klass>) vector
    Timer t_1 = new Timer();
    Sample[] ss = new Sample[_nclass];
    for (int k = 0; k < _nclass; k++) if (ktrees[k] != null)
        ss[k] = new Sample((DRFTree) ktrees[k], sample_rate).dfork(0, new Frame(vec_nids(fr, k), vec_resp(fr, k)), build_tree_one_node);
    for (int k = 0; k < _nclass; k++) if (ss[k] != null)
        ss[k].getResult();
    Log.debug(Sys.DRF__, "Sampling took: + " + t_1);
    // Define a "working set" of leaf splits, from leafs[i] to tree._len for each tree i
    int[] leafs = new int[_nclass];
    // ----
    // One Big Loop till the ktrees are of proper depth.
    // Adds a layer to the trees each pass.
    Timer t_2 = new Timer();
    int depth = 0;
    for (; depth < max_depth; depth++) {
        if (!Job.isRunning(self()))
            return null;
        hcs = buildLayer(fr, ktrees, leafs, hcs, true, build_tree_one_node);
        // If we did not make any new splits, then the tree is split-to-death
        if (hcs == null)
            break;
    }
    Log.debug(Sys.DRF__, "Tree build took: " + t_2);
    // Each tree bottomed-out in a DecidedNode; go 1 more level and insert
    // LeafNodes to hold predictions.
    Timer t_3 = new Timer();
    for (int k = 0; k < _nclass; k++) {
        DTree tree = ktrees[k];
        if (tree == null)
            continue;
        int leaf = leafs[k] = tree.len();
        for (int nid = 0; nid < leaf; nid++) {
            if (tree.node(nid) instanceof DecidedNode) {
                DecidedNode dn = tree.decided(nid);
                for (int i = 0; i < dn._nids.length; i++) {
                    int cnid = dn._nids[i];
                    if (// Bottomed out (predictors or responses known constant)
                    cnid == -1 || // Or chopped off for depth
                    tree.node(cnid) instanceof UndecidedNode || (// Or not possible to split
                    tree.node(cnid) instanceof DecidedNode && ((DecidedNode) tree.node(cnid))._split.col() == -1)) {
                        LeafNode ln = new DRFLeafNode(tree, nid);
                        // Set prediction into the leaf
                        ln._pred = dn.pred(i);
                        // Mark a leaf here
                        dn._nids[i] = ln.nid();
                    }
                }
                // Handle the trivial non-splitting tree
                if (nid == 0 && dn._split.col() == -1)
                    new DRFLeafNode(tree, -1, 0);
            }
        }
    }
    // -- k-trees are done
    Log.debug(Sys.DRF__, "Nodes propagation: " + t_3);
    // ----
    // Move rows into the final leaf rows
    Timer t_4 = new Timer();
    CollectPreds cp = new CollectPreds(ktrees, leafs).doAll(fr, build_tree_one_node);
    if (importance) {
        if (// Track right votes over OOB rows for this tree
        classification)
            // Track right votes over OOB rows for this tree
            asVotes(_treeMeasuresOnOOB).append(cp.rightVotes, cp.allRows);
        else
            /* regression */
            asSSE(_treeMeasuresOnOOB).append(cp.sse, cp.allRows);
    }
    Log.debug(Sys.DRF__, "CollectPreds done: " + t_4);
    // Collect leaves stats
    for (int i = 0; i < ktrees.length; i++) if (ktrees[i] != null)
        ktrees[i].leaves = ktrees[i].len() - leafs[i];
    return ktrees;
}
Also used : Frame(water.fvec.Frame) UndecidedNode(hex.gbm.DTree.UndecidedNode) DTree(hex.gbm.DTree) DecidedNode(hex.gbm.DTree.DecidedNode) DHistogram(hex.gbm.DHistogram) LeafNode(hex.gbm.DTree.LeafNode)

Example 2 with DTree

use of hex.gbm.DTree in project h2o-2 by h2oai.

the class DRF method doVarImpCalc.

//  /** On-the-fly version for varimp. After generation a new tree, its tree votes are collected on shuffled
//   * OOB rows and variable importance is recomputed.
//   * <p>
//   * The <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#varimp">page</a> says:
//   * <cite>
//   * "In every tree grown in the forest, put down the oob cases and count the number of votes cast for the correct class.
//   * Now randomly permute the values of variable m in the oob cases and put these cases down the tree.
//   * Subtract the number of votes for the correct class in the variable-m-permuted oob data from the number of votes
//   * for the correct class in the untouched oob data.
//   * The average of this number over all trees in the forest is the raw importance score for variable m."
//   * </cite>
//   * </p>
//   * */
//  @Override
//  protected VarImp doVarImpCalc(final DRFModel model, DTree[] ktrees, final int tid, final Frame fTrain, boolean scale) {
//    // Check if we have already serialized 'ktrees'-trees in the model
//    assert model.ntrees()-1-_ntreesFromCheckpoint == tid : "Cannot compute DRF varimp since 'ktrees' are not serialized in the model! tid="+tid;
//    assert _treeMeasuresOnOOB.npredictors()-1 == tid : "Tree votes over OOB rows for this tree (var ktrees) were not found!";
//    // Compute tree votes over shuffled data
//    final CompressedTree[/*nclass*/] theTree = model.ctree(tid); // get the last tree FIXME we should pass only keys
//    final int nclasses = model.nclasses();
//    Futures fs = new Futures();
//    for (int var=0; var<_ncols; var++) {
//      final int variable = var;
//      H2OCountedCompleter task4var = classification ? new H2OCountedCompleter() {
//        @Override public void compute2() {
//          // Compute this tree votes over all data over given variable
//          TreeVotes cd = TreeMeasuresCollector.collectVotes(theTree, nclasses, fTrain, _ncols, sample_rate, variable);
//          assert cd.npredictors() == 1;
//          asVotes(_treeMeasuresOnSOOB[variable]).append(cd);
//          tryComplete();
//        }
//      } : /* regression */ new H2OCountedCompleter() {
//        @Override public void compute2() {
//          // Compute this tree votes over all data over given variable
//          TreeSSE cd = TreeMeasuresCollector.collectSSE(theTree, nclasses, fTrain, _ncols, sample_rate, variable);
//          assert cd.npredictors() == 1;
//          asSSE(_treeMeasuresOnSOOB[variable]).append(cd);
//          tryComplete();
//        }
//      };
//      fs.add(task4var);
//      H2O.submitTask(task4var); // Fork computation
//    }
//    fs.blockForPending(); // Wait for results
//    // Compute varimp for individual features (_ncols)
//    final float[] varimp   = new float[_ncols]; // output variable importance
//    final float[] varimpSD = new float[_ncols]; // output variable importance sd
//    for (int var=0; var<_ncols; var++) {
//      double[/*2*/] imp = classification ? asVotes(_treeMeasuresOnSOOB[var]).imp(asVotes(_treeMeasuresOnOOB)) :  asSSE(_treeMeasuresOnSOOB[var]).imp(asSSE(_treeMeasuresOnOOB));
//      varimp  [var] = (float) imp[0];
//      varimpSD[var] = (float) imp[1];
//    }
//    return new VarImp.VarImpMDA(varimp, varimpSD, model.ntrees());
//  }
/** Compute relative variable importance for RF model.
   *
   *  See (45), (35) formulas in Friedman: Greedy Function Approximation: A Gradient boosting machine.
   *  Algo used here can be used for computation individual importance of features per output class. */
@Override
protected VarImp doVarImpCalc(DRFModel model, DTree[] ktrees, int tid, Frame validationFrame, boolean scale) {
    assert model.ntrees() - 1 - _ntreesFromCheckpoint == tid : "varimp computation expect model with already serialized trees: tid=" + tid;
    // Iterates over k-tree
    for (DTree t : ktrees) {
        // Iterate over trees
        if (t != null) {
            for (int n = 0; n < t.len() - t.leaves; n++) if (t.node(n) instanceof DecidedNode) {
                // it is split node
                DTree.Split split = t.decided(n)._split;
                if (// Skip impossible splits ~ leafs
                split.col() != -1)
                    // least squares improvement
                    _improvPerVar[split.col()] += split.improvement();
            }
        }
    }
    // Compute variable importance for all trees in model
    float[] varimp = new float[model.nfeatures()];
    int ntreesTotal = model.ntrees() * model.nclasses();
    int maxVar = 0;
    for (int var = 0; var < _improvPerVar.length; var++) {
        varimp[var] = _improvPerVar[var] / ntreesTotal;
        if (varimp[var] > varimp[maxVar])
            maxVar = var;
    }
    // scale varimp to scale 0..100
    if (scale) {
        float maxVal = varimp[maxVar];
        for (int var = 0; var < varimp.length; var++) varimp[var] /= maxVal;
    }
    return new VarImp.VarImpRI(varimp);
}
Also used : DTree(hex.gbm.DTree) DecidedNode(hex.gbm.DTree.DecidedNode)

Example 3 with DTree

use of hex.gbm.DTree in project h2o-2 by h2oai.

the class DRF method buildModel.

@Override
protected DRFModel buildModel(DRFModel model, final Frame fr, String[] names, String[][] domains, final Timer t_build) {
    // The RNG used to pick split columns
    Random rand = createRNG(_seed);
    // put random generator to the same state
    for (int i = 0; i < _ntreesFromCheckpoint; i++) rand.nextLong();
    int tid;
    DTree[] ktrees = null;
    // Prepare tree statistics
    TreeStats tstats = model.treeStats != null ? model.treeStats : new TreeStats();
    // Build trees until we hit the limit
    for (tid = 0; tid < ntrees; tid++) {
        // Building tid-tree
        if (tid != 0 || checkpoint == null) {
            // do not make initial scoring if model already exist
            model = doScoring(model, fr, ktrees, tid, tstats, tid == 0, !hasValidation(), build_tree_one_node);
        }
        // At each iteration build K trees (K = nclass = response column domain size)
        // TODO: parallelize more? build more than k trees at each time, we need to care about temporary data
        // Idea: launch more DRF at once.
        Timer kb_timer = new Timer();
        ktrees = buildNextKTrees(fr, _mtry, sample_rate, rand, tid);
        Log.info(logTag(), (tid + 1) + ". tree was built " + kb_timer.toString());
        // If canceled during building, do not bulkscore
        if (!Job.isRunning(self()))
            break;
        // Check latest predictions
        tstats.updateBy(ktrees);
    }
    if (Job.isRunning(self())) {
        // do not perform final scoring and finish
        model = doScoring(model, fr, ktrees, tid, tstats, true, !hasValidation(), build_tree_one_node);
    // Make sure that we did not miss any votes
    //      assert !importance || _treeMeasuresOnOOB.npredictors() == _treeMeasuresOnSOOB[0/*variable*/].npredictors() : "Missing some tree votes in variable importance voting?!";
    }
    return model;
}
Also used : TreeStats(hex.gbm.DTree.TreeModel.TreeStats) Random(java.util.Random) DTree(hex.gbm.DTree)

Aggregations

DTree (hex.gbm.DTree)3 DecidedNode (hex.gbm.DTree.DecidedNode)2 DHistogram (hex.gbm.DHistogram)1 LeafNode (hex.gbm.DTree.LeafNode)1 TreeStats (hex.gbm.DTree.TreeModel.TreeStats)1 UndecidedNode (hex.gbm.DTree.UndecidedNode)1 Random (java.util.Random)1 Frame (water.fvec.Frame)1