Search in sources :

Example 1 with FunctionName

use of org.knime.base.node.mine.regression.pmmlgreg.PMMLGeneralRegressionContent.FunctionName in project knime-core by knime.

the class PMMLGeneralRegressionTranslator method initializeFrom.

/**
 * {@inheritDoc}
 */
@Override
public void initializeFrom(final PMMLDocument pmmlDoc) {
    m_nameMapper = new DerivedFieldMapper(pmmlDoc);
    List<GeneralRegressionModel> models = pmmlDoc.getPMML().getGeneralRegressionModelList();
    if (models.isEmpty()) {
        throw new IllegalArgumentException("No general regression model" + " provided.");
    } else if (models.size() > 1) {
        LOGGER.warn("Multiple general regression models found. " + "Only the first model is considered.");
    }
    GeneralRegressionModel reg = models.get(0);
    // read the content type
    PMMLGeneralRegressionContent.ModelType modelType = getKNIMERegModelType(reg.getModelType());
    m_content.setModelType(modelType);
    // read the function name
    FunctionName functionName = getKNIMEFunctionName(reg.getFunctionName());
    m_content.setFunctionName(functionName);
    m_content.setAlgorithmName(reg.getAlgorithmName());
    m_content.setModelName(reg.getModelName());
    if (reg.getCumulativeLink() != null) {
        throw new IllegalArgumentException("The attribute \"cumulativeLink\"" + " is currently not supported.");
    }
    m_content.setTargetReferenceCategory(reg.getTargetReferenceCategory());
    if (reg.isSetOffsetValue()) {
        m_content.setOffsetValue(reg.getOffsetValue());
    }
    if (reg.getLocalTransformations() != null && reg.getLocalTransformations().getDerivedFieldList() != null) {
        updateVectorLengthsBasedOnDerivedFields(reg.getLocalTransformations().getDerivedFieldList());
    }
    // final Stream<String> vectorLengthsAsJsonAsString = reg.getMiningSchema().getExtensionList().stream()
    // .filter(e -> e.getExtender().equals(EXTENDER) && e.getName().equals(VECTOR_COLUMNS_WITH_LENGTH)).map(v -> v.getValue());
    // vectorLengthsAsJsonAsString
    // .forEachOrdered(jsonAsString -> m_content.updateVectorLengths(
    // Json.createReader(new StringReader(jsonAsString)).readObject().entrySet().stream().collect(
    // Collectors.toMap(Entry::getKey, entry -> ((JsonNumber)entry.getValue()).intValueExact()))));
    // read the parameter list
    ParameterList pmmlParamList = reg.getParameterList();
    if (pmmlParamList != null && pmmlParamList.sizeOfParameterArray() > 0) {
        List<Parameter> pmmlParam = pmmlParamList.getParameterList();
        PMMLParameter[] paramList = new PMMLParameter[pmmlParam.size()];
        for (int i = 0; i < pmmlParam.size(); i++) {
            String name = m_nameMapper.getColumnName(pmmlParam.get(i).getName());
            String label = pmmlParam.get(i).getLabel();
            if (label == null) {
                paramList[i] = new PMMLParameter(name);
            } else {
                paramList[i] = new PMMLParameter(name, label);
            }
        }
        m_content.setParameterList(paramList);
    } else {
        m_content.setParameterList(new PMMLParameter[0]);
    }
    // read the factor list
    FactorList pmmlFactorList = reg.getFactorList();
    if (pmmlFactorList != null && pmmlFactorList.sizeOfPredictorArray() > 0) {
        List<Predictor> pmmlPredictor = pmmlFactorList.getPredictorList();
        PMMLPredictor[] predictor = new PMMLPredictor[pmmlPredictor.size()];
        for (int i = 0; i < pmmlPredictor.size(); i++) {
            predictor[i] = new PMMLPredictor(m_nameMapper.getColumnName(pmmlPredictor.get(i).getName()));
        }
        m_content.setFactorList(predictor);
    } else {
        m_content.setFactorList(new PMMLPredictor[0]);
    }
    // read covariate list
    CovariateList covariateList = reg.getCovariateList();
    if (covariateList != null && covariateList.sizeOfPredictorArray() > 0) {
        List<Predictor> pmmlPredictor = covariateList.getPredictorList();
        PMMLPredictor[] predictor = new PMMLPredictor[pmmlPredictor.size()];
        for (int i = 0; i < pmmlPredictor.size(); i++) {
            predictor[i] = new PMMLPredictor(m_nameMapper.getColumnName(pmmlPredictor.get(i).getName()));
        }
        m_content.setCovariateList(predictor);
    } else {
        m_content.setCovariateList(new PMMLPredictor[0]);
    }
    // read PPMatrix
    PPMatrix ppMatrix = reg.getPPMatrix();
    if (ppMatrix != null && ppMatrix.sizeOfPPCellArray() > 0) {
        List<PPCell> pmmlCellArray = ppMatrix.getPPCellList();
        PMMLPPCell[] cells = new PMMLPPCell[pmmlCellArray.size()];
        for (int i = 0; i < pmmlCellArray.size(); i++) {
            PPCell ppCell = pmmlCellArray.get(i);
            cells[i] = new PMMLPPCell(ppCell.getValue(), m_nameMapper.getColumnName(ppCell.getPredictorName()), ppCell.getParameterName(), ppCell.getTargetCategory());
        }
        m_content.setPPMatrix(cells);
    } else {
        m_content.setPPMatrix(new PMMLPPCell[0]);
    }
    // read CovMatrix
    PCovMatrix pCovMatrix = reg.getPCovMatrix();
    if (pCovMatrix != null && pCovMatrix.sizeOfPCovCellArray() > 0) {
        List<PCovCell> pCovCellArray = pCovMatrix.getPCovCellList();
        PMMLPCovCell[] covCells = new PMMLPCovCell[pCovCellArray.size()];
        for (int i = 0; i < pCovCellArray.size(); i++) {
            PCovCell c = pCovCellArray.get(i);
            covCells[i] = new PMMLPCovCell(c.getPRow(), c.getPCol(), c.getTRow(), c.getTCol(), c.getValue(), c.getTargetCategory());
        }
        m_content.setPCovMatrix(covCells);
    } else {
        m_content.setPCovMatrix(new PMMLPCovCell[0]);
    }
    // read ParamMatrix
    ParamMatrix paramMatrix = reg.getParamMatrix();
    if (paramMatrix != null && paramMatrix.sizeOfPCellArray() > 0) {
        List<PCell> pCellArray = paramMatrix.getPCellList();
        PMMLPCell[] cells = new PMMLPCell[pCellArray.size()];
        for (int i = 0; i < pCellArray.size(); i++) {
            PCell p = pCellArray.get(i);
            double beta = p.getBeta();
            BigInteger df = p.getDf();
            if (df != null) {
                cells[i] = new PMMLPCell(p.getParameterName(), beta, df.intValue(), p.getTargetCategory());
            } else {
                cells[i] = new PMMLPCell(p.getParameterName(), beta, p.getTargetCategory());
            }
        }
        m_content.setParamMatrix(cells);
    } else {
        m_content.setParamMatrix(new PMMLPCell[0]);
    }
}
Also used : Predictor(org.dmg.pmml.PredictorDocument.Predictor) PPCell(org.dmg.pmml.PPCellDocument.PPCell) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) FunctionName(org.knime.base.node.mine.regression.pmmlgreg.PMMLGeneralRegressionContent.FunctionName) FactorList(org.dmg.pmml.FactorListDocument.FactorList) PPCell(org.dmg.pmml.PPCellDocument.PPCell) PCell(org.dmg.pmml.PCellDocument.PCell) ParamMatrix(org.dmg.pmml.ParamMatrixDocument.ParamMatrix) PPMatrix(org.dmg.pmml.PPMatrixDocument.PPMatrix) CovariateList(org.dmg.pmml.CovariateListDocument.CovariateList) PCovMatrix(org.dmg.pmml.PCovMatrixDocument.PCovMatrix) PCovCell(org.dmg.pmml.PCovCellDocument.PCovCell) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) ParameterList(org.dmg.pmml.ParameterListDocument.ParameterList) Parameter(org.dmg.pmml.ParameterDocument.Parameter) BigInteger(java.math.BigInteger)

Aggregations

BigInteger (java.math.BigInteger)1 CovariateList (org.dmg.pmml.CovariateListDocument.CovariateList)1 FactorList (org.dmg.pmml.FactorListDocument.FactorList)1 GeneralRegressionModel (org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel)1 PCell (org.dmg.pmml.PCellDocument.PCell)1 PCovCell (org.dmg.pmml.PCovCellDocument.PCovCell)1 PCovMatrix (org.dmg.pmml.PCovMatrixDocument.PCovMatrix)1 PPCell (org.dmg.pmml.PPCellDocument.PPCell)1 PPMatrix (org.dmg.pmml.PPMatrixDocument.PPMatrix)1 ParamMatrix (org.dmg.pmml.ParamMatrixDocument.ParamMatrix)1 Parameter (org.dmg.pmml.ParameterDocument.Parameter)1 ParameterList (org.dmg.pmml.ParameterListDocument.ParameterList)1 Predictor (org.dmg.pmml.PredictorDocument.Predictor)1 FunctionName (org.knime.base.node.mine.regression.pmmlgreg.PMMLGeneralRegressionContent.FunctionName)1 DerivedFieldMapper (org.knime.core.node.port.pmml.preproc.DerivedFieldMapper)1