use of org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow in project knime-core by knime.
the class SagLogRegLearner method learn.
/**
* {@inheritDoc}
*/
@Override
public LogRegLearnerResult learn(final TrainingData<ClassificationTrainingRow> data, final ExecutionMonitor progressMonitor) throws CanceledExecutionException, InvalidSettingsException {
AbstractSGOptimizer sgOpt = createOptimizer(m_settings, data);
SimpleProgress progMon = new SimpleProgress(progressMonitor.getProgressMonitor());
LogRegLearnerResult result = sgOpt.optimize(m_settings.getMaxEpoch(), data, progMon);
Optional<String> warning = sgOpt.getWarning();
if (warning.isPresent()) {
m_warning = warning.get();
}
return result;
}
use of org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow in project knime-core by knime.
the class SparseClassificationTrainingRowTest method testFeatureIterator.
/**
* Tests the {@link FeatureIterator} returned by {@link ClassificationTrainingRow#getFeatureIterator()}.
*
* @throws Exception
*/
@Test
public void testFeatureIterator() throws Exception {
SparseClassificationTrainingRow row = createRow();
FeatureIterator fi = row.getFeatureIterator();
for (int i = 0; i < INDICES.length; i++) {
assertTrue(fi.hasNext());
assertTrue(fi.next());
assertEquals(INDICES[i], fi.getFeatureIndex());
// there are no differences allowed here
assertEquals(VALUES[i], fi.getFeatureValue(), 0);
if (i == 2) {
FeatureIterator sfi = fi.spawn();
assertEquals(INDICES[i - 1], sfi.getFeatureIndex());
assertEquals(VALUES[i - 1], sfi.getFeatureValue(), 0);
}
}
assertFalse(fi.hasNext());
assertFalse(fi.next());
}
use of org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow in project knime-core by knime.
the class LogRegCoordinator method learn.
/**
* Performs the learning task by creating the appropriate LogRegLearner and all other objects
* necessary for a successful training.
*
* @param trainingData a DataTable that contains the data on which to learn the logistic regression model
* @param exec the execution context of the corresponding KNIME node
* @return the content of the logistic regression model
* @throws InvalidSettingsException if the settings cause inconsistencies during training
* @throws CanceledExecutionException if the training is canceled
*/
LogisticRegressionContent learn(final BufferedDataTable trainingData, final ExecutionContext exec) throws InvalidSettingsException, CanceledExecutionException {
CheckUtils.checkArgument(trainingData.size() > 0, "The input table is empty. Please provide data to learn on.");
CheckUtils.checkArgument(trainingData.size() <= Integer.MAX_VALUE, "The input table contains too many rows.");
LogRegLearner learner;
if (m_settings.getSolver() == Solver.IRLS) {
learner = new IrlsLearner(m_settings.getMaxEpoch(), m_settings.getEpsilon(), m_settings.isCalcCovMatrix());
} else {
learner = new SagLogRegLearner(m_settings);
}
double calcDomainTime = 1.0 / (5.0 * 2.0 + 1.0);
exec.setMessage("Analyzing categorical data");
BufferedDataTable dataTable = recalcDomainForTargetAndLearningFields(trainingData, exec.createSubExecutionContext(calcDomainTime));
checkConstantLearningFields(dataTable);
exec.setMessage("Building logistic regression model");
ExecutionMonitor trainExec = exec.createSubProgress(1.0 - calcDomainTime);
LogRegLearnerResult result;
TrainingRowBuilder<ClassificationTrainingRow> rowBuilder = new SparseClassificationTrainingRowBuilder(dataTable, m_pmmlOutSpec, m_settings.getTargetReferenceCategory(), m_settings.getSortTargetCategories(), m_settings.getSortIncludesCategories());
TrainingData<ClassificationTrainingRow> data;
Long seed = m_settings.getSeed();
if (m_settings.isInMemory()) {
data = new InMemoryData<ClassificationTrainingRow>(dataTable, seed, rowBuilder);
} else {
data = new DataTableTrainingData<ClassificationTrainingRow>(trainingData, seed, rowBuilder, m_settings.getChunkSize(), exec.createSilentSubExecutionContext(0.0));
}
checkShapeCompatibility(data);
result = learner.learn(data, trainExec);
LogisticRegressionContent content = createContentFromLearnerResult(result, rowBuilder, trainingData.getDataTableSpec());
addToWarning(learner.getWarningMessage());
return content;
}
use of org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow in project knime-core by knime.
the class IrlsLearner method likelihood.
// private RealMatrix getStdErrorMatrix(final RealMatrix xTwx) {
// RealMatrix covMat = new QRDecomposition(xTwx).getSolver().getInverse().scalarMultiply(-1);
// // the standard error estimate
// RealMatrix stdErr = MatrixUtils.createRealMatrix(covMat.getColumnDimension(),
// covMat.getRowDimension());
// for (int i = 0; i < covMat.getRowDimension(); i++) {
// stdErr.setEntry(i, i, sqrt(abs(covMat.getEntry(i, i))));
// }
// return stdErr;
// }
/**
* Compute the likelihood at given beta.
*
* @param iter iterator over trainings data.
* @param beta parameter vector
* @param rC regressors count
* @param tcC target category count
* @throws CanceledExecutionException when method is cancelled
*/
private double likelihood(final Iterator<ClassificationTrainingRow> iter, final RealMatrix beta, final int rC, final int tcC, final ExecutionMonitor exec) throws CanceledExecutionException {
double loglike = 0;
RealMatrix x = MatrixUtils.createRealMatrix(1, rC + 1);
while (iter.hasNext()) {
exec.checkCanceled();
ClassificationTrainingRow row = iter.next();
fillXFromRow(x, row);
double sumEBetaTx = 0;
for (int i = 0; i < tcC - 1; i++) {
RealMatrix betaITx = x.multiply(beta.getSubMatrix(0, 0, i * (rC + 1), (i + 1) * (rC + 1) - 1).transpose());
sumEBetaTx += Math.exp(betaITx.getEntry(0, 0));
}
int y = row.getCategory();
double yBetaTx = 0;
if (y < tcC - 1) {
yBetaTx = x.multiply(beta.getSubMatrix(0, 0, y * (rC + 1), (y + 1) * (rC + 1) - 1).transpose()).getEntry(0, 0);
}
loglike += yBetaTx - Math.log(1 + sumEBetaTx);
}
return loglike;
}
use of org.knime.base.node.mine.regression.logistic.learner4.data.ClassificationTrainingRow in project knime-core by knime.
the class IrlsLearner method fillXFromRow.
private static void fillXFromRow(final RealMatrix x, final ClassificationTrainingRow row) {
FeatureIterator iter = row.getFeatureIterator();
boolean hasNext = iter.next();
for (int i = 0; i < x.getColumnDimension(); i++) {
double val = 0.0;
if (hasNext && iter.getFeatureIndex() == i) {
val = iter.getFeatureValue();
hasNext = iter.next();
}
x.setEntry(0, i, val);
}
}
Aggregations