use of org.knime.base.node.mine.regression.pmmlgreg.PMMLPredictor in project knime-core by knime.
the class RegressionPredictorCellFactory method determineFactorValues.
/**
* @param trainingSpec the table spec of the training set
* @param content the content
* @return the factors name mapped to its values
* @throws InvalidSettingsException If the PMML data dictionary contains more elements for a nominal column
* than represented in the data
*/
protected static Map<String, List<DataCell>> determineFactorValues(final PMMLGeneralRegressionContent content, final DataTableSpec trainingSpec) throws InvalidSettingsException {
HashMap<String, List<DataCell>> values = new HashMap<String, List<DataCell>>();
for (PMMLPredictor factor : content.getFactorList()) {
String factorName = factor.getName();
Map<String, DataCell> domainValues = new HashMap<String, DataCell>();
for (DataCell cell : trainingSpec.getColumnSpec(factorName).getDomain().getValues()) {
domainValues.put(cell.toString(), cell);
}
Set<DataCell> factorValues = new LinkedHashSet<DataCell>();
// add all values for all PMMLGeneralRegression model that do not specify all values in the PPMatrix
factorValues.addAll(trainingSpec.getColumnSpec(factorName).getDomain().getValues());
int count = 0;
for (PMMLPPCell ppCell : content.getPPMatrix()) {
if (ppCell.getPredictorName().equals(factorName)) {
DataCell cell = domainValues.get(ppCell.getValue());
// move cell to the end of the list, this gives in the end the same ordering
// as in the PPMatrix of the PMMLGeneralRegression model
factorValues.remove(cell);
factorValues.add(cell);
count++;
}
}
// The base line category may not be in the PPMatrix of the PMMLGeneralRegression model
// in this case count is lower than the number of domain values, but if count if even
// less than that the base line category is ambiguous.
final int valuesDataDictionary = trainingSpec.getColumnSpec(factorName).getDomain().getValues().size();
if (count < valuesDataDictionary - 1) {
throw new InvalidSettingsException("The data dictionary to column \"" + factorName + "\" contains more elements than represented in the regression model " + "(unable to decode dummy variables as reference is unknown: " + valuesDataDictionary + " > " + count + " + 1)");
}
List<DataCell> vals = new ArrayList<DataCell>();
vals.addAll(factorValues);
values.put(factorName, vals);
}
return values;
}
use of org.knime.base.node.mine.regression.pmmlgreg.PMMLPredictor in project knime-core by knime.
the class LogisticRegressionContent method createGeneralRegressionContent.
/**
* Creates a new PMML General Regression Content from this logistic
* regression model.
* @return the PMMLGeneralRegressionContent
*/
public PMMLGeneralRegressionContent createGeneralRegressionContent() {
List<PMMLPredictor> factors = new ArrayList<PMMLPredictor>();
for (String factor : m_factorList) {
PMMLPredictor predictor = new PMMLPredictor(factor);
factors.add(predictor);
}
List<PMMLPredictor> covariates = new ArrayList<PMMLPredictor>();
for (String covariate : m_covariateList) {
PMMLPredictor predictor = new PMMLPredictor(covariate);
covariates.add(predictor);
}
// the ParameterList, the PPMatrix and the ParamMatrix
List<PMMLParameter> parameterList = new ArrayList<PMMLParameter>();
List<PMMLPPCell> ppMatrix = new ArrayList<PMMLPPCell>();
List<PMMLPCell> paramMatrix = new ArrayList<PMMLPCell>();
int pCount = m_beta.getColumnDimension() / (m_targetCategories.size() - 1);
int p = 0;
parameterList.add(new PMMLParameter("p" + p, "Intercept"));
for (int k = 0; k < m_targetCategories.size() - 1; k++) {
paramMatrix.add(new PMMLPCell("p" + p, m_beta.get(0, p + (k * pCount)), 1, m_targetCategories.get(k).toString()));
}
p++;
for (String colName : m_outSpec.getLearningFields()) {
if (m_factorList.contains(colName)) {
Iterator<DataCell> designIter = m_factorDomainValues.get(colName).iterator();
// Omit first
designIter.next();
while (designIter.hasNext()) {
DataCell dvValue = designIter.next();
String pName = "p" + p;
parameterList.add(new PMMLParameter(pName, "[" + colName + "=" + dvValue + "]"));
ppMatrix.add(new PMMLPPCell(dvValue.toString(), colName, pName));
for (int k = 0; k < m_targetCategories.size() - 1; k++) {
paramMatrix.add(new PMMLPCell(pName, m_beta.get(0, p + (k * pCount)), 1, m_targetCategories.get(k).toString()));
}
p++;
}
} else {
String pName = "p" + p;
parameterList.add(new PMMLParameter("p" + p, colName));
ppMatrix.add(new PMMLPPCell("1", colName, pName));
for (int k = 0; k < m_targetCategories.size() - 1; k++) {
paramMatrix.add(new PMMLPCell(pName, m_beta.get(0, p + (k * pCount)), 1, m_targetCategories.get(k).toString()));
}
p++;
}
}
// TODO PCovMatrix
List<PMMLPCovCell> pCovMatrix = new ArrayList<PMMLPCovCell>();
PMMLGeneralRegressionContent content = new PMMLGeneralRegressionContent(ModelType.multinomialLogistic, "KNIME Logistic Regression", FunctionName.classification, "LogisticRegression", parameterList.toArray(new PMMLParameter[0]), factors.toArray(new PMMLPredictor[0]), covariates.toArray(new PMMLPredictor[0]), ppMatrix.toArray(new PMMLPPCell[0]), pCovMatrix.toArray(new PMMLPCovCell[0]), paramMatrix.toArray(new PMMLPCell[0]));
content.setTargetReferenceCategory(m_targetCategories.get(m_targetCategories.size() - 1).toString());
return content;
}
use of org.knime.base.node.mine.regression.pmmlgreg.PMMLPredictor in project knime-core by knime.
the class LogisticRegressionContent method createGeneralRegressionContent.
/**
* Creates a new PMML General Regression Content from this logistic
* regression model.
* @return the PMMLGeneralRegressionContent
*/
public PMMLGeneralRegressionContent createGeneralRegressionContent() {
List<PMMLPredictor> factors = new ArrayList<PMMLPredictor>();
for (String factor : m_factorList) {
PMMLPredictor predictor = new PMMLPredictor(factor);
factors.add(predictor);
}
List<PMMLPredictor> covariates = new ArrayList<PMMLPredictor>();
for (String covariate : m_covariateList) {
PMMLPredictor predictor = new PMMLPredictor(covariate);
covariates.add(predictor);
}
// the ParameterList, the PPMatrix and the ParamMatrix
List<PMMLParameter> parameterList = new ArrayList<PMMLParameter>();
List<PMMLPPCell> ppMatrix = new ArrayList<PMMLPPCell>();
List<PMMLPCell> paramMatrix = new ArrayList<PMMLPCell>();
int pCount = m_beta.getColumnDimension() / (m_targetCategories.size() - 1);
int p = 0;
parameterList.add(new PMMLParameter("p" + p, "Intercept"));
for (int k = 0; k < m_targetCategories.size() - 1; k++) {
paramMatrix.add(new PMMLPCell("p" + p, m_beta.getEntry(0, p + (k * pCount)), 1, m_targetCategories.get(k).toString()));
}
p++;
final List<String> learningFields = new ArrayList<>(m_outSpec.getLearningFields());
// learningFields.addAll(m_vectorLengths.keySet());
for (String colName : learningFields) {
if (m_factorList.contains(colName)) {
Iterator<DataCell> designIter = m_factorDomainValues.get(colName).iterator();
// Omit first
designIter.next();
while (designIter.hasNext()) {
DataCell dvValue = designIter.next();
String pName = "p" + p;
parameterList.add(new PMMLParameter(pName, "[" + colName + "=" + dvValue + "]"));
ppMatrix.add(new PMMLPPCell(dvValue.toString(), colName, pName));
for (int k = 0; k < m_targetCategories.size() - 1; k++) {
paramMatrix.add(new PMMLPCell(pName, m_beta.getEntry(0, p + (k * pCount)), 1, m_targetCategories.get(k).toString()));
}
p++;
}
} else {
if (m_vectorLengths.containsKey(colName)) {
final int length = m_vectorLengths.get(colName);
final int pFrozen = p;
for (int idx = 0; idx < length; ++idx) {
final String pName = "p" + pFrozen + "_" + idx;
final String predictorName = VectorHandling.valueAt(colName, idx);
parameterList.add(new PMMLParameter(pName, predictorName));
ppMatrix.add(new PMMLPPCell("1", predictorName, pName));
for (int k = 0; k < m_targetCategories.size() - 1; k++) {
paramMatrix.add(new PMMLPCell(pName, m_beta.getEntry(0, p + (k * pCount)), 1, m_targetCategories.get(k).toString()));
}
p++;
}
} else {
String pName = "p" + p;
parameterList.add(new PMMLParameter("p" + p, colName));
ppMatrix.add(new PMMLPPCell("1", colName, pName));
for (int k = 0; k < m_targetCategories.size() - 1; k++) {
paramMatrix.add(new PMMLPCell(pName, m_beta.getEntry(0, p + (k * pCount)), 1, m_targetCategories.get(k).toString()));
}
p++;
}
}
}
// TODO PCovMatrix
List<PMMLPCovCell> pCovMatrix = new ArrayList<PMMLPCovCell>();
PMMLGeneralRegressionContent content = new PMMLGeneralRegressionContent(ModelType.multinomialLogistic, "KNIME Logistic Regression", FunctionName.classification, "LogisticRegression", parameterList.toArray(new PMMLParameter[0]), factors.toArray(new PMMLPredictor[0]), covariates.toArray(new PMMLPredictor[0]), m_vectorLengths, ppMatrix.toArray(new PMMLPPCell[0]), pCovMatrix.toArray(new PMMLPCovCell[0]), paramMatrix.toArray(new PMMLPCell[0]));
content.setTargetReferenceCategory(m_targetCategories.get(m_targetCategories.size() - 1).toString());
return content;
}
use of org.knime.base.node.mine.regression.pmmlgreg.PMMLPredictor in project knime-core by knime.
the class GeneralRegressionPredictorNodeModel method createRearranger.
private ColumnRearranger createRearranger(final PMMLGeneralRegressionContent content, final PMMLPortObjectSpec pmmlSpec, final DataTableSpec inDataSpec) throws InvalidSettingsException {
if (content == null) {
throw new InvalidSettingsException("No input");
}
// the predictor can only predict logistic regression models
if (!content.getModelType().equals(ModelType.multinomialLogistic)) {
throw new InvalidSettingsException("Model Type: " + content.getModelType() + " is not supported.");
}
if (!content.getFunctionName().equals(FunctionName.classification)) {
throw new InvalidSettingsException("Function Name: " + content.getFunctionName() + " is not supported.");
}
// are nominal values
for (PMMLPredictor factor : content.getFactorList()) {
DataColumnSpec columnSpec = inDataSpec.getColumnSpec(factor.getName());
if (null == columnSpec) {
throw new InvalidSettingsException("The column \"" + factor.getName() + "\" is in the model but not in given table.");
}
if (!columnSpec.getType().isCompatible(NominalValue.class)) {
throw new InvalidSettingsException("The column \"" + factor.getName() + "\" is supposed to be nominal.");
}
}
// are numeric values
for (PMMLPredictor covariate : content.getCovariateList()) {
DataColumnSpec columnSpec = inDataSpec.getColumnSpec(covariate.getName());
if (null == columnSpec) {
throw new InvalidSettingsException("The column \"" + covariate.getName() + "\" is in the model but not in given table.");
}
if (!columnSpec.getType().isCompatible(DoubleValue.class)) {
throw new InvalidSettingsException("The column \"" + covariate.getName() + "\" is supposed to be numeric.");
}
}
ColumnRearranger c = new ColumnRearranger(inDataSpec);
RegressionPredictorSettings s = createRegressionPredictorSettings(pmmlSpec, inDataSpec);
c.append(new LogRegPredictor(content, inDataSpec, pmmlSpec, pmmlSpec.getTargetFields().get(0), s));
return c;
}
use of org.knime.base.node.mine.regression.pmmlgreg.PMMLPredictor in project knime-core by knime.
the class RegressionPredictorCellFactory method determineFactorValues.
/**
* @param trainingSpec the table spec of the training set
* @param content the content
* @return the factors name mapped to its values
* @throws InvalidSettingsException If the PMML data dictionary contains more elements for a nominal column
* than represented in the data
*/
protected static Map<String, List<DataCell>> determineFactorValues(final PMMLGeneralRegressionContent content, final DataTableSpec trainingSpec) throws InvalidSettingsException {
HashMap<String, List<DataCell>> values = new HashMap<String, List<DataCell>>();
for (PMMLPredictor factor : content.getFactorList()) {
String factorName = factor.getName();
Map<String, DataCell> domainValues = new HashMap<String, DataCell>();
for (DataCell cell : trainingSpec.getColumnSpec(factorName).getDomain().getValues()) {
domainValues.put(cell.toString(), cell);
}
Set<DataCell> factorValues = new LinkedHashSet<DataCell>();
// add all values for all PMMLGeneralRegression model that do not specify all values in the PPMatrix
factorValues.addAll(trainingSpec.getColumnSpec(factorName).getDomain().getValues());
int count = 0;
for (PMMLPPCell ppCell : content.getPPMatrix()) {
if (ppCell.getPredictorName().equals(factorName)) {
DataCell cell = domainValues.get(ppCell.getValue());
// move cell to the end of the list, this gives in the end the same ordering
// as in the PPMatrix of the PMMLGeneralRegression model
factorValues.remove(cell);
factorValues.add(cell);
count++;
}
}
// The base line category may not be in the PPMatrix of the PMMLGeneralRegression model
// in this case count is lower than the number of domain values, but if count if even
// less than that the base line category is ambiguous.
final int valuesDataDictionary = trainingSpec.getColumnSpec(factorName).getDomain().getValues().size();
if (count < valuesDataDictionary - 1) {
throw new InvalidSettingsException("The data dictionary to column \"" + factorName + "\" contains more elements than represented in the regression model " + "(unable to decode dummy variables as reference is unknown: " + valuesDataDictionary + " > " + count + " + 1)");
}
List<DataCell> vals = new ArrayList<DataCell>();
vals.addAll(factorValues);
values.put(factorName, vals);
}
return values;
}
Aggregations