use of org.dmg.pmml.DataDictionary in project shifu by ShifuML.
the class PMMLConstructorFactory method produce.
public static PMMLTranslator produce(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isOutBaggingToOne) {
AbstractPmmlElementCreator<Model> modelCreator = null;
AbstractSpecifCreator specifCreator = null;
if (ModelTrainConf.ALGORITHM.NN.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
modelCreator = new NNPmmlModelCreator(modelConfig, columnConfigList, isConcise);
specifCreator = new NNSpecifCreator(modelConfig, columnConfigList);
} else if (ModelTrainConf.ALGORITHM.LR.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
modelCreator = new RegressionPmmlModelCreator(modelConfig, columnConfigList, isConcise);
specifCreator = new RegressionSpecifCreator(modelConfig, columnConfigList);
} else if (ModelTrainConf.ALGORITHM.GBT.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm()) || ModelTrainConf.ALGORITHM.RF.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
TreeEnsemblePmmlCreator gbtmodelCreator = new TreeEnsemblePmmlCreator(modelConfig, columnConfigList);
AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList);
AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new TreeModelMiningSchemaCreator(modelConfig, columnConfigList);
return new TreeEnsemblePMMLTranslator(gbtmodelCreator, dataDictionaryCreator, miningSchemaCreator);
} else {
throw new RuntimeException("Model not supported: " + modelConfig.getTrain().getAlgorithm());
}
AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new MiningSchemaCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<ModelStats> modelStatsCreator = new ModelStatsCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<LocalTransformations> localTransformationsCreator = null;
ModelNormalizeConf.NormType normType = modelConfig.getNormalizeType();
if (normType.equals(ModelNormalizeConf.NormType.WOE) || normType.equals(ModelNormalizeConf.NormType.WEIGHT_WOE)) {
localTransformationsCreator = new WoeLocalTransformCreator(modelConfig, columnConfigList, isConcise);
} else if (normType == ModelNormalizeConf.NormType.WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WOE_ZSCALE) {
localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, false);
} else if (normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCALE) {
localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, true);
} else if (normType == ModelNormalizeConf.NormType.ZSCALE_ONEHOT) {
localTransformationsCreator = new ZscoreOneHotLocalTransformCreator(modelConfig, columnConfigList, isConcise);
} else {
localTransformationsCreator = new ZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise);
}
return new PMMLTranslator(modelCreator, dataDictionaryCreator, miningSchemaCreator, modelStatsCreator, localTransformationsCreator, specifCreator, isOutBaggingToOne);
}
use of org.dmg.pmml.DataDictionary in project shifu by ShifuML.
the class DataDictionaryCreator method build.
@Override
public DataDictionary build(BasicML basicML) {
DataDictionary dict = new DataDictionary();
List<DataField> fields = new ArrayList<DataField>();
boolean isSegExpansionMode = columnConfigList.size() > datasetHeaders.length;
int segSize = segmentExpansions.size();
if (basicML != null && basicML instanceof BasicFloatNetwork) {
BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
Set<Integer> featureSet = bfn.getFeatureSet();
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.getColumnNum() >= datasetHeaders.length) {
// in order
break;
}
if (isConcise) {
if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum())) || columnConfig.isTarget()) {
fields.add(convertColumnToDataField(columnConfig));
} else if (isSegExpansionMode) {
// even current column not selected, if segment column selected, we should keep raw column
for (int i = 0; i < segSize; i++) {
int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
ColumnConfig cc = columnConfigList.get(newIndex);
if (cc.isFinalSelect()) {
// if one segment feature is selected, we should put raw column in
fields.add(convertColumnToDataField(columnConfig));
break;
}
}
}
} else {
fields.add(convertColumnToDataField(columnConfig));
}
}
} else {
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.getColumnNum() >= datasetHeaders.length) {
// in order
break;
}
if (isConcise) {
if (columnConfig.isFinalSelect() || columnConfig.isTarget()) {
fields.add(convertColumnToDataField(columnConfig));
} else if (isSegExpansionMode) {
// even current column not selected, if segment column selected, we should keep raw column
for (int i = 0; i < segSize; i++) {
int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
ColumnConfig cc = columnConfigList.get(newIndex);
if (cc.isFinalSelect()) {
// if one segment feature is selected, we should put raw column in
fields.add(convertColumnToDataField(columnConfig));
break;
}
}
}
} else {
fields.add(convertColumnToDataField(columnConfig));
}
}
}
dict.addDataFields(fields.toArray(new DataField[fields.size()]));
dict.setNumberOfFields(fields.size());
return dict;
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldsTypeMapWithTargetFieldsWithoutTargets.
@Test
public void getTargetFieldsTypeMapWithTargetFieldsWithoutTargets() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
IntStream.range(0, 3).forEach(i -> {
final DataField dataField = getRandomDataField();
dataDictionary.addDataFields(dataField);
final MiningField miningField = getMiningField(dataField.getName().getValue(), MiningField.UsageType.PREDICTED);
miningSchema.addMiningFields(miningField);
});
model.setMiningSchema(miningSchema);
Map<String, DATA_TYPE> retrieved = ModelUtils.getTargetFieldsTypeMap(getFieldsFromDataDictionary(dataDictionary), model);
assertNotNull(retrieved);
assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
assertTrue(retrieved instanceof LinkedHashMap);
final Iterator<Map.Entry<String, DATA_TYPE>> iterator = retrieved.entrySet().iterator();
for (int i = 0; i < miningSchema.getMiningFields().size(); i++) {
MiningField miningField = miningSchema.getMiningFields().get(i);
DataField dataField = dataDictionary.getDataFields().stream().filter(df -> df.getName().equals(miningField.getName())).findFirst().get();
DATA_TYPE expected = DATA_TYPE.byName(dataField.getDataType().value());
final Map.Entry<String, DATA_TYPE> next = iterator.next();
assertEquals(expected, next.getValue());
}
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldTypeWithoutTargetField.
@Test(expected = Exception.class)
public void getTargetFieldTypeWithoutTargetField() {
final String fieldName = "fieldName";
MiningField.UsageType usageType = MiningField.UsageType.ACTIVE;
MiningField miningField = getMiningField(fieldName, usageType);
final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
final DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
final MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
final Model model = new RegressionModel();
model.setMiningSchema(miningSchema);
ModelUtils.getTargetFieldType(getFieldsFromDataDictionary(dataDictionary), model);
}
use of org.dmg.pmml.DataDictionary in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldsWithTargetFieldsWithTargetsWithoutOptType.
@Test
public void getTargetFieldsWithTargetFieldsWithTargetsWithoutOptType() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
final Targets targets = new Targets();
IntStream.range(0, 3).forEach(i -> {
final String fieldName = "fieldName-" + i;
final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
dataDictionary.addDataFields(dataField);
final MiningField miningField = getMiningField(fieldName, MiningField.UsageType.PREDICTED);
miningField.setOpType(OpType.CONTINUOUS);
miningSchema.addMiningFields(miningField);
final Target targetField = getTarget(fieldName, null);
targets.addTargets(targetField);
});
model.setMiningSchema(miningSchema);
model.setTargets(targets);
List<KiePMMLNameOpType> retrieved = ModelUtils.getTargetFields(getFieldsFromDataDictionary(dataDictionary), model);
assertNotNull(retrieved);
assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
retrieved.forEach(kiePMMLNameOpType -> {
Optional<MiningField> optionalMiningField = miningSchema.getMiningFields().stream().filter(fld -> kiePMMLNameOpType.getName().equals(fld.getName().getValue())).findFirst();
assertTrue(optionalMiningField.isPresent());
MiningField miningField = optionalMiningField.get();
OP_TYPE expected = OP_TYPE.byName(miningField.getOpType().value());
assertEquals(expected, kiePMMLNameOpType.getOpType());
});
}
Aggregations