use of org.dmg.pmml.clustering.ClusteringModel in project jpmml-sparkml by jpmml.
the class KMeansModelConverter method encodeModel.
@Override
public ClusteringModel encodeModel(Schema schema) {
KMeansModel model = getTransformer();
List<Cluster> clusters = new ArrayList<>();
Vector[] clusterCenters = model.clusterCenters();
for (int i = 0; i < clusterCenters.length; i++) {
Cluster cluster = new Cluster().setId(String.valueOf(i)).setArray(PMMLUtil.createRealArray(VectorUtil.toList(clusterCenters[i])));
clusters.add(cluster);
}
ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE).setCompareFunction(CompareFunction.ABS_DIFF).setMeasure(new SquaredEuclidean());
return new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, clusters.size(), ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters);
}
use of org.dmg.pmml.clustering.ClusteringModel in project drools by kiegroup.
the class ClusteringModelImplementationProviderTest method getKiePMMLModel.
@Test
public void getKiePMMLModel() throws Exception {
PMML pmml = TestUtils.loadFromFile(SOURCE_FILE);
ClusteringModel model = getModel(pmml);
final CommonCompilationDTO<ClusteringModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, model, new HasClassLoaderMock());
KiePMMLClusteringModel retrieved = PROVIDER.getKiePMMLModel(compilationDTO);
assertNotNull(retrieved);
assertTrue(retrieved instanceof Serializable);
}
use of org.dmg.pmml.clustering.ClusteringModel in project drools by kiegroup.
the class ClusteringModelImplementationProviderTest method getModel.
private static ClusteringModel getModel(PMML pmml) {
assertNotNull(pmml);
assertEquals(1, pmml.getModels().size());
Model model = pmml.getModels().get(0);
assertTrue(model instanceof ClusteringModel);
return (ClusteringModel) model;
}
use of org.dmg.pmml.clustering.ClusteringModel in project jpmml-r by jpmml.
the class KMeansConverter method encodeModel.
@Override
public Model encodeModel(Schema schema) {
RGenericVector kmeans = getObject();
RDoubleVector centers = kmeans.getDoubleElement("centers");
RIntegerVector size = kmeans.getIntegerElement("size");
RIntegerVector centersDim = centers.dim();
int rows = centersDim.getValue(0);
int columns = centersDim.getValue(1);
List<Cluster> clusters = new ArrayList<>();
RStringVector rowNames = centers.dimnames(0);
for (int i = 0; i < rowNames.size(); i++) {
Cluster cluster = new Cluster(PMMLUtil.createRealArray(FortranMatrixUtil.getRow(centers.getValues(), rows, columns, i))).setId(String.valueOf(i + 1)).setName(rowNames.getValue(i)).setSize(size.getValue(i));
clusters.add(cluster);
}
ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE, new SquaredEuclidean()).setCompareFunction(CompareFunction.ABS_DIFF);
ClusteringModel clusteringModel = new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, rows, ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters).setOutput(ClusteringModelUtil.createOutput("cluster", DataType.DOUBLE, clusters));
return clusteringModel;
}
use of org.dmg.pmml.clustering.ClusteringModel in project drools by kiegroup.
the class KiePMMLClusteringModelFactory method setConstructor.
static void setConstructor(final CompilationDTO<ClusteringModel> compilationDTO, final ClassOrInterfaceDeclaration modelTemplate) {
KiePMMLModelFactoryUtils.init(compilationDTO, modelTemplate);
final ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_DEFAULT_CONSTRUCTOR, modelTemplate.getName())));
final BlockStmt body = constructorDeclaration.getBody();
ClusteringModel clusteringModel = compilationDTO.getModel();
body.addStatement(assignExprFrom("modelClass", modelClassFrom(clusteringModel.getModelClass())));
clusteringModel.getClusters().stream().map(KiePMMLClusteringModelFactory::clusterCreationExprFrom).map(expr -> methodCallExprFrom("clusters", "add", expr)).forEach(body::addStatement);
clusteringModel.getClusteringFields().stream().map(KiePMMLClusteringModelFactory::clusteringFieldCreationExprFrom).map(expr -> methodCallExprFrom("clusteringFields", "add", expr)).forEach(body::addStatement);
body.addStatement(assignExprFrom("comparisonMeasure", comparisonMeasureCreationExprFrom(clusteringModel.getComparisonMeasure())));
if (clusteringModel.getMissingValueWeights() != null) {
body.addStatement(assignExprFrom("missingValueWeights", missingValueWeightsCreationExprFrom(clusteringModel.getMissingValueWeights())));
}
}
Aggregations