use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.
the class GmmPartitionDataTest method testUpdatePcxi.
/**
*/
@Test
public void testUpdatePcxi() {
GmmPartitionData.updatePcxi(data, VectorUtils.of(0.3, 0.7), Arrays.asList(new MultivariateGaussianDistribution(VectorUtils.of(1.0, 0.5), new DenseMatrix(new double[] { 0.5, 0., 0., 1. }, 2)), new MultivariateGaussianDistribution(VectorUtils.of(0.0, 0.5), new DenseMatrix(new double[] { 1.0, 0., 0., 1. }, 2))));
assertEquals(0.49, data.pcxi(0, 0), 1e-2);
assertEquals(0.50, data.pcxi(1, 0), 1e-2);
assertEquals(0.18, data.pcxi(0, 1), 1e-2);
assertEquals(0.81, data.pcxi(1, 1), 1e-2);
assertEquals(0.49, data.pcxi(0, 2), 1e-2);
assertEquals(0.50, data.pcxi(1, 2), 1e-2);
}
use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.
the class GmmTrainer method updateModel.
/**
* Gets older model and returns updated model on given data.
*
* @param dataset Dataset.
* @param model Model.
* @return Updated model.
*/
@NotNull
private UpdateResult updateModel(Dataset<EmptyContext, GmmPartitionData> dataset, GmmModel model) {
boolean isConverged = false;
int countOfIterations = 0;
double maxProbInDataset = Double.NEGATIVE_INFINITY;
while (!isConverged) {
MeanWithClusterProbAggregator.AggregatedStats stats = MeanWithClusterProbAggregator.aggreateStats(dataset, countOfComponents);
Vector clusterProbs = stats.clusterProbabilities();
Vector[] newMeans = stats.means().toArray(new Vector[countOfComponents]);
A.ensure(newMeans.length == model.countOfComponents(), "newMeans.size() == count of components");
A.ensure(newMeans[0].size() == initialMeans[0].size(), "newMeans[0].size() == initialMeans[0].size()");
List<Matrix> newCovs = CovarianceMatricesAggregator.computeCovariances(dataset, clusterProbs, newMeans);
try {
List<MultivariateGaussianDistribution> components = buildComponents(newMeans, newCovs);
GmmModel newModel = new GmmModel(clusterProbs, components);
countOfIterations += 1;
isConverged = isConverged(model, newModel) || countOfIterations > maxCountOfIterations;
model = newModel;
maxProbInDataset = GmmPartitionData.updatePcxiAndComputeLikelihood(dataset, clusterProbs, components);
} catch (SingularMatrixException | IllegalArgumentException e) {
String msg = "Cannot construct non-singular covariance matrix by data. " + "Try to select other initial means or other model trainer. Iterations will stop.";
environment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
isConverged = true;
}
}
return new UpdateResult(model, maxProbInDataset);
}
use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.
the class GmmModelTest method testTwoComponents.
/**
*/
@Test
public void testTwoComponents() {
Vector mean1 = VectorUtils.of(1., 2.);
DenseMatrix covariance1 = MatrixUtil.fromList(Arrays.asList(VectorUtils.of(1, -0.25), VectorUtils.of(-0.25, 1)), true);
Vector mean2 = VectorUtils.of(2., 1.);
DenseMatrix covariance2 = MatrixUtil.fromList(Arrays.asList(VectorUtils.of(1, 0.5), VectorUtils.of(0.5, 1)), true);
GmmModel gmm = new GmmModel(VectorUtils.of(0.5, 0.5), Arrays.asList(new MultivariateGaussianDistribution(mean1, covariance1), new MultivariateGaussianDistribution(mean2, covariance2)));
Assert.assertEquals(0., gmm.predict(mean1), 0.01);
Assert.assertEquals(1., gmm.predict(mean2), 0.01);
Assert.assertEquals(0., gmm.predict(VectorUtils.of(1.5, 1.5)), 0.01);
Assert.assertEquals(1., gmm.predict(VectorUtils.of(3., 0.)), 0.01);
}
use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.
the class GmmClusterizationExample method main.
/**
* Runs example.
*
* @param args Command line arguments.
*/
public static void main(String[] args) {
System.out.println();
System.out.println(">>> GMM clustering algorithm over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
long seed = 0;
IgniteCache<Integer, LabeledVector<Double>> dataCache = null;
try {
dataCache = ignite.createCache(new CacheConfiguration<Integer, LabeledVector<Double>>("GMM_EXAMPLE_CACHE").setAffinity(new RendezvousAffinityFunction(false, 10)));
// Dataset consists of three gaussians where two from them are rotated onto PI/4.
DataStreamGenerator dataStream = new VectorGeneratorsFamily.Builder().add(RandomProducer.vectorize(new GaussRandomProducer(0, 2., seed++), new GaussRandomProducer(0, 3., seed++)).rotate(Math.PI / 4).move(VectorUtils.of(10., 10.))).add(RandomProducer.vectorize(new GaussRandomProducer(0, 1., seed++), new GaussRandomProducer(0, 2., seed++)).rotate(-Math.PI / 4).move(VectorUtils.of(-10., 10.))).add(RandomProducer.vectorize(new GaussRandomProducer(0, 3., seed++), new GaussRandomProducer(0, 3., seed++)).move(VectorUtils.of(0., -10.))).build(seed++).asDataStream();
AtomicInteger keyGen = new AtomicInteger();
dataStream.fillCacheWithCustomKey(50000, dataCache, v -> keyGen.getAndIncrement());
GmmTrainer trainer = new GmmTrainer(1);
GmmModel mdl = trainer.withMaxCountIterations(10).withMaxCountOfClusters(4).withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed)).fit(ignite, dataCache, new LabeledDummyVectorizer<>());
System.out.println(">>> GMM means and covariances");
for (int i = 0; i < mdl.countOfComponents(); i++) {
MultivariateGaussianDistribution distribution = mdl.distributions().get(i);
System.out.println();
System.out.println("============");
System.out.println("Component #" + i);
System.out.println("============");
System.out.println("Mean vector = ");
Tracer.showAscii(distribution.mean());
System.out.println();
System.out.println("Covariance matrix = ");
Tracer.showAscii(distribution.covariance());
}
System.out.println(">>>");
} finally {
if (dataCache != null)
dataCache.destroy();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution in project ignite by apache.
the class GmmTrainer method filterModel.
/**
* Remove clusters with probability value < minClusterProbability
*
* @param model Model.
* @return Filtered model.
*/
private GmmModel filterModel(GmmModel model) {
List<Double> componentProbs = new ArrayList<>();
List<MultivariateGaussianDistribution> distributions = new ArrayList<>();
Vector originalComponentProbs = model.componentsProbs();
List<MultivariateGaussianDistribution> originalDistr = model.distributions();
for (int i = 0; i < model.countOfComponents(); i++) {
double prob = originalComponentProbs.get(i);
if (prob > minClusterProbability) {
componentProbs.add(prob);
distributions.add(originalDistr.get(i));
}
}
return new GmmModel(VectorUtils.of(componentProbs.toArray(new Double[0])), distributions);
}
Aggregations