Search in sources :

Example 11 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class TreeNodeSignatureFactory method getChildSignatureFor.

/**
 * Looks up whether a signature with given <b>parent</b> and <b>childIndex</b> exists and returns
 * it if it does and otherwise creates a new signature stores it and returns the newly created signature.
 * This function is synchronized because different threads will access it during the tree building process.
 *
 * @param parentSignature
 * @param childIndex
 * @return signature for child node
 */
public synchronized TreeNodeSignature getChildSignatureFor(final TreeNodeSignature parentSignature, final byte childIndex) {
    List<TreeNodeSignature> knownChildren = m_knownSignatures.get(parentSignature);
    // case that the child signature does not exist yet
    if (knownChildren.size() <= childIndex) {
        TreeNodeSignature childSignature = parentSignature.createChildSignature(childIndex);
        knownChildren.add(childIndex, childSignature);
        return childSignature;
    } else {
        // there are already signatures registered for parent
        TreeNodeSignature childSignature = knownChildren.get(childIndex);
        // (unlikely because we usually build the children in order)
        if (childSignature == null) {
            childSignature = parentSignature.createChildSignature(childIndex);
            knownChildren.add(childIndex, childSignature);
        }
        return childSignature;
    }
}
Also used : TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)

Example 12 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class MGradientBoostedTreesLearner method learn.

/**
 * {@inheritDoc}
 */
@Override
public AbstractGradientBoostingModel learn(final ExecutionMonitor exec) throws CanceledExecutionException {
    final TreeData actualData = getData();
    final GradientBoostingLearnerConfiguration config = getConfig();
    final int nrModels = config.getNrModels();
    final TreeTargetNumericColumnData actualTarget = getTarget();
    final double initialValue = actualTarget.getMedian();
    final ArrayList<TreeModelRegression> models = new ArrayList<TreeModelRegression>(nrModels);
    final ArrayList<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<Map<TreeNodeSignature, Double>>(nrModels);
    final double[] previousPrediction = new double[actualTarget.getNrRows()];
    Arrays.fill(previousPrediction, initialValue);
    final RandomData rd = config.createRandomData();
    final double alpha = config.getAlpha();
    TreeNodeSignatureFactory signatureFactory = null;
    final int maxLevels = config.getMaxLevels();
    // this should be the default
    if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
        final int capacity = IntMath.pow(2, maxLevels - 1);
        signatureFactory = new TreeNodeSignatureFactory(capacity);
    } else {
        signatureFactory = new TreeNodeSignatureFactory();
    }
    exec.setMessage("Learning model");
    TreeData residualData;
    for (int i = 0; i < nrModels; i++) {
        final double[] residuals = new double[actualTarget.getNrRows()];
        for (int j = 0; j < actualTarget.getNrRows(); j++) {
            residuals[j] = actualTarget.getValueFor(j) - previousPrediction[j];
        }
        final double quantile = calculateAlphaQuantile(residuals, alpha);
        final double[] gradients = new double[residuals.length];
        for (int j = 0; j < gradients.length; j++) {
            gradients[j] = Math.abs(residuals[j]) <= quantile ? residuals[j] : quantile * Math.signum(residuals[j]);
        }
        residualData = createResidualDataFromArray(gradients, actualData);
        final RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
        final RowSample rowSample = getRowSampler().createRowSample(rdSingle);
        final TreeLearnerRegression treeLearner = new TreeLearnerRegression(getConfig(), residualData, getIndexManager(), signatureFactory, rdSingle, rowSample);
        final TreeModelRegression tree = treeLearner.learnSingleTree(exec, rdSingle);
        final Map<TreeNodeSignature, Double> coefficientMap = calcCoefficientMap(residuals, quantile, tree);
        adaptPreviousPrediction(previousPrediction, tree, coefficientMap);
        models.add(tree);
        coefficientMaps.add(coefficientMap);
        exec.setProgress(((double) i) / nrModels, "Finished level " + i + "/" + nrModels);
    }
    return new GradientBoostedTreesModel(getConfig(), actualData.getMetaData(), models.toArray(new TreeModelRegression[models.size()]), actualData.getTreeType(), initialValue, coefficientMaps);
}
Also used : RandomData(org.apache.commons.math.random.RandomData) ArrayList(java.util.ArrayList) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeLearnerRegression(org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) HashMap(java.util.HashMap) Map(java.util.Map) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)

Example 13 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class AbstractGBTModelExporter method writeSumSegmentation.

