use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.
the class KMeansTrainerTest method findOneClusters.
/**
* A few points, one cluster, one iteration
*/
@Test
public void findOneClusters() {
KMeansTrainer trainer = createAndCheckTrainer();
KMeansModel knnMdl = trainer.withAmountOfClusters(1).fit(new LocalDatasetBuilder<>(data, parts), new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST));
Vector firstVector = new DenseVector(new double[] { 2.0, 2.0 });
assertEquals(knnMdl.predict(firstVector), 0.0, PRECISION);
Vector secondVector = new DenseVector(new double[] { -2.0, -2.0 });
assertEquals(knnMdl.predict(secondVector), 0.0, PRECISION);
assertEquals(trainer.getMaxIterations(), 1);
assertEquals(trainer.getEpsilon(), PRECISION, PRECISION);
}
use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.
the class KMeansTrainerTest method testUpdateMdl.
/**
*/
@Test
public void testUpdateMdl() {
KMeansTrainer trainer = createAndCheckTrainer();
Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST);
KMeansModel originalMdl = trainer.withAmountOfClusters(1).fit(new LocalDatasetBuilder<>(data, parts), vectorizer);
KMeansModel updatedMdlOnSameDataset = trainer.update(originalMdl, new LocalDatasetBuilder<>(data, parts), vectorizer);
KMeansModel updatedMdlOnEmptyDataset = trainer.update(originalMdl, new LocalDatasetBuilder<>(new HashMap<>(), parts), vectorizer);
Vector firstVector = new DenseVector(new double[] { 2.0, 2.0 });
Vector secondVector = new DenseVector(new double[] { -2.0, -2.0 });
assertEquals(originalMdl.predict(firstVector), updatedMdlOnSameDataset.predict(firstVector), PRECISION);
assertEquals(originalMdl.predict(secondVector), updatedMdlOnSameDataset.predict(secondVector), PRECISION);
assertEquals(originalMdl.predict(firstVector), updatedMdlOnEmptyDataset.predict(firstVector), PRECISION);
assertEquals(originalMdl.predict(secondVector), updatedMdlOnEmptyDataset.predict(secondVector), PRECISION);
}
use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.
the class SparkModelParser method loadKMeansModel.
/**
* Load K-Means model.
*
* @param pathToMdl Path to model.
* @param learningEnvironment learningEnvironment
*/
private static Model loadKMeansModel(String pathToMdl, LearningEnvironment learningEnvironment) {
Vector[] centers = null;
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
while (null != (pages = r.readNextRowGroup())) {
final int rows = (int) pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
centers = new DenseVector[rows];
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
// final int clusterIdx = g.getInteger(0, 0);
Group clusterCenterCoeff = g.getGroup(1, 0).getGroup(3, 0);
final int amountOfCoefficients = clusterCenterCoeff.getFieldRepetitionCount(0);
centers[i] = new DenseVector(amountOfCoefficients);
for (int j = 0; j < amountOfCoefficients; j++) {
double coefficient = clusterCenterCoeff.getGroup(0, j).getDouble(0, 0);
centers[i].set(j, coefficient);
}
}
}
} catch (IOException e) {
String msg = "Error reading parquet file: " + e.getMessage();
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
return new KMeansModel(centers, new EuclideanDistance());
}
use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.
the class ANNClassificationTrainer method getCentroids.
/**
* Perform KMeans clusterization algorithm to find centroids.
*
* @param vectorizer Upstream vectorizer.
* @param datasetBuilder The dataset builder.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
* @return The arrays of vectors.
*/
private <K, V, C extends Serializable> List<Vector> getCentroids(Preprocessor<K, V> vectorizer, DatasetBuilder<K, V> datasetBuilder) {
KMeansTrainer trainer = new KMeansTrainer().withAmountOfClusters(k).withMaxIterations(maxIterations).withDistance(distance).withEpsilon(epsilon);
KMeansModel mdl = trainer.fit(datasetBuilder, vectorizer);
return Arrays.asList(mdl.centers());
}
Aggregations