Search in sources :

Example 6 with MultivariateGaussianDistribution

use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.

the class GmmTrainer method isConverged.

/**
 * Check algorithm covergency. If it's true then algorithm stops.
 *
 * @param oldModel Old model.
 * @param newModel New model.
 * @return True if algorithm gonverged.
 */
private boolean isConverged(GmmModel oldModel, GmmModel newModel) {
    A.ensure(oldModel.countOfComponents() == newModel.countOfComponents(), "oldModel.countOfComponents() == newModel.countOfComponents()");
    for (int i = 0; i < oldModel.countOfComponents(); i++) {
        MultivariateGaussianDistribution d1 = oldModel.distributions().get(i);
        MultivariateGaussianDistribution d2 = newModel.distributions().get(i);
        if (Math.sqrt(d1.mean().getDistanceSquared(d2.mean())) >= eps)
            return false;
    }
    return true;
}
Also used : MultivariateGaussianDistribution(org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution)

Example 7 with MultivariateGaussianDistribution

use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.

the class GmmModelTest method testTrivialCasesWithOneComponent.

/**
 */
@Test
public void testTrivialCasesWithOneComponent() {
    Vector mean = VectorUtils.of(1., 2.);
    DenseMatrix covariance = MatrixUtil.fromList(Arrays.asList(VectorUtils.of(1, -0.5), VectorUtils.of(-0.5, 1)), true);
    GmmModel gmm = new GmmModel(VectorUtils.of(1.0), Collections.singletonList(new MultivariateGaussianDistribution(mean, covariance)));
    Assert.assertEquals(2, gmm.dimension());
    Assert.assertEquals(1, gmm.countOfComponents());
    Assert.assertEquals(VectorUtils.of(1.), gmm.componentsProbs());
    Assert.assertEquals(0., gmm.predict(mean), 0.01);
    Assert.assertEquals(1, gmm.likelihood(mean).size());
    Assert.assertEquals(0.183, gmm.likelihood(mean).get(0), 0.01);
    Assert.assertEquals(0.183, gmm.prob(mean), 0.01);
}
Also used : MultivariateGaussianDistribution(org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Test(org.junit.Test)

Aggregations

MultivariateGaussianDistribution (org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution)7 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)4 DenseMatrix (org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix)3 Test (org.junit.Test)3 ArrayList (java.util.ArrayList)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 Ignite (org.apache.ignite.Ignite)1 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)1 GmmModel (org.apache.ignite.ml.clustering.gmm.GmmModel)1 GmmTrainer (org.apache.ignite.ml.clustering.gmm.GmmTrainer)1 SingularMatrixException (org.apache.ignite.ml.math.exceptions.math.SingularMatrixException)1 Matrix (org.apache.ignite.ml.math.primitives.matrix.Matrix)1 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)1 DataStreamGenerator (org.apache.ignite.ml.util.generators.DataStreamGenerator)1 GaussRandomProducer (org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer)1 VectorGeneratorsFamily (org.apache.ignite.ml.util.generators.primitives.vector.VectorGeneratorsFamily)1 NotNull (org.jetbrains.annotations.NotNull)1