protected void writeSumSegmentation(final Segmentation segmentation, final Collection<TreeModelRegression> trees, final Collection<Map<TreeNodeSignature, Double>> coefficientMaps) {
    assert trees.size() == coefficientMaps.size() : "The number of trees does not match the number of coefficient maps.";
    segmentation.setMultipleModelMethod(MULTIPLEMODELMETHOD.SUM);
    Iterator<TreeModelRegression> treeIterator = trees.iterator();
    Iterator<Map<TreeNodeSignature, Double>> coefficientMapIterator = coefficientMaps.iterator();
    for (int i = 1; i <= trees.size(); i++) {
        Segment segment = segmentation.addNewSegment();
        segment.setId(Integer.toString(i));
        segment.addNewTrue();
        writeTreeIntoSegment(segment, treeIterator.next(), coefficientMapIterator.next());
    }
}
Also used : Map(java.util.Map) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) Segment(org.dmg.pmml.SegmentDocument.Segment)

Example 14 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class AbstractRegressionContentParser method createNode.

@Override
public final TreeNodeRegression createNode(final Node node, final TargetColumnHelper<TreeTargetNumericColumnMetaData> targetHelper, final TreeNodeSignature signature, final List<TreeNodeRegression> children) {
    double mean = Double.parseDouble(node.getScore());
    OptionalDouble totalSum = node.getExtensionList().stream().filter(e -> e.getName().equals(TranslationUtil.TOTAL_SUM_KEY)).mapToDouble(e -> Double.parseDouble(e.getValue())).findFirst();
    OptionalDouble sumSquaredDeviation = node.getExtensionList().stream().filter(e -> e.getName().equals(TranslationUtil.SUM_SQUARED_DEVIATION_KEY)).mapToDouble(e -> Double.parseDouble(e.getValue())).findFirst();
    return createNodeInternal(node, targetHelper.getMetaData(), signature, mean, totalSum.orElse(-1), sumSquaredDeviation.orElse(-1), children.toArray(new TreeNodeRegression[children.size()]));
}
Also used : List(java.util.List) TreeTargetNumericColumnMetaData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnMetaData) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) OptionalDouble(java.util.OptionalDouble) Node(org.dmg.pmml.NodeDocument.Node) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) OptionalDouble(java.util.OptionalDouble)

Example 15 with TreeNodeSignature

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature in project knime-core by knime.

the class ClassificationGBTModelImporter method importFromPMMLInternal.

/**
 * {@inheritDoc}
 */
@Override
protected MultiClassGradientBoostedTreesModel importFromPMMLInternal(final MiningModel miningModel) {
    Segmentation modelChain = miningModel.getSegmentation();
    CheckUtils.checkArgument(modelChain.getMultipleModelMethod() == MULTIPLEMODELMETHOD.MODEL_CHAIN, "The top level segmentation should have multiple model method '%s' but has '%s'", MULTIPLEMODELMETHOD.MODEL_CHAIN, modelChain.getMultipleModelMethod());
    List<List<TreeModelRegression>> trees = new ArrayList<>();
    List<List<Map<TreeNodeSignature, Double>>> coefficientMaps = new ArrayList<>();
    List<String> classLabels = new ArrayList<>();
    List<Segment> segments = modelChain.getSegmentList();
    for (int i = 0; i < segments.size() - 1; i++) {
        Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> gbtPair = processClassSegment(segments.get(i));
        trees.add(gbtPair.getFirst());
        coefficientMaps.add(gbtPair.getSecond());
        classLabels.add(extractClassLabel(segments.get(i)));
    }
    double initialValue = extractInitialValue(segments.get(0));
    return MultiClassGradientBoostedTreesModel.create(getMetaDataMapper().getTreeMetaData(), trees, coefficientMaps, initialValue, TreeType.Ordinary, classLabels);
}
Also used : Segmentation(org.dmg.pmml.SegmentationDocument.Segmentation) ArrayList(java.util.ArrayList) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) Segment(org.dmg.pmml.SegmentDocument.Segment) ArrayList(java.util.ArrayList) List(java.util.List)

Aggregations

TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)20 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)10 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)8 ArrayList (java.util.ArrayList)6 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)6 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)6 Map (java.util.Map)5 RandomData (org.apache.commons.math.random.RandomData)5 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)5 TreeNodeSignatureFactory (org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)5 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)5 BitSet (java.util.BitSet)4 HashMap (java.util.HashMap)4 Segment (org.dmg.pmml.SegmentDocument.Segment)4 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)4 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)4 List (java.util.List)3 Segmentation (org.dmg.pmml.SegmentationDocument.Segmentation)3 Test (org.junit.Test)3 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)3