use of org.knime.core.data.vector.bytevector.ByteVectorValue in project knime-core by knime.
the class AbstractTrainingRowBuilder method getVectorLength.
private static long getVectorLength(final DataCell vectorCell) {
DataType cellType = vectorCell.getType();
long vectorLength = 0;
if (cellType.isCompatible(BitVectorValue.class)) {
BitVectorValue bv = (BitVectorValue) vectorCell;
vectorLength = bv.length();
} else if (cellType.isCompatible(ByteVectorValue.class)) {
ByteVectorValue bv = (ByteVectorValue) vectorCell;
vectorLength = bv.length();
// uncomment once double vectors become compatible with PMML
// } else if (cellType.isCompatible(DoubleVectorValue.class)) {
// DoubleVectorValue dv = (DoubleVectorValue)vectorCell;
// vectorLength = dv.getLength();
// uncomment once double ists become compatible with PMML
// } else if (vectorCell instanceof ListDataValue) {
// ListDataValue ldv = (ListDataValue)vectorCell;
// vectorLength = ldv.size();
} else {
throw new IllegalStateException("The provided cell is of unknown vector type \"" + vectorCell.getType() + "\".");
}
return vectorLength;
}
use of org.knime.core.data.vector.bytevector.ByteVectorValue in project knime-core by knime.
the class AbstractTrainingRowBuilder method build.
@Override
public T build(final DataRow row, final int id) {
int nonZeroFeatures = 1;
int accumulatedIdx = 1;
// the intercept feature is always present
m_nonZeroIndices[0] = 0;
m_nonZeroValues[0] = 1.0F;
for (int i = 0; i < m_featureCellIndices.size(); i++) {
// get cell from row
Integer cellIdx = m_featureCellIndices.get(i);
DataCell cell = row.getCell(cellIdx);
DataType cellType = cell.getType();
// handle cell according to cell type
if (cellType.isCompatible(NominalValue.class)) {
// handle nominal cells
List<DataCell> nominalDomainValues = m_nominalDomainValues.get(cellIdx);
int oneHotIdx = nominalDomainValues.indexOf(cell);
if (oneHotIdx == -1) {
throw new IllegalStateException("DataCell \"" + cell.toString() + "\" is not in the DataColumnDomain. Please apply a " + "Domain Calculator on the columns with nominal values.");
} else if (oneHotIdx > 0) {
m_nonZeroIndices[nonZeroFeatures] = accumulatedIdx + oneHotIdx - 1;
m_nonZeroValues[nonZeroFeatures] = 1.0F;
nonZeroFeatures++;
}
accumulatedIdx += nominalDomainValues.size() - 1;
} else if (m_vectorLengths.containsKey(cellIdx)) {
// handle vector cells
if (cellType.isCompatible(BitVectorValue.class)) {
BitVectorValue bv = (BitVectorValue) cell;
for (long s = bv.nextSetBit(0L); s >= 0; s = bv.nextSetBit(s + 1)) {
m_nonZeroIndices[nonZeroFeatures] = (int) (accumulatedIdx + s);
m_nonZeroValues[nonZeroFeatures++] = 1.0F;
}
} else if (cellType.isCompatible(ByteVectorValue.class)) {
ByteVectorValue bv = (ByteVectorValue) cell;
for (long s = bv.nextCountIndex(0L); s >= 0; s = bv.nextCountIndex(s + 1)) {
m_nonZeroIndices[nonZeroFeatures] = (int) (accumulatedIdx + s);
m_nonZeroValues[nonZeroFeatures++] = bv.get(s);
}
// uncomment once DoubleVectors can be used with PMML
// } else if (cellType.isCompatible(DoubleVectorValue.class)) {
// // DoubleVectorValue also implements CollectionDataValue but
// // as it then first boxes its values into DataCells, it is much more
// // efficient to access its values via the DoubleVectorValue interface
// DoubleVectorValue dv = (DoubleVectorValue)cell;
// for (int s = 0; s < dv.getLength(); s++) {
// float val = (float)dv.getValue(s);
// if (!MathUtils.equals(val, 0.0)) {
// m_nonZeroIndices[nonZeroFeatures] = accumulatedIdx + s;
// m_nonZeroValues[nonZeroFeatures++] = val;
// }
// }
// uncomment once double lists become compatible with PMML
// } else if (cellType.isCollectionType() && cellType.getCollectionElementType().isCompatible(DoubleValue.class)) {
// CollectionDataValue cv = (CollectionDataValue)cell;
// int s = 0;
// for (DataCell c : cv) {
// // we already checked above that cv contains DoubleValues
// DoubleValue dv = (DoubleValue)c;
// double val = dv.getDoubleValue();
// if (!MathUtils.equals(val, 0.0)) {
// m_nonZeroIndices[nonZeroFeatures] = accumulatedIdx + s;
// m_nonZeroValues[nonZeroFeatures] = (float)val;
// }
// s++;
// }
} else {
// should never be thrown because we check the compatibility in the constructor
throw new IllegalStateException("DataCell \"" + cell.toString() + "\" is of an unknown vector/collections type.");
}
accumulatedIdx += m_vectorLengths.get(cellIdx);
} else if (cellType.isCompatible(DoubleValue.class)) {
// handle numerical cells
double val = ((DoubleValue) cell).getDoubleValue();
if (!MathUtils.equals(val, 0.0)) {
m_nonZeroIndices[nonZeroFeatures] = accumulatedIdx;
m_nonZeroValues[nonZeroFeatures++] = (float) val;
}
accumulatedIdx++;
} else {
// a different DataCell of incompatible type.
throw new IllegalStateException("The DataCell \"" + cell.toString() + "\" is of incompatible type \"" + cellType.toPrettyString() + "\".");
}
}
int[] nonZero = Arrays.copyOf(m_nonZeroIndices, nonZeroFeatures);
float[] values = Arrays.copyOf(m_nonZeroValues, nonZeroFeatures);
return createTrainingRow(row, nonZero, values, id);
}
use of org.knime.core.data.vector.bytevector.ByteVectorValue in project knime-core by knime.
the class ExpandByteVectorNodeModel method createCellFactory.
/**
* {@inheritDoc}
*/
@Override
protected AbstractCellFactory createCellFactory(final String[] colNames, final DataColumnSpec[] outputColumns, final int inputIndex) {
return new AbstractCellFactory(outputColumns) {
@Override
public DataCell[] getCells(final DataRow row) {
DataCell[] vs = new DataCell[colNames.length];
DataCell cell = row.getCell(inputIndex);
if (cell instanceof ByteVectorValue) {
ByteVectorValue bvv = (ByteVectorValue) cell;
int length = Math.min(vs.length, (int) bvv.length());
for (int i = length; i-- > 0; ) {
vs[i] = VALUES[bvv.get(i)];
}
for (int i = vs.length; i-- > length; ) {
vs[i] = DataType.getMissingCell();
}
} else {
for (int i = 0; i < vs.length; i++) {
vs[i] = DataType.getMissingCell();
}
}
return vs;
}
};
}
use of org.knime.core.data.vector.bytevector.ByteVectorValue in project knime-core by knime.
the class TreeByteNumericColumnDataCreator method add.
/**
* {@inheritDoc}
*/
@SuppressWarnings({ "unchecked" })
@Override
public void add(final RowKey rowKey, final DataCell cell) {
if (cell.isMissing()) {
throw new IllegalStateException("Missing values not supported");
}
ByteVectorValue v = (ByteVectorValue) cell;
final long lengthLong = v.length();
if (lengthLong > Integer.MAX_VALUE) {
throw new IllegalStateException("Sparse byte vectors not supported");
}
final int length = (int) lengthLong;
if (m_byteTupleLists == null) {
m_byteTupleLists = new ArrayList[length];
for (int i = 0; i < length; i++) {
m_byteTupleLists[i] = new ArrayList<ByteTuple>();
}
} else if (m_byteTupleLists.length != length) {
throw new IllegalArgumentException("Byte vectors in table have different length, expected " + m_byteTupleLists.length + " bytes but got " + length + " bytes in row \"" + rowKey + "\"");
}
for (int attrIndex = 0; attrIndex < length; attrIndex++) {
ByteTuple tuple = new ByteTuple();
int val = v.get(attrIndex);
if (val > MAX_COUNT) {
throw new IllegalArgumentException("The value \"" + val + "\" is larger than the maximum value \"" + MAX_COUNT + "\".");
} else if (val < 0) {
throw new IllegalArgumentException("Negative values are not allowed.");
}
tuple.m_value = (byte) val;
tuple.m_indexInColumn = m_index;
m_byteTupleLists[attrIndex].add(tuple);
}
m_index++;
}
use of org.knime.core.data.vector.bytevector.ByteVectorValue in project knime-core by knime.
the class RegressionTreeModel method createByteVectorPredictorRecord.
private PredictorRecord createByteVectorPredictorRecord(final DataRow filterRow) {
assert filterRow.getNumCells() == 1 : "Expected one cell as byte vector data";
DataCell c = filterRow.getCell(0);
if (c.isMissing()) {
return null;
}
ByteVectorValue bv = (ByteVectorValue) c;
final long length = bv.length();
if (length != getMetaData().getNrAttributes()) {
throw new IllegalArgumentException("The byte-vector in " + filterRow.getKey().getString() + " has an invalid length (" + length + " instead of " + getMetaData().getNrAttributes() + ").");
}
Map<String, Object> valueMap = new LinkedHashMap<String, Object>((int) (length / 0.75 + 1.0));
for (int i = 0; i < length; i++) {
valueMap.put(TreeNumericColumnMetaData.getAttributeName(i), Integer.valueOf(bv.get(i)));
}
return new PredictorRecord(valueMap);
}
Aggregations