use of org.knime.core.data.RowIterator in project knime-core by knime.
the class LinRegLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
/*
* What comes next is the matrix calculation, solving A \times w = b
* where A is the matrix having the training data (as many rows as there
* are rows in inData[0], w is the vector of weights to learn (number of
* variables) and b is the target output
*/
// reset was called, must be cleared
final BufferedDataTable data = (BufferedDataTable) inData[0];
final DataTableSpec spec = data.getDataTableSpec();
final String[] includes = computeIncludes(spec);
final int nrUnknown = includes.length + 1;
double[] means = new double[includes.length];
// indices of the columns in m_includes
final int[] colIndizes = new int[includes.length];
for (int i = 0; i < includes.length; i++) {
colIndizes[i] = spec.findColumnIndex(includes[i]);
}
// index of m_target
final int target = spec.findColumnIndex(m_target);
// this is the matrix (A^T x A) where A is the training data including
// one column fixed to one.
// (we do it here manually in order to avoid to get all the data in
// double[][])
double[][] ata = new double[nrUnknown][nrUnknown];
double[] buffer = new double[nrUnknown];
// we memorize for each row if it contains missing values.
BitSet missingSet = new BitSet();
m_nrRows = data.getRowCount();
int myProgress = 0;
// we need 2 or 3 scans on the data (first run was done already)
final double totalProgress = (2 + (m_isCalcError ? 1 : 0)) * m_nrRows;
int rowCount = 0;
boolean hasPrintedWarning = false;
for (RowIterator it = data.iterator(); it.hasNext(); rowCount++) {
DataRow row = it.next();
myProgress++;
exec.setProgress(myProgress / totalProgress, "Calculating matrix " + (rowCount + 1) + " (\"" + row.getKey().getString() + "\")");
exec.checkCanceled();
DataCell targetValue = row.getCell(target);
// read data from row into buffer, skip missing value rows
boolean containsMissing = targetValue.isMissing() || readIntoBuffer(row, buffer, colIndizes);
missingSet.set(rowCount, containsMissing);
if (containsMissing) {
String errorMessage = "Row \"" + row.getKey().getString() + "\" contains missing values, skipping it.";
if (!hasPrintedWarning) {
LOGGER.warn(errorMessage + " Suppress further warnings.");
hasPrintedWarning = true;
} else {
LOGGER.debug(errorMessage);
}
m_nrRowsSkipped++;
// with next row
continue;
}
updateMean(buffer, means);
// the matrix is symmetric
for (int i = 0; i < nrUnknown; i++) {
for (int j = 0; j < nrUnknown; j++) {
ata[i][j] += buffer[i] * buffer[j];
}
}
}
assert (m_nrRows == rowCount);
normalizeMean(means);
// no unique solution when there are less rows than unknown variables
if (rowCount <= nrUnknown) {
throw new Exception("Too few rows to perform regression (" + rowCount + " rows, but degree of freedom of " + nrUnknown + ")");
}
exec.setMessage("Calculating pseudo inverse...");
double[][] ataInverse = MathUtils.inverse(ata);
checkForNaN(ataInverse);
// multiply with A^T and b, i.e. (A^T x A)^-1 x A^T x b
double[] multipliers = new double[nrUnknown];
rowCount = 0;
for (RowIterator it = data.iterator(); it.hasNext(); rowCount++) {
DataRow row = it.next();
exec.setMessage("Determining output " + (rowCount + 1) + " (\"" + row.getKey().getString() + "\")");
myProgress++;
exec.setProgress(myProgress / totalProgress);
exec.checkCanceled();
// does row containing missing values?
if (missingSet.get(rowCount)) {
// error has printed above, silently ignore here.
continue;
}
boolean containsMissing = readIntoBuffer(row, buffer, colIndizes);
assert !containsMissing;
DataCell targetValue = row.getCell(target);
double b = ((DoubleValue) targetValue).getDoubleValue();
for (int i = 0; i < nrUnknown; i++) {
double buf = 0.0;
for (int j = 0; j < nrUnknown; j++) {
buf += ataInverse[i][j] * buffer[j];
}
multipliers[i] += buf * b;
}
}
if (m_isCalcError) {
assert m_error == 0.0;
rowCount = 0;
for (RowIterator it = data.iterator(); it.hasNext(); rowCount++) {
DataRow row = it.next();
exec.setMessage("Calculating error " + (rowCount + 1) + " (\"" + row.getKey().getString() + "\")");
myProgress++;
exec.setProgress(myProgress / totalProgress);
exec.checkCanceled();
// does row containing missing values?
if (missingSet.get(rowCount)) {
// error has printed above, silently ignore here.
continue;
}
boolean hasMissing = readIntoBuffer(row, buffer, colIndizes);
assert !hasMissing;
DataCell targetValue = row.getCell(target);
double b = ((DoubleValue) targetValue).getDoubleValue();
double out = 0.0;
for (int i = 0; i < nrUnknown; i++) {
out += multipliers[i] * buffer[i];
}
m_error += (b - out) * (b - out);
}
}
// handle the optional PMML input
PMMLPortObject inPMMLPort = (PMMLPortObject) inData[1];
DataTableSpec outSpec = getLearningSpec(spec);
double offset = multipliers[0];
multipliers = Arrays.copyOfRange(multipliers, 1, multipliers.length);
m_params = new LinearRegressionContent(outSpec, offset, multipliers, means);
// cache the entire table as otherwise the color information
// may be lost (filtering out the "colored" column)
m_rowContainer = new DefaultDataArray(data, m_firstRowPaint, m_rowCountPaint);
m_actualUsedColumns = includes;
return new PortObject[] { m_params.createPortObject(inPMMLPort, spec, outSpec) };
}
use of org.knime.core.data.RowIterator in project knime-core by knime.
the class ExecutionContext method createBufferedDataTable.
/**
* Caches the table argument and returns a reference to a BufferedDataTable
* wrapping the content. When saving the workflow, the entire data is
* written to disc. This method is provided for convenience. (All it does
* is to create a BufferedDataContainer, adding the rows to it and
* returning a handle to it.)
* <br /><br />Note: If table is already a BufferedDataTable it is simply returned.
* <p>This method refers to the first way of storing data,
* see <a href="#new_data">here</a>.
* @param table The table to cache.
* @param subProgressMon The execution monitor to report progress to. In
* most cases this is the object on which this method is invoked. It may
* however be an sub progress monitor.
* @return A table ready to be returned in the execute method.
* @throws CanceledExecutionException If canceled.
*/
public BufferedDataTable createBufferedDataTable(final DataTable table, final ExecutionMonitor subProgressMon) throws CanceledExecutionException {
if (table instanceof BufferedDataTable) {
return (BufferedDataTable) table;
}
BufferedDataContainer c = createDataContainer(table.getDataTableSpec(), true);
int row = 0;
try {
for (RowIterator it = table.iterator(); it.hasNext(); row++) {
DataRow next = it.next();
String message = "Caching row #" + (row + 1) + " (\"" + next.getKey() + "\")";
subProgressMon.setMessage(message);
subProgressMon.checkCanceled();
c.addRowToTable(next);
}
} finally {
c.close();
}
BufferedDataTable out = c.getTable();
out.setOwnerRecursively(m_node);
return out;
}
use of org.knime.core.data.RowIterator in project knime-core by knime.
the class DataContainer method cache.
/**
* Convenience method that will buffer the entire argument table. This is useful if you have a wrapper table at hand
* and want to make sure that all calculations are done here
*
* @param table The table to cache.
* @param exec The execution monitor to report progress to and to check for the cancel status.
* @param maxCellsInMemory The number of cells to be kept in memory before swapping to disk.
* @return A cache table containing the data from the argument.
* @throws NullPointerException If the argument is <code>null</code>.
* @throws CanceledExecutionException If the process has been canceled.
*/
public static DataTable cache(final DataTable table, final ExecutionMonitor exec, final int maxCellsInMemory) throws CanceledExecutionException {
DataContainer buf = new DataContainer(table.getDataTableSpec(), true, maxCellsInMemory);
int row = 0;
try {
for (RowIterator it = table.iterator(); it.hasNext(); row++) {
DataRow next = it.next();
exec.setMessage("Caching row #" + (row + 1) + " (\"" + next.getKey() + "\")");
exec.checkCanceled();
buf.addRowToTable(next);
}
} finally {
buf.close();
}
return buf.getTable();
}
use of org.knime.core.data.RowIterator in project knime-core by knime.
the class AppendedRowsIterator method initNextTable.
/**
* Start iterator on next table.
*/
private void initNextTable() {
assert (m_curItIndex < m_iteratorSuppliers.length - 1);
m_curItIndex++;
Pair<RowIterator, DataTableSpec> pair = m_iteratorSuppliers[m_curItIndex].get();
m_curIterator = pair.getFirst();
DataTableSpec spec = pair.getSecond();
int missingNumber = m_spec.getNumColumns() - spec.getNumColumns();
m_curMissingCells = new DataCell[missingNumber];
int missingCounter = 0;
m_curMapping = new int[m_spec.getNumColumns()];
for (int c = 0; c < m_spec.getNumColumns(); c++) {
DataColumnSpec colSpec = m_spec.getColumnSpec(c);
int targetCol = spec.findColumnIndex(colSpec.getName());
if (targetCol < 0) {
// that is one of the "missing" columns
targetCol = spec.getNumColumns() + missingCounter;
// create the missing cell
m_curMissingCells[missingCounter] = DataType.getMissingCell();
missingCounter++;
}
m_curMapping[c] = targetCol;
}
boolean leaveUntouched = missingCounter == 0;
for (int i = 0; leaveUntouched && i < m_curMapping.length; i++) {
if (m_curMapping[i] != i) {
leaveUntouched = false;
}
}
if (leaveUntouched) {
m_curMapping = null;
m_curMissingCells = null;
}
assert missingCounter == missingNumber;
}
use of org.knime.core.data.RowIterator in project knime-core by knime.
the class TableContentModelTest method testGetRow.
/**
* Method being tested: DataRow getRow(int).
*/
public final void testGetRow() {
final TableContentModel m = new TableContentModel(DATA);
int i = 0;
for (RowIterator it = DATA.iterator(); it.hasNext(); i++) {
DataRow row = it.next();
assertEquals(row, m.getRow(i));
}
try {
m.getRow(-1);
fail("Expected " + IndexOutOfBoundsException.class + " not thrown");
} catch (IndexOutOfBoundsException e) {
NodeLogger.getLogger(getClass()).debug("Got expected exception: " + e.getClass().getName(), e);
}
try {
m.getRow(OBJECT_DATA.length);
fail("Expected " + IndexOutOfBoundsException.class + " not thrown");
} catch (IndexOutOfBoundsException e) {
NodeLogger.getLogger(getClass()).debug("Got expected exception: " + e.getClass().getName(), e);
}
// further checking is done at testCachingStrategy() and other
// test methods
}
Aggregations