Search in sources :

Example 1 with Cluster

use of org.dmg.pmml.clustering.Cluster in project jpmml-sparkml by jpmml.

the class KMeansModelConverter method encodeModel.

@Override
public ClusteringModel encodeModel(Schema schema) {
    KMeansModel model = getTransformer();
    List<Cluster> clusters = new ArrayList<>();
    Vector[] clusterCenters = model.clusterCenters();
    for (int i = 0; i < clusterCenters.length; i++) {
        Cluster cluster = new Cluster().setId(String.valueOf(i)).setArray(PMMLUtil.createRealArray(VectorUtil.toList(clusterCenters[i])));
        clusters.add(cluster);
    }
    ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE).setCompareFunction(CompareFunction.ABS_DIFF).setMeasure(new SquaredEuclidean());
    return new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, clusters.size(), ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters);
}
Also used : KMeansModel(org.apache.spark.ml.clustering.KMeansModel) SquaredEuclidean(org.dmg.pmml.SquaredEuclidean) ArrayList(java.util.ArrayList) Cluster(org.dmg.pmml.clustering.Cluster) Vector(org.apache.spark.ml.linalg.Vector) ComparisonMeasure(org.dmg.pmml.ComparisonMeasure) ClusteringModel(org.dmg.pmml.clustering.ClusteringModel)

Example 2 with Cluster

use of org.dmg.pmml.clustering.Cluster in project jpmml-r by jpmml.

the class KMeansConverter method encodeModel.

@Override
public Model encodeModel(Schema schema) {
    RGenericVector kmeans = getObject();
    RDoubleVector centers = kmeans.getDoubleElement("centers");
    RIntegerVector size = kmeans.getIntegerElement("size");
    RIntegerVector centersDim = centers.dim();
    int rows = centersDim.getValue(0);
    int columns = centersDim.getValue(1);
    List<Cluster> clusters = new ArrayList<>();
    RStringVector rowNames = centers.dimnames(0);
    for (int i = 0; i < rowNames.size(); i++) {
        Cluster cluster = new Cluster(PMMLUtil.createRealArray(FortranMatrixUtil.getRow(centers.getValues(), rows, columns, i))).setId(String.valueOf(i + 1)).setName(rowNames.getValue(i)).setSize(size.getValue(i));
        clusters.add(cluster);
    }
    ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE, new SquaredEuclidean()).setCompareFunction(CompareFunction.ABS_DIFF);
    ClusteringModel clusteringModel = new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, rows, ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters).setOutput(ClusteringModelUtil.createOutput("cluster", DataType.DOUBLE, clusters));
    return clusteringModel;
}
Also used : SquaredEuclidean(org.dmg.pmml.SquaredEuclidean) ArrayList(java.util.ArrayList) Cluster(org.dmg.pmml.clustering.Cluster) ComparisonMeasure(org.dmg.pmml.ComparisonMeasure) ClusteringModel(org.dmg.pmml.clustering.ClusteringModel)

Aggregations

ArrayList (java.util.ArrayList)2 ComparisonMeasure (org.dmg.pmml.ComparisonMeasure)2 SquaredEuclidean (org.dmg.pmml.SquaredEuclidean)2 Cluster (org.dmg.pmml.clustering.Cluster)2 ClusteringModel (org.dmg.pmml.clustering.ClusteringModel)2 KMeansModel (org.apache.spark.ml.clustering.KMeansModel)1 Vector (org.apache.spark.ml.linalg.Vector)1