use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.
the class GmmTrainer method isConverged.
/**
* Check algorithm covergency. If it's true then algorithm stops.
*
* @param oldModel Old model.
* @param newModel New model.
* @return True if algorithm gonverged.
*/
private boolean isConverged(GmmModel oldModel, GmmModel newModel) {
A.ensure(oldModel.countOfComponents() == newModel.countOfComponents(), "oldModel.countOfComponents() == newModel.countOfComponents()");
for (int i = 0; i < oldModel.countOfComponents(); i++) {
MultivariateGaussianDistribution d1 = oldModel.distributions().get(i);
MultivariateGaussianDistribution d2 = newModel.distributions().get(i);
if (Math.sqrt(d1.mean().getDistanceSquared(d2.mean())) >= eps)
return false;
}
return true;
}
use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.
the class GmmModelTest method testTrivialCasesWithOneComponent.
/**
*/
@Test
public void testTrivialCasesWithOneComponent() {
Vector mean = VectorUtils.of(1., 2.);
DenseMatrix covariance = MatrixUtil.fromList(Arrays.asList(VectorUtils.of(1, -0.5), VectorUtils.of(-0.5, 1)), true);
GmmModel gmm = new GmmModel(VectorUtils.of(1.0), Collections.singletonList(new MultivariateGaussianDistribution(mean, covariance)));
Assert.assertEquals(2, gmm.dimension());
Assert.assertEquals(1, gmm.countOfComponents());
Assert.assertEquals(VectorUtils.of(1.), gmm.componentsProbs());
Assert.assertEquals(0., gmm.predict(mean), 0.01);
Assert.assertEquals(1, gmm.likelihood(mean).size());
Assert.assertEquals(0.183, gmm.likelihood(mean).get(0), 0.01);
Assert.assertEquals(0.183, gmm.prob(mean), 0.01);
}
Aggregations