use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.
the class KMeansClusterizationExportImportExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> KMeans clustering algorithm over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, Vector> dataCache = null;
Path jsonMdlPath = null;
try {
dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
KMeansTrainer trainer = new KMeansTrainer().withDistance(new WeightedMinkowskiDistance(2, new double[] { 5.9360, 2.7700, 4.2600, 1.3260 }));
// .withDistance(new MinkowskiDistance(2));
KMeansModel mdl = trainer.fit(ignite, dataCache, vectorizer);
System.out.println("\n>>> Exported KMeans model: " + mdl);
jsonMdlPath = Files.createTempFile(null, null);
mdl.toJSON(jsonMdlPath);
KMeansModel modelImportedFromJSON = KMeansModel.fromJSON(jsonMdlPath);
System.out.println("\n>>> Imported KMeans model: " + modelImportedFromJSON);
System.out.println("\n>>> KMeans clustering algorithm over cached dataset usage example completed.");
} finally {
if (dataCache != null)
dataCache.destroy();
if (jsonMdlPath != null)
Files.deleteIfExists(jsonMdlPath);
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.
the class KeepBinaryTest method test.
/**
* Startup Ignite, populate cache and train some model.
*/
@Test
public void test() {
IgniteCache<Integer, BinaryObject> dataCache = populateCache(ignite);
KMeansTrainer trainer = new KMeansTrainer();
CacheBasedDatasetBuilder<Integer, BinaryObject> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, dataCache).withKeepBinary(true);
KMeansModel mdl = trainer.fit(datasetBuilder, new BinaryObjectVectorizer<Integer>("feature1").labeled("label"));
Integer zeroCentre = mdl.predict(VectorUtils.num2Vec(0.0));
assertTrue(mdl.centers()[zeroCentre].get(0) == 0);
}
use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.
the class CollectionsTest method test.
/**
*/
@Test
@SuppressWarnings("unchecked")
public void test() {
test(new VectorizedViewMatrix(new DenseMatrix(2, 2), 1, 1, 1, 1), new VectorizedViewMatrix(new DenseMatrix(3, 2), 2, 1, 1, 1));
specialTest(new ManhattanDistance(), new ManhattanDistance());
specialTest(new HammingDistance(), new HammingDistance());
specialTest(new EuclideanDistance(), new EuclideanDistance());
FeatureMetadata data = new FeatureMetadata("name2");
data.setName("name1");
test(data, new FeatureMetadata("name2"));
test(new DatasetRow<>(new DenseVector()), new DatasetRow<>(new DenseVector(1)));
test(new LabeledVector<>(new DenseVector(), null), new LabeledVector<>(new DenseVector(1), null));
test(new Dataset<DatasetRow<Vector>>(new DatasetRow[] {}, new FeatureMetadata[] {}), new Dataset<DatasetRow<Vector>>(new DatasetRow[] { new DatasetRow() }, new FeatureMetadata[] { new FeatureMetadata() }));
test(new LogisticRegressionModel(new DenseVector(), 1.0), new LogisticRegressionModel(new DenseVector(), 0.5));
test(new KMeansModelFormat(new Vector[] {}, new ManhattanDistance()), new KMeansModelFormat(new Vector[] {}, new HammingDistance()));
test(new KMeansModel(new Vector[] {}, new ManhattanDistance()), new KMeansModel(new Vector[] {}, new HammingDistance()));
test(new SVMLinearClassificationModel(null, 1.0), new SVMLinearClassificationModel(null, 0.5));
test(new ANNClassificationModel(new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()), new ANNClassificationModel(new LabeledVectorSet<>(1, 1), new ANNClassificationTrainer.CentroidStat()));
test(new ANNModelFormat(1, new ManhattanDistance(), false, new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()), new ANNModelFormat(2, new ManhattanDistance(), false, new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()));
}
use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.
the class KMeansModelTest method predictClusters.
/**
*/
@Test
public void predictClusters() {
DistanceMeasure distanceMeasure = new EuclideanDistance();
Vector[] centers = new DenseVector[4];
centers[0] = new DenseVector(new double[] { 1.0, 1.0 });
centers[1] = new DenseVector(new double[] { -1.0, 1.0 });
centers[2] = new DenseVector(new double[] { 1.0, -1.0 });
centers[3] = new DenseVector(new double[] { -1.0, -1.0 });
KMeansModel mdl = new KMeansModel(centers, distanceMeasure);
Assert.assertTrue(mdl.toString().contains("KMeansModel"));
Assert.assertEquals(mdl.predict(new DenseVector(new double[] { 1.1, 1.1 })), 0.0, PRECISION);
Assert.assertEquals(mdl.predict(new DenseVector(new double[] { -1.1, 1.1 })), 1.0, PRECISION);
Assert.assertEquals(mdl.predict(new DenseVector(new double[] { 1.1, -1.1 })), 2.0, PRECISION);
Assert.assertEquals(mdl.predict(new DenseVector(new double[] { -1.1, -1.1 })), 3.0, PRECISION);
Assert.assertEquals(mdl.distanceMeasure(), distanceMeasure);
Assert.assertEquals(mdl.amountOfClusters(), 4);
Assert.assertArrayEquals(mdl.centers(), centers);
}
use of org.apache.ignite.ml.clustering.kmeans.KMeansModel in project ignite by apache.
the class KMeansFromSparkExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws FileNotFoundException {
System.out.println();
System.out.println(">>> K-means model loaded from Spark through serialization over partitioned dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, Vector> dataCache = null;
try {
dataCache = TitanicUtils.readPassengers(ignite);
final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 5, 6, 4).labeled(1);
KMeansModel mdl = (KMeansModel) SparkModelParser.parse(SPARK_MDL_PATH, SupportedSparkModels.KMEANS, env);
System.out.println(">>> K-Means model: " + mdl);
System.out.println(">>> ------------------------------------");
System.out.println(">>> | Predicted cluster\t| Is survived\t|");
System.out.println(">>> ------------------------------------");
try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, Vector> observation : observations) {
LabeledVector<Double> lv = vectorizer.apply(observation.getKey(), observation.getValue());
Vector inputs = lv.features();
double isSurvived = lv.label();
double clusterId = mdl.predict(inputs);
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", clusterId, isSurvived);
}
}
System.out.println(">>> ---------------------------------");
} finally {
dataCache.destroy();
}
}
}
Aggregations