use of org.knime.base.node.mine.treeensemble.model.TreeEnsembleModel.TreeType in project knime-core by knime.
the class RegressionTreeModel method createLearnAttributeRow.
public DataRow createLearnAttributeRow(final DataRow learnRow, final DataTableSpec learnSpec) {
final TreeType type = getType();
switch(type) {
case Ordinary:
return learnRow;
case BitVector:
DataCell c = learnRow.getCell(0);
if (c.isMissing()) {
return null;
}
BitVectorValue bv = (BitVectorValue) c;
final long length = bv.length();
int nrAttributes = getMetaData().getNrAttributes();
if (length != nrAttributes) {
// TODO indicate error message
return null;
}
DataCell trueCell = new StringCell("1");
DataCell falseCell = new StringCell("0");
DataCell[] cells = new DataCell[nrAttributes];
for (int i = 0; i < nrAttributes; i++) {
cells[i] = bv.get(i) ? trueCell : falseCell;
}
return new DefaultRow(learnRow.getKey(), cells);
case ByteVector:
DataCell cell = learnRow.getCell(0);
if (cell.isMissing()) {
return null;
}
ByteVectorValue byteVector = (ByteVectorValue) cell;
final long bvLength = byteVector.length();
int nrAttr = getMetaData().getNrAttributes();
if (bvLength != nrAttr) {
return null;
}
DataCell[] bvCells = new DataCell[nrAttr];
for (int i = 0; i < nrAttr; i++) {
bvCells[i] = new IntCell(byteVector.get(i));
}
return new DefaultRow(learnRow.getKey(), bvCells);
default:
throw new IllegalStateException("Type unknown (not implemented): " + type);
}
}
use of org.knime.base.node.mine.treeensemble.model.TreeEnsembleModel.TreeType in project knime-core by knime.
the class RegressionTreeModel method load.
/**
* Loads and returns new ensemble model, input is NOT closed afterwards.
*
* @param in ...
* @param exec ...
* @return ...
* @throws IOException ...
* @throws CanceledExecutionException ...
*/
public static RegressionTreeModel load(final InputStream in, final ExecutionMonitor exec) throws IOException, CanceledExecutionException {
// wrapping the argument (zip input) stream in a buffered stream
// reduces read operation from, e.g. 42s to 2s
TreeModelDataInputStream input = new TreeModelDataInputStream(new BufferedInputStream(new NonClosableInputStream(in)));
int version = input.readInt();
if (version > 20140201) {
throw new IOException("Tree Ensemble version " + version + " not supported");
}
TreeType type = TreeType.load(input);
TreeMetaData metaData = TreeMetaData.load(input);
boolean isRegression = metaData.isRegression();
TreeModelRegression model;
try {
model = TreeModelRegression.load(input, metaData);
if (input.readByte() != 0) {
throw new IOException("Model not terminated by 0 byte");
}
} catch (IOException e) {
throw new IOException("Can't read tree model. " + e.getMessage(), e);
}
// does not close the method argument stream!!
input.close();
return new RegressionTreeModel(metaData, model, type);
}
use of org.knime.base.node.mine.treeensemble.model.TreeEnsembleModel.TreeType in project knime-core by knime.
the class RegressionTreeModel method getLearnAttributeSpec.
/**
* Get a table spec representing the learn attributes (not the target!). For ordinary data it is just a subset of
* the input columns, for bit vector data it's an expanded table spec with each bit represented by a StringCell
* column ("0" or "1"), whose name is "Bit x".
*
* @param learnSpec The original learn spec (which is also the return value for ordinary data)
* @return Such a learn attribute spec.
*/
public DataTableSpec getLearnAttributeSpec(final DataTableSpec learnSpec) {
final TreeType type = getType();
switch(type) {
case Ordinary:
return learnSpec;
case BitVector:
int nrAttributes = getMetaData().getNrAttributes();
DataColumnSpec[] colSpecs = new DataColumnSpec[nrAttributes];
for (int i = 0; i < nrAttributes; i++) {
colSpecs[i] = new DataColumnSpecCreator(TreeBitColumnMetaData.getAttributeName(i), StringCell.TYPE).createSpec();
}
return new DataTableSpec(colSpecs);
case ByteVector:
int nrAttr = getMetaData().getNrAttributes();
DataColumnSpec[] bvColSpecs = new DataColumnSpec[nrAttr];
for (int i = 0; i < nrAttr; i++) {
bvColSpecs[i] = new DataColumnSpecCreator(TreeNumericColumnMetaData.getAttributeName(i), IntCell.TYPE).createSpec();
}
return new DataTableSpec(bvColSpecs);
default:
throw new IllegalStateException("Type unknown (not implemented): " + type);
}
}
Aggregations