use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.
the class EnsembleOptionsPanel method loadSettings.
/**
* Load settings from config <b>cfg</b>
*
* @param cfg
*/
public void loadSettings(final TreeEnsembleLearnerConfiguration cfg) {
m_nrModelsSpinner.setValue(cfg.getNrModels());
double dataFrac = cfg.getDataFractionPerTree();
boolean isDataWithReplacement = cfg.isDataSelectionWithReplacement();
boolean isEqualSizeSampling = cfg.getRowSamplingMode() == RowSamplingMode.EqualSize;
boolean doesSampling = dataFrac < 1.0 || isDataWithReplacement || isEqualSizeSampling;
m_dataFractionPerTreeSpinner.setValue(dataFrac);
if (m_dataFractionPerTreeChecker.isSelected() != doesSampling) {
m_dataFractionPerTreeChecker.doClick();
}
if (isDataWithReplacement) {
m_dataSamplingWithReplacementChecker.doClick();
} else {
m_dataSamplingWithOutReplacementChecker.doClick();
}
m_dataSamplingModeComboBox.setSelectedItem(cfg.getRowSamplingMode());
double colFrac = cfg.getColumnFractionLinearValue();
int colAbsolute = cfg.getColumnAbsoluteValue();
boolean useDifferentAttributesAtEachNode = cfg.isUseDifferentAttributesAtEachNode();
ColumnSamplingMode columnFraction = cfg.getColumnSamplingMode();
switch(columnFraction) {
case None:
m_columnFractionNoneButton.doClick();
useDifferentAttributesAtEachNode = false;
colFrac = 1.0;
break;
case Linear:
m_columnFractionLinearButton.doClick();
break;
case Absolute:
m_columnFractionAbsoluteButton.doClick();
break;
case SquareRoot:
m_columnFractionSqrtButton.doClick();
colFrac = 1.0;
break;
}
m_columnFractionLinearTreeSpinner.setValue(colFrac);
m_columnFractionAbsoluteTreeSpinner.setValue(colAbsolute);
if (useDifferentAttributesAtEachNode) {
m_columnUseDifferentSetOfAttributesForNodes.doClick();
} else {
m_columnUseSameSetOfAttributesForNodes.doClick();
}
Long seed = cfg.getSeed();
if (m_seedChecker.isSelected() != (seed != null)) {
m_seedChecker.doClick();
}
m_seedTextField.setText(Long.toString(seed != null ? seed : System.currentTimeMillis()));
}
use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.
the class EnsembleOptionsPanel method saveSettings.
/**
* Save settings to config <b>cfg</b>
*
* @param cfg
* @throws InvalidSettingsException
*/
public void saveSettings(final TreeEnsembleLearnerConfiguration cfg) throws InvalidSettingsException {
cfg.setNrModels((Integer) m_nrModelsSpinner.getValue());
double dataFrac;
boolean isSamplingWithReplacement;
if (m_dataFractionPerTreeChecker.isSelected()) {
dataFrac = (Double) m_dataFractionPerTreeSpinner.getValue();
isSamplingWithReplacement = m_dataSamplingWithReplacementChecker.isSelected();
} else {
dataFrac = 1.0;
isSamplingWithReplacement = false;
}
cfg.setDataFractionPerTree(dataFrac);
cfg.setDataSelectionWithReplacement(isSamplingWithReplacement);
cfg.setRowSamplingMode((RowSamplingMode) m_dataSamplingModeComboBox.getSelectedItem());
ColumnSamplingMode cf;
double columnFrac = 1.0;
int columnAbsolute = TreeEnsembleLearnerConfiguration.DEF_COLUMN_ABSOLUTE;
boolean isUseDifferentAttributesAtEachNode = m_columnUseDifferentSetOfAttributesForNodes.isSelected();
if (m_columnFractionNoneButton.isSelected()) {
cf = ColumnSamplingMode.None;
isUseDifferentAttributesAtEachNode = false;
} else if (m_columnFractionLinearButton.isSelected()) {
cf = ColumnSamplingMode.Linear;
columnFrac = (Double) m_columnFractionLinearTreeSpinner.getValue();
} else if (m_columnFractionAbsoluteButton.isSelected()) {
cf = ColumnSamplingMode.Absolute;
columnAbsolute = (Integer) m_columnFractionAbsoluteTreeSpinner.getValue();
} else if (m_columnFractionSqrtButton.isSelected()) {
cf = ColumnSamplingMode.SquareRoot;
} else {
throw new InvalidSettingsException("No column selection policy selected");
}
cfg.setColumnSamplingMode(cf);
cfg.setColumnFractionLinearValue(columnFrac);
cfg.setColumnAbsoluteValue(columnAbsolute);
cfg.setUseDifferentAttributesAtEachNode(isUseDifferentAttributesAtEachNode);
Long seed;
if (m_seedChecker.isSelected()) {
final String seedText = m_seedTextField.getText();
try {
seed = Long.valueOf(seedText);
} catch (Exception e) {
throw new InvalidSettingsException("Unable to parse seed \"" + seedText + "\"", e);
}
} else {
seed = null;
}
cfg.setSeed(seed);
}
use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.
the class TreeEnsembleRegressionLearnerNodeDialogPane method saveSettingsTo.
/**
* {@inheritDoc}
*/
@Override
protected void saveSettingsTo(final NodeSettingsWO settings) throws InvalidSettingsException {
TreeEnsembleLearnerConfiguration cfg = new TreeEnsembleLearnerConfiguration(true);
m_attributeSelectionPanel.saveSettings(cfg);
m_treeOptionsPanel.saveSettings(cfg);
m_ensembleOptionsPanel.saveSettings(cfg);
cfg.save(settings);
}
use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.
the class TreeNominalColumnDataTest method createConfig.
private static TreeEnsembleLearnerConfiguration createConfig(final boolean isRegression) throws InvalidSettingsException {
final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(isRegression);
config.setColumnSamplingMode(ColumnSamplingMode.None);
config.setDataSelectionWithReplacement(false);
config.setNrModels(1);
config.setUseDifferentAttributesAtEachNode(false);
config.setDataFractionPerTree(1.0);
config.setUseBinaryNominalSplits(true);
if (!isRegression) {
config.setSplitCriterion(SplitCriterion.Gini);
}
return config;
}
use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.
the class TreeNominalColumnDataTest method testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue1.
/**
* Tests the XGBoost missing value handling variant, where for each split it is tried which direction for missing
* values provides the better gain.
*
* @throws Exception
*/
@Test
public void testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue1() throws Exception {
final TreeEnsembleLearnerConfiguration config = createConfig(false);
config.setMissingValueHandling(MissingValueHandling.XGBoost);
final TestDataGenerator dataGen = new TestDataGenerator(config);
// check correct behavior if no missing values are encountered during split search
Pair<TreeNominalColumnData, TreeTargetNominalColumnData> twoClassTennisData = twoClassTennisData(config);
String dataContainingMissingsCSV = "S,?,O,R,S,R,S,?,O,?";
final TreeNominalColumnData columnData = dataGen.createNominalAttributeColumn(dataContainingMissingsCSV, "column containing missing values", 0);
final TreeTargetNominalColumnData target = twoClassTennisData.getSecond();
double[] rowWeights = new double[TWO_CLASS_INDICES.length];
Arrays.fill(rowWeights, 1.0);
// based on the ordering in the columnData
final int[] originalIndex = new int[] { 0, 4, 6, 2, 8, 3, 5, 1, 7, 9 };
final int[] columnIndex = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
final DataMemberships dataMem = new MockDataColMem(originalIndex, columnIndex, rowWeights);
final SplitCandidate split = columnData.calcBestSplitClassification(dataMem, target.getDistribution(rowWeights, config), target, TestDataGenerator.createRandomData());
assertThat(split, instanceOf(NominalBinarySplitCandidate.class));
final NominalBinarySplitCandidate nomSplit = (NominalBinarySplitCandidate) split;
TreeNodeNominalBinaryCondition[] childConditions = nomSplit.getChildConditions();
assertEquals("Wrong gain value.", 0.18, nomSplit.getGainValue(), 1e-8);
final String[] conditionValues = new String[] { "S", "R" };
assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[0].getValues());
assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[1].getValues());
assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
assertEquals("Wrong set logic.", SetLogic.IS_IN, childConditions[1].getSetLogic());
assertTrue("Missing values are not sent to the correct child.", childConditions[0].acceptsMissings());
assertFalse("Missing values are not sent to the correct child.", childConditions[1].acceptsMissings());
}
Aggregations