use of org.dmg.pmml.clustering.Cluster in project jpmml-sparkml by jpmml.
the class KMeansModelConverter method encodeModel.
@Override
public ClusteringModel encodeModel(Schema schema) {
KMeansModel model = getTransformer();
List<Cluster> clusters = new ArrayList<>();
Vector[] clusterCenters = model.clusterCenters();
for (int i = 0; i < clusterCenters.length; i++) {
Cluster cluster = new Cluster().setId(String.valueOf(i)).setArray(PMMLUtil.createRealArray(VectorUtil.toList(clusterCenters[i])));
clusters.add(cluster);
}
ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE).setCompareFunction(CompareFunction.ABS_DIFF).setMeasure(new SquaredEuclidean());
return new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, clusters.size(), ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters);
}
use of org.dmg.pmml.clustering.Cluster in project jpmml-r by jpmml.
the class KMeansConverter method encodeModel.
@Override
public Model encodeModel(Schema schema) {
RGenericVector kmeans = getObject();
RDoubleVector centers = kmeans.getDoubleElement("centers");
RIntegerVector size = kmeans.getIntegerElement("size");
RIntegerVector centersDim = centers.dim();
int rows = centersDim.getValue(0);
int columns = centersDim.getValue(1);
List<Cluster> clusters = new ArrayList<>();
RStringVector rowNames = centers.dimnames(0);
for (int i = 0; i < rowNames.size(); i++) {
Cluster cluster = new Cluster(PMMLUtil.createRealArray(FortranMatrixUtil.getRow(centers.getValues(), rows, columns, i))).setId(String.valueOf(i + 1)).setName(rowNames.getValue(i)).setSize(size.getValue(i));
clusters.add(cluster);
}
ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE, new SquaredEuclidean()).setCompareFunction(CompareFunction.ABS_DIFF);
ClusteringModel clusteringModel = new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, rows, ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters).setOutput(ClusteringModelUtil.createOutput("cluster", DataType.DOUBLE, clusters));
return clusteringModel;
}
Aggregations