use of org.dmg.pmml.pmml_4_2.descr.MiningField in project drools by kiegroup.
the class AbstractModel method initMiningFieldMap.
protected void initMiningFieldMap() {
MiningSchema schema = getMiningSchema();
miningFieldMap = new HashMap<>();
for (MiningField field : schema.getMiningFields()) {
miningFieldMap.put(field.getName(), field);
}
}
use of org.dmg.pmml.pmml_4_2.descr.MiningField in project drools by kiegroup.
the class AbstractModel method getOutputFields.
@Override
public List<PMMLOutputField> getOutputFields() {
List<PMMLOutputField> fields = new ArrayList<>();
for (String key : outputFieldMap.keySet()) {
OutputField of = outputFieldMap.get(key);
fields.add(new PMMLOutputField(of, null, this.getModelId()));
}
Map<String, MiningField> includeFromMining = getFilteredMiningFieldMap(true, FIELDUSAGETYPE.PREDICTED, FIELDUSAGETYPE.TARGET);
Map<String, PMMLDataField> dataDictionary = getOwner().getDataDictionaryMap();
if (includeFromMining != null && !includeFromMining.isEmpty()) {
for (String key : includeFromMining.keySet()) {
MiningField field = includeFromMining.get(key);
PMMLDataField df = dataDictionary.get(key);
fields.add(new PMMLOutputField(field, df.getRawDataField(), this.getModelId()));
}
}
return fields;
}
use of org.dmg.pmml.pmml_4_2.descr.MiningField in project drools by kiegroup.
the class GuidedScoreCardDRLPersistence method createPMMLDocument.
private static PMML createPMMLDocument(final ScoreCardModel model) {
final Scorecard pmmlScorecard = ScorecardPMMLUtils.createScorecard();
final Output output = new Output();
final Characteristics characteristics = new Characteristics();
final MiningSchema miningSchema = new MiningSchema();
Extension extension = new Extension();
extension.setName(PMMLExtensionNames.EXTERNAL_CLASS);
extension.setValue(model.getFactName());
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
String agendaGroup = model.getAgendaGroup();
if (!StringUtils.isEmpty(agendaGroup)) {
extension = new Extension();
extension.setName(PMMLExtensionNames.AGENDA_GROUP);
extension.setValue(agendaGroup);
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
}
String ruleFlowGroup = model.getRuleFlowGroup();
if (!StringUtils.isEmpty(ruleFlowGroup)) {
extension = new Extension();
extension.setName(PMMLExtensionNames.RULEFLOW_GROUP);
extension.setValue(agendaGroup);
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
}
extension = new Extension();
extension.setName(PMMLExtensionNames.MODEL_IMPORTS);
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
List<String> imports = new ArrayList<String>();
StringBuilder importBuilder = new StringBuilder();
for (Import imp : model.getImports().getImports()) {
if (!imports.contains(imp.getType())) {
imports.add(imp.getType());
importBuilder.append(imp.getType()).append(",");
}
}
extension.setValue(importBuilder.toString());
extension = new Extension();
extension.setName(ScorecardPMMLExtensionNames.SCORECARD_RESULTANT_SCORE_FIELD);
extension.setValue(model.getFieldName());
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
extension = new Extension();
extension.setName(PMMLExtensionNames.MODEL_PACKAGE);
String pkgName = model.getPackageName();
extension.setValue(!(pkgName == null || pkgName.isEmpty()) ? pkgName : null);
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
final String modelName = convertToJavaIdentifier(model.getName());
pmmlScorecard.setModelName(modelName);
pmmlScorecard.setInitialScore(model.getInitialScore());
pmmlScorecard.setUseReasonCodes(model.isUseReasonCodes());
if (model.isUseReasonCodes()) {
pmmlScorecard.setBaselineScore(model.getBaselineScore());
pmmlScorecard.setReasonCodeAlgorithm(model.getReasonCodesAlgorithm());
}
for (final org.drools.workbench.models.guided.scorecard.shared.Characteristic characteristic : model.getCharacteristics()) {
final Characteristic _characteristic = new Characteristic();
characteristics.getCharacteristics().add(_characteristic);
extension = new Extension();
extension.setName(PMMLExtensionNames.EXTERNAL_CLASS);
extension.setValue(characteristic.getFact());
_characteristic.getExtensions().add(extension);
extension = new Extension();
extension.setName(ScorecardPMMLExtensionNames.CHARACTERTISTIC_DATATYPE);
if ("string".equalsIgnoreCase(characteristic.getDataType())) {
extension.setValue(XLSKeywords.DATATYPE_TEXT);
} else if ("int".equalsIgnoreCase(characteristic.getDataType()) || "double".equalsIgnoreCase(characteristic.getDataType())) {
extension.setValue(XLSKeywords.DATATYPE_NUMBER);
} else if ("boolean".equalsIgnoreCase(characteristic.getDataType())) {
extension.setValue(XLSKeywords.DATATYPE_BOOLEAN);
} else {
System.out.println(">>>> Found unknown data type :: " + characteristic.getDataType());
}
_characteristic.getExtensions().add(extension);
_characteristic.setBaselineScore(characteristic.getBaselineScore());
if (model.isUseReasonCodes()) {
_characteristic.setReasonCode(characteristic.getReasonCode());
}
_characteristic.setName(characteristic.getName());
final MiningField miningField = new MiningField();
miningField.setName(characteristic.getField());
miningField.setUsageType(FIELDUSAGETYPE.ACTIVE);
miningField.setInvalidValueTreatment(INVALIDVALUETREATMENTMETHOD.RETURN_INVALID);
miningSchema.getMiningFields().add(miningField);
extension = new Extension();
extension.setName(PMMLExtensionNames.EXTERNAL_CLASS);
extension.setValue(characteristic.getFact());
miningField.getExtensions().add(extension);
for (final org.drools.workbench.models.guided.scorecard.shared.Attribute attribute : characteristic.getAttributes()) {
final Attribute _attribute = new Attribute();
_characteristic.getAttributes().add(_attribute);
extension = new Extension();
extension.setName(ScorecardPMMLExtensionNames.CHARACTERTISTIC_FIELD);
extension.setValue(characteristic.getField());
_attribute.getExtensions().add(extension);
if (model.isUseReasonCodes()) {
_attribute.setReasonCode(attribute.getReasonCode());
}
_attribute.setPartialScore(attribute.getPartialScore());
final String operator = attribute.getOperator();
final String dataType = characteristic.getDataType();
String predicateResolver;
if ("boolean".equalsIgnoreCase(dataType)) {
predicateResolver = operator.toUpperCase();
} else if ("String".equalsIgnoreCase(dataType)) {
if (operator.contains("=")) {
predicateResolver = operator + attribute.getValue();
} else {
predicateResolver = attribute.getValue() + ",";
}
} else {
if (NUMERIC_OPERATORS.contains(operator)) {
predicateResolver = operator + " " + attribute.getValue();
} else {
predicateResolver = attribute.getValue().replace(",", "-");
}
}
extension = new Extension();
extension.setName("predicateResolver");
extension.setValue(predicateResolver);
_attribute.getExtensions().add(extension);
}
}
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(miningSchema);
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(output);
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(characteristics);
return new ScorecardPMMLGenerator().generateDocument(pmmlScorecard);
}
use of org.dmg.pmml.pmml_4_2.descr.MiningField in project drools by kiegroup.
the class AbstractModel method getFilteredMiningFieldMap.
public Map<String, MiningField> getFilteredMiningFieldMap(boolean includeFiltered, FIELDUSAGETYPE... filterTypes) {
Map<String, MiningField> mfm = new HashMap<>();
List<FIELDUSAGETYPE> filteredTypes = Arrays.asList(filterTypes);
for (String key : miningFieldMap.keySet()) {
MiningField field = miningFieldMap.get(key);
FIELDUSAGETYPE usageType = field.getUsageType();
if ((includeFiltered && filteredTypes.contains(usageType)) || (!includeFiltered && !filteredTypes.contains(usageType))) {
mfm.put(key, field);
}
}
return mfm;
}
use of org.dmg.pmml.pmml_4_2.descr.MiningField in project drools by kiegroup.
the class AbstractModel method getMiningFields.
@Override
public List<PMMLMiningField> getMiningFields() {
List<PMMLMiningField> fields = new ArrayList<>();
Map<String, MiningField> excludesTargetMap = getFilteredMiningFieldMap(false, FIELDUSAGETYPE.TARGET);
Map<String, PMMLDataField> dataDictionary = getOwner().getDataDictionaryMap();
for (String key : excludesTargetMap.keySet()) {
PMMLDataField df = dataDictionary.get(key);
MiningField mf = miningFieldMap.get(key);
if (df != null) {
fields.add(new PMMLMiningField(mf, df.getRawDataField(), this.getModelId(), true));
} else {
PMMLMiningField fld = new PMMLMiningField(mf, this.getModelId());
if (this.getParentModel() != null) {
PMML4Model ultimateParentModel = this.getParentModel();
if (ultimateParentModel instanceof Miningmodel) {
while (ultimateParentModel.getParentModel() != null) {
ultimateParentModel = ultimateParentModel.getParentModel();
}
PMMLOutputField ofld = ((Miningmodel) ultimateParentModel).findOutputField(fld.getName());
if (ofld != null) {
fld.setType(ofld.getType());
fields.add(fld);
}
}
}
}
}
return fields;
}
Aggregations