Search in sources :

Example 11 with KMeansModel

use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.

the class KMeansTrainerTest method findOneClusters.

/**
 * A few points, one cluster, one iteration
 */
@Test
public void findOneClusters() {
    KMeansTrainer trainer = createAndCheckTrainer();
    KMeansModel knnMdl = trainer.withAmountOfClusters(1).fit(new LocalDatasetBuilder<>(data, parts), new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST));
    Vector firstVector = new DenseVector(new double[] { 2.0, 2.0 });
    assertEquals(knnMdl.predict(firstVector), 0.0, PRECISION);
    Vector secondVector = new DenseVector(new double[] { -2.0, -2.0 });
    assertEquals(knnMdl.predict(secondVector), 0.0, PRECISION);
    assertEquals(trainer.getMaxIterations(), 1);
    assertEquals(trainer.getEpsilon(), PRECISION, PRECISION);
}
Also used : KMeansModel(org.apache.ignite.ml.clustering.kmeans.KMeansModel) DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) KMeansTrainer(org.apache.ignite.ml.clustering.kmeans.KMeansTrainer) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 12 with KMeansModel

use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.

the class KMeansTrainerTest method testUpdateMdl.

/**
 */
@Test
public void testUpdateMdl() {
    KMeansTrainer trainer = createAndCheckTrainer();
    Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST);
    KMeansModel originalMdl = trainer.withAmountOfClusters(1).fit(new LocalDatasetBuilder<>(data, parts), vectorizer);
    KMeansModel updatedMdlOnSameDataset = trainer.update(originalMdl, new LocalDatasetBuilder<>(data, parts), vectorizer);
    KMeansModel updatedMdlOnEmptyDataset = trainer.update(originalMdl, new LocalDatasetBuilder<>(new HashMap<>(), parts), vectorizer);
    Vector firstVector = new DenseVector(new double[] { 2.0, 2.0 });
    Vector secondVector = new DenseVector(new double[] { -2.0, -2.0 });
    assertEquals(originalMdl.predict(firstVector), updatedMdlOnSameDataset.predict(firstVector), PRECISION);
    assertEquals(originalMdl.predict(secondVector), updatedMdlOnSameDataset.predict(secondVector), PRECISION);
    assertEquals(originalMdl.predict(firstVector), updatedMdlOnEmptyDataset.predict(firstVector), PRECISION);
    assertEquals(originalMdl.predict(secondVector), updatedMdlOnEmptyDataset.predict(secondVector), PRECISION);
}
Also used : KMeansModel(org.apache.ignite.ml.clustering.kmeans.KMeansModel) HashMap(java.util.HashMap) KMeansTrainer(org.apache.ignite.ml.clustering.kmeans.KMeansTrainer) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 13 with KMeansModel

use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.

the class SparkModelParser method loadKMeansModel.

/**
 * Load K-Means model.
 *
 * @param pathToMdl Path to model.
 * @param learningEnvironment learningEnvironment
 */
private static Model loadKMeansModel(String pathToMdl, LearningEnvironment learningEnvironment) {
    Vector[] centers = null;
    try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
        PageReadStore pages;
        final MessageType schema = r.getFooter().getFileMetaData().getSchema();
        final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
        while (null != (pages = r.readNextRowGroup())) {
            final int rows = (int) pages.getRowCount();
            final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
            centers = new DenseVector[rows];
            for (int i = 0; i < rows; i++) {
                final SimpleGroup g = (SimpleGroup) recordReader.read();
                // final int clusterIdx = g.getInteger(0, 0);
                Group clusterCenterCoeff = g.getGroup(1, 0).getGroup(3, 0);
                final int amountOfCoefficients = clusterCenterCoeff.getFieldRepetitionCount(0);
                centers[i] = new DenseVector(amountOfCoefficients);
                for (int j = 0; j < amountOfCoefficients; j++) {
                    double coefficient = clusterCenterCoeff.getGroup(0, j).getDouble(0, 0);
                    centers[i].set(j, coefficient);
                }
            }
        }
    } catch (IOException e) {
        String msg = "Error reading parquet file: " + e.getMessage();
        learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
        e.printStackTrace();
    }
    return new KMeansModel(centers, new EuclideanDistance());
}
Also used : Path(org.apache.hadoop.fs.Path) Group(org.apache.parquet.example.data.Group) SimpleGroup(org.apache.parquet.example.data.simple.SimpleGroup) GroupRecordConverter(org.apache.parquet.example.data.simple.convert.GroupRecordConverter) KMeansModel(org.apache.ignite.ml.clustering.kmeans.KMeansModel) Configuration(org.apache.hadoop.conf.Configuration) ParquetFileReader(org.apache.parquet.hadoop.ParquetFileReader) RecordReader(org.apache.parquet.io.RecordReader) SimpleGroup(org.apache.parquet.example.data.simple.SimpleGroup) IOException(java.io.IOException) MessageColumnIO(org.apache.parquet.io.MessageColumnIO) ColumnIOFactory(org.apache.parquet.io.ColumnIOFactory) EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) PageReadStore(org.apache.parquet.column.page.PageReadStore) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) MessageType(org.apache.parquet.schema.MessageType) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 14 with KMeansModel

use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.

the class ANNClassificationTrainer method getCentroids.

/**
 * Perform KMeans clusterization algorithm to find centroids.
 *
 * @param vectorizer Upstream vectorizer.
 * @param datasetBuilder The dataset builder.
 * @param <K> Type of a key in {@code upstream} data.
 * @param <V> Type of a value in {@code upstream} data.
 * @return The arrays of vectors.
 */
private <K, V, C extends Serializable> List<Vector> getCentroids(Preprocessor<K, V> vectorizer, DatasetBuilder<K, V> datasetBuilder) {
    KMeansTrainer trainer = new KMeansTrainer().withAmountOfClusters(k).withMaxIterations(maxIterations).withDistance(distance).withEpsilon(epsilon);
    KMeansModel mdl = trainer.fit(datasetBuilder, vectorizer);
    return Arrays.asList(mdl.centers());
}
Also used : KMeansModel(org.apache.ignite.ml.clustering.kmeans.KMeansModel) KMeansTrainer(org.apache.ignite.ml.clustering.kmeans.KMeansTrainer)

Aggregations

KMeansModel (org.apache.ignite.ml.clustering.kmeans.KMeansModel)14 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)10 KMeansTrainer (org.apache.ignite.ml.clustering.kmeans.KMeansTrainer)9 Ignite (org.apache.ignite.Ignite)6 Test (org.junit.Test)6 EuclideanDistance (org.apache.ignite.ml.math.distances.EuclideanDistance)5 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)5 SandboxMLCache (org.apache.ignite.examples.ml.util.SandboxMLCache)4 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)4 Cache (javax.cache.Cache)3 IgniteCache (org.apache.ignite.IgniteCache)3 IOException (java.io.IOException)2 HashMap (java.util.HashMap)2 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)2 BinaryObject (org.apache.ignite.binary.BinaryObject)2 KMeansModelFormat (org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat)2 TrainerTest (org.apache.ignite.ml.common.TrainerTest)2 Path (java.nio.file.Path)1 Map (java.util.Map)1 Configuration (org.apache.hadoop.conf.Configuration)1