Search in sources :

Example 1 with ComparisonMeasure

use of org.dmg.pmml.ComparisonMeasure in project jpmml-sparkml by jpmml.

the class KMeansModelConverter method encodeModel.

@Override
public ClusteringModel encodeModel(Schema schema) {
    KMeansModel model = getTransformer();
    List<Cluster> clusters = new ArrayList<>();
    Vector[] clusterCenters = model.clusterCenters();
    for (int i = 0; i < clusterCenters.length; i++) {
        Cluster cluster = new Cluster().setId(String.valueOf(i)).setArray(PMMLUtil.createRealArray(VectorUtil.toList(clusterCenters[i])));
        clusters.add(cluster);
    }
    ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE).setCompareFunction(CompareFunction.ABS_DIFF).setMeasure(new SquaredEuclidean());
    return new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, clusters.size(), ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters);
}
Also used : KMeansModel(org.apache.spark.ml.clustering.KMeansModel) SquaredEuclidean(org.dmg.pmml.SquaredEuclidean) ArrayList(java.util.ArrayList) Cluster(org.dmg.pmml.clustering.Cluster) Vector(org.apache.spark.ml.linalg.Vector) ComparisonMeasure(org.dmg.pmml.ComparisonMeasure) ClusteringModel(org.dmg.pmml.clustering.ClusteringModel)

Example 2 with ComparisonMeasure

use of org.dmg.pmml.ComparisonMeasure in project jpmml-r by jpmml.

the class KMeansConverter method encodeModel.

@Override
public Model encodeModel(Schema schema) {
    RGenericVector kmeans = getObject();
    RDoubleVector centers = kmeans.getDoubleElement("centers");
    RIntegerVector size = kmeans.getIntegerElement("size");
    RIntegerVector centersDim = centers.dim();
    int rows = centersDim.getValue(0);
    int columns = centersDim.getValue(1);
    List<Cluster> clusters = new ArrayList<>();
    RStringVector rowNames = centers.dimnames(0);
    for (int i = 0; i < rowNames.size(); i++) {
        Cluster cluster = new Cluster(PMMLUtil.createRealArray(FortranMatrixUtil.getRow(centers.getValues(), rows, columns, i))).setId(String.valueOf(i + 1)).setName(rowNames.getValue(i)).setSize(size.getValue(i));
        clusters.add(cluster);
    }
    ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE, new SquaredEuclidean()).setCompareFunction(CompareFunction.ABS_DIFF);
    ClusteringModel clusteringModel = new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, rows, ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters).setOutput(ClusteringModelUtil.createOutput("cluster", DataType.DOUBLE, clusters));
    return clusteringModel;
}
Also used : SquaredEuclidean(org.dmg.pmml.SquaredEuclidean) ArrayList(java.util.ArrayList) Cluster(org.dmg.pmml.clustering.Cluster) ComparisonMeasure(org.dmg.pmml.ComparisonMeasure) ClusteringModel(org.dmg.pmml.clustering.ClusteringModel)

Example 3 with ComparisonMeasure

use of org.dmg.pmml.ComparisonMeasure in project drools by kiegroup.

the class KiePMMLClusteringModelFactory method setConstructor.

static void setConstructor(final CompilationDTO<ClusteringModel> compilationDTO, final ClassOrInterfaceDeclaration modelTemplate) {
    KiePMMLModelFactoryUtils.init(compilationDTO, modelTemplate);
    final ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_DEFAULT_CONSTRUCTOR, modelTemplate.getName())));
    final BlockStmt body = constructorDeclaration.getBody();
    ClusteringModel clusteringModel = compilationDTO.getModel();
    body.addStatement(assignExprFrom("modelClass", modelClassFrom(clusteringModel.getModelClass())));
    clusteringModel.getClusters().stream().map(KiePMMLClusteringModelFactory::clusterCreationExprFrom).map(expr -> methodCallExprFrom("clusters", "add", expr)).forEach(body::addStatement);
    clusteringModel.getClusteringFields().stream().map(KiePMMLClusteringModelFactory::clusteringFieldCreationExprFrom).map(expr -> methodCallExprFrom("clusteringFields", "add", expr)).forEach(body::addStatement);
    body.addStatement(assignExprFrom("comparisonMeasure", comparisonMeasureCreationExprFrom(clusteringModel.getComparisonMeasure())));
    if (clusteringModel.getMissingValueWeights() != null) {
        body.addStatement(assignExprFrom("missingValueWeights", missingValueWeightsCreationExprFrom(clusteringModel.getMissingValueWeights())));
    }
}
Also used : KiePMMLCluster(org.kie.pmml.models.clustering.model.KiePMMLCluster) Arrays(java.util.Arrays) ClassOrInterfaceType(com.github.javaparser.ast.type.ClassOrInterfaceType) CommonCodegenUtils.methodCallExprFrom(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.methodCallExprFrom) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CommonCodegenUtils.assignExprFrom(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.assignExprFrom) CommonCodegenUtils.literalExprFrom(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.literalExprFrom) MAIN_CLASS_NOT_FOUND(org.kie.pmml.compiler.commons.utils.JavaParserUtils.MAIN_CLASS_NOT_FOUND) ConstructorDeclaration(com.github.javaparser.ast.body.ConstructorDeclaration) NullLiteralExpr(com.github.javaparser.ast.expr.NullLiteralExpr) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) MissingValueWeights(org.dmg.pmml.clustering.MissingValueWeights) KiePMMLComparisonMeasure(org.kie.pmml.models.clustering.model.KiePMMLComparisonMeasure) DoubleLiteralExpr(com.github.javaparser.ast.expr.DoubleLiteralExpr) ObjectCreationExpr(com.github.javaparser.ast.expr.ObjectCreationExpr) Map(java.util.Map) Expression(com.github.javaparser.ast.expr.Expression) ComparisonMeasure(org.dmg.pmml.ComparisonMeasure) CompilationUnit(com.github.javaparser.ast.CompilationUnit) KiePMMLClusteringModel(org.kie.pmml.models.clustering.model.KiePMMLClusteringModel) KiePMMLMissingValueWeights(org.kie.pmml.models.clustering.model.KiePMMLMissingValueWeights) NodeList(com.github.javaparser.ast.NodeList) CompilationDTO(org.kie.pmml.compiler.api.dto.CompilationDTO) KiePMMLClusteringField(org.kie.pmml.models.clustering.model.KiePMMLClusteringField) ClusteringField(org.dmg.pmml.clustering.ClusteringField) JavaParserUtils(org.kie.pmml.compiler.commons.utils.JavaParserUtils) Logger(org.slf4j.Logger) BooleanLiteralExpr(com.github.javaparser.ast.expr.BooleanLiteralExpr) KiePMMLModelFactoryUtils(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLModelFactoryUtils) JavaParserUtils.getFullClassName(org.kie.pmml.compiler.commons.utils.JavaParserUtils.getFullClassName) KiePMMLClusteringConversionUtils.aggregateFunctionFrom(org.kie.pmml.models.clustering.compiler.factories.KiePMMLClusteringConversionUtils.aggregateFunctionFrom) KiePMMLClusteringConversionUtils.compareFunctionFrom(org.kie.pmml.models.clustering.compiler.factories.KiePMMLClusteringConversionUtils.compareFunctionFrom) Array(org.dmg.pmml.Array) KiePMMLClusteringConversionUtils.modelClassFrom(org.kie.pmml.models.clustering.compiler.factories.KiePMMLClusteringConversionUtils.modelClassFrom) Cluster(org.dmg.pmml.clustering.Cluster) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) MISSING_DEFAULT_CONSTRUCTOR(org.kie.pmml.commons.Constants.MISSING_DEFAULT_CONSTRUCTOR) KiePMMLClusteringConversionUtils.comparisonMeasureKindFrom(org.kie.pmml.models.clustering.compiler.factories.KiePMMLClusteringConversionUtils.comparisonMeasureKindFrom) ClassOrInterfaceDeclaration(com.github.javaparser.ast.body.ClassOrInterfaceDeclaration) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) ClusteringModel(org.dmg.pmml.clustering.ClusteringModel) ConstructorDeclaration(com.github.javaparser.ast.body.ConstructorDeclaration) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) KiePMMLClusteringModel(org.kie.pmml.models.clustering.model.KiePMMLClusteringModel) ClusteringModel(org.dmg.pmml.clustering.ClusteringModel)

Aggregations

ComparisonMeasure (org.dmg.pmml.ComparisonMeasure)3 Cluster (org.dmg.pmml.clustering.Cluster)3 ClusteringModel (org.dmg.pmml.clustering.ClusteringModel)3 ArrayList (java.util.ArrayList)2 SquaredEuclidean (org.dmg.pmml.SquaredEuclidean)2 CompilationUnit (com.github.javaparser.ast.CompilationUnit)1 NodeList (com.github.javaparser.ast.NodeList)1 ClassOrInterfaceDeclaration (com.github.javaparser.ast.body.ClassOrInterfaceDeclaration)1 ConstructorDeclaration (com.github.javaparser.ast.body.ConstructorDeclaration)1 BooleanLiteralExpr (com.github.javaparser.ast.expr.BooleanLiteralExpr)1 DoubleLiteralExpr (com.github.javaparser.ast.expr.DoubleLiteralExpr)1 Expression (com.github.javaparser.ast.expr.Expression)1 NullLiteralExpr (com.github.javaparser.ast.expr.NullLiteralExpr)1 ObjectCreationExpr (com.github.javaparser.ast.expr.ObjectCreationExpr)1 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)1 ClassOrInterfaceType (com.github.javaparser.ast.type.ClassOrInterfaceType)1 Arrays (java.util.Arrays)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 KMeansModel (org.apache.spark.ml.clustering.KMeansModel)1