Search in sources :

Example 1 with MultivariateGaussianDistribution

use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.

the class GmmPartitionDataTest method testUpdatePcxi.

/**
 */
@Test
public void testUpdatePcxi() {
    GmmPartitionData.updatePcxi(data, VectorUtils.of(0.3, 0.7), Arrays.asList(new MultivariateGaussianDistribution(VectorUtils.of(1.0, 0.5), new DenseMatrix(new double[] { 0.5, 0., 0., 1. }, 2)), new MultivariateGaussianDistribution(VectorUtils.of(0.0, 0.5), new DenseMatrix(new double[] { 1.0, 0., 0., 1. }, 2))));
    assertEquals(0.49, data.pcxi(0, 0), 1e-2);
    assertEquals(0.50, data.pcxi(1, 0), 1e-2);
    assertEquals(0.18, data.pcxi(0, 1), 1e-2);
    assertEquals(0.81, data.pcxi(1, 1), 1e-2);
    assertEquals(0.49, data.pcxi(0, 2), 1e-2);
    assertEquals(0.50, data.pcxi(1, 2), 1e-2);
}
Also used : MultivariateGaussianDistribution(org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Test(org.junit.Test)

Example 2 with MultivariateGaussianDistribution

use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.

the class GmmTrainer method updateModel.

/**
 * Gets older model and returns updated model on given data.
 *
 * @param dataset Dataset.
 * @param model Model.
 * @return Updated model.
 */
@NotNull
private UpdateResult updateModel(Dataset<EmptyContext, GmmPartitionData> dataset, GmmModel model) {
    boolean isConverged = false;
    int countOfIterations = 0;
    double maxProbInDataset = Double.NEGATIVE_INFINITY;
    while (!isConverged) {
        MeanWithClusterProbAggregator.AggregatedStats stats = MeanWithClusterProbAggregator.aggreateStats(dataset, countOfComponents);
        Vector clusterProbs = stats.clusterProbabilities();
        Vector[] newMeans = stats.means().toArray(new Vector[countOfComponents]);
        A.ensure(newMeans.length == model.countOfComponents(), "newMeans.size() == count of components");
        A.ensure(newMeans[0].size() == initialMeans[0].size(), "newMeans[0].size() == initialMeans[0].size()");
        List<Matrix> newCovs = CovarianceMatricesAggregator.computeCovariances(dataset, clusterProbs, newMeans);
        try {
            List<MultivariateGaussianDistribution> components = buildComponents(newMeans, newCovs);
            GmmModel newModel = new GmmModel(clusterProbs, components);
            countOfIterations += 1;
            isConverged = isConverged(model, newModel) || countOfIterations > maxCountOfIterations;
            model = newModel;
            maxProbInDataset = GmmPartitionData.updatePcxiAndComputeLikelihood(dataset, clusterProbs, components);
        } catch (SingularMatrixException | IllegalArgumentException e) {
            String msg = "Cannot construct non-singular covariance matrix by data. " + "Try to select other initial means or other model trainer. Iterations will stop.";
            environment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
            isConverged = true;
        }
    }
    return new UpdateResult(model, maxProbInDataset);
}
Also used : Matrix(org.apache.ignite.ml.math.primitives.matrix.Matrix) MultivariateGaussianDistribution(org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution) SingularMatrixException(org.apache.ignite.ml.math.exceptions.math.SingularMatrixException) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) NotNull(org.jetbrains.annotations.NotNull)

Example 3 with MultivariateGaussianDistribution

use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.

the class GmmModelTest method testTwoComponents.

/**
 */
@Test
public void testTwoComponents() {
    Vector mean1 = VectorUtils.of(1., 2.);
    DenseMatrix covariance1 = MatrixUtil.fromList(Arrays.asList(VectorUtils.of(1, -0.25), VectorUtils.of(-0.25, 1)), true);
    Vector mean2 = VectorUtils.of(2., 1.);
    DenseMatrix covariance2 = MatrixUtil.fromList(Arrays.asList(VectorUtils.of(1, 0.5), VectorUtils.of(0.5, 1)), true);
    GmmModel gmm = new GmmModel(VectorUtils.of(0.5, 0.5), Arrays.asList(new MultivariateGaussianDistribution(mean1, covariance1), new MultivariateGaussianDistribution(mean2, covariance2)));
    Assert.assertEquals(0., gmm.predict(mean1), 0.01);
    Assert.assertEquals(1., gmm.predict(mean2), 0.01);
    Assert.assertEquals(0., gmm.predict(VectorUtils.of(1.5, 1.5)), 0.01);
    Assert.assertEquals(1., gmm.predict(VectorUtils.of(3., 0.)), 0.01);
}
Also used : MultivariateGaussianDistribution(org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Test(org.junit.Test)

Example 4 with MultivariateGaussianDistribution

use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.

the class GmmClusterizationExample method main.

/**
 * Runs example.
 *
 * @param args Command line arguments.
 */
public static void main(String[] args) {
    System.out.println();
    System.out.println(">>> GMM clustering algorithm over cached dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        long seed = 0;
        IgniteCache<Integer, LabeledVector<Double>> dataCache = null;
        try {
            dataCache = ignite.createCache(new CacheConfiguration<Integer, LabeledVector<Double>>("GMM_EXAMPLE_CACHE").setAffinity(new RendezvousAffinityFunction(false, 10)));
            // Dataset consists of three gaussians where two from them are rotated onto PI/4.
            DataStreamGenerator dataStream = new VectorGeneratorsFamily.Builder().add(RandomProducer.vectorize(new GaussRandomProducer(0, 2., seed++), new GaussRandomProducer(0, 3., seed++)).rotate(Math.PI / 4).move(VectorUtils.of(10., 10.))).add(RandomProducer.vectorize(new GaussRandomProducer(0, 1., seed++), new GaussRandomProducer(0, 2., seed++)).rotate(-Math.PI / 4).move(VectorUtils.of(-10., 10.))).add(RandomProducer.vectorize(new GaussRandomProducer(0, 3., seed++), new GaussRandomProducer(0, 3., seed++)).move(VectorUtils.of(0., -10.))).build(seed++).asDataStream();
            AtomicInteger keyGen = new AtomicInteger();
            dataStream.fillCacheWithCustomKey(50000, dataCache, v -> keyGen.getAndIncrement());
            GmmTrainer trainer = new GmmTrainer(1);
            GmmModel mdl = trainer.withMaxCountIterations(10).withMaxCountOfClusters(4).withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed)).fit(ignite, dataCache, new LabeledDummyVectorizer<>());
            System.out.println(">>> GMM means and covariances");
            for (int i = 0; i < mdl.countOfComponents(); i++) {
                MultivariateGaussianDistribution distribution = mdl.distributions().get(i);
                System.out.println();
                System.out.println("============");
                System.out.println("Component #" + i);
                System.out.println("============");
                System.out.println("Mean vector = ");
                Tracer.showAscii(distribution.mean());
                System.out.println();
                System.out.println("Covariance matrix = ");
                Tracer.showAscii(distribution.covariance());
            }
            System.out.println(">>>");
        } finally {
            if (dataCache != null)
                dataCache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : LabeledVector(org.apache.ignite.ml.structures.LabeledVector) GmmModel(org.apache.ignite.ml.clustering.gmm.GmmModel) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) GmmTrainer(org.apache.ignite.ml.clustering.gmm.GmmTrainer) MultivariateGaussianDistribution(org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) DataStreamGenerator(org.apache.ignite.ml.util.generators.DataStreamGenerator) VectorGeneratorsFamily(org.apache.ignite.ml.util.generators.primitives.vector.VectorGeneratorsFamily) Ignite(org.apache.ignite.Ignite) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)

Example 5 with MultivariateGaussianDistribution

use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.

the class GmmTrainer method filterModel.

/**
 * Remove clusters with probability value < minClusterProbability
 *
 * @param model Model.
 * @return Filtered model.
 */
private GmmModel filterModel(GmmModel model) {
    List<Double> componentProbs = new ArrayList<>();
    List<MultivariateGaussianDistribution> distributions = new ArrayList<>();
    Vector originalComponentProbs = model.componentsProbs();
    List<MultivariateGaussianDistribution> originalDistr = model.distributions();
    for (int i = 0; i < model.countOfComponents(); i++) {
        double prob = originalComponentProbs.get(i);
        if (prob > minClusterProbability) {
            componentProbs.add(prob);
            distributions.add(originalDistr.get(i));
        }
    }
    return new GmmModel(VectorUtils.of(componentProbs.toArray(new Double[0])), distributions);
}
Also used : MultivariateGaussianDistribution(org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution) ArrayList(java.util.ArrayList) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Aggregations

MultivariateGaussianDistribution (org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution)7 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)4 DenseMatrix (org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix)3 Test (org.junit.Test)3 ArrayList (java.util.ArrayList)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 Ignite (org.apache.ignite.Ignite)1 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)1 GmmModel (org.apache.ignite.ml.clustering.gmm.GmmModel)1 GmmTrainer (org.apache.ignite.ml.clustering.gmm.GmmTrainer)1 SingularMatrixException (org.apache.ignite.ml.math.exceptions.math.SingularMatrixException)1 Matrix (org.apache.ignite.ml.math.primitives.matrix.Matrix)1 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)1 DataStreamGenerator (org.apache.ignite.ml.util.generators.DataStreamGenerator)1 GaussRandomProducer (org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer)1 VectorGeneratorsFamily (org.apache.ignite.ml.util.generators.primitives.vector.VectorGeneratorsFamily)1 NotNull (org.jetbrains.annotations.NotNull)1