use of org.apache.ignite.ml.knn.ann.ANNClassificationModel in project ignite by apache.
the class ANNClassificationExportImportExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, double[]> dataCache = null;
Path jsonMdlPath = null;
try {
dataCache = getTestCache(ignite);
ANNClassificationTrainer trainer = new ANNClassificationTrainer().withDistance(new ManhattanDistance()).withK(50).withMaxIterations(1000).withEpsilon(1e-2);
ANNClassificationModel mdl = (ANNClassificationModel) trainer.fit(ignite, dataCache, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(5).withDistanceMeasure(new EuclideanDistance()).withWeighted(true);
System.out.println("\n>>> Exported ANN model: " + mdl.toString(true));
double accuracy = evaluateModel(dataCache, mdl);
System.out.println("\n>>> Accuracy for exported ANN model:" + accuracy);
jsonMdlPath = Files.createTempFile(null, null);
mdl.toJSON(jsonMdlPath);
ANNClassificationModel modelImportedFromJSON = ANNClassificationModel.fromJSON(jsonMdlPath);
System.out.println("\n>>> Imported ANN model: " + modelImportedFromJSON.toString(true));
accuracy = evaluateModel(dataCache, modelImportedFromJSON);
System.out.println("\n>>> Accuracy for imported ANN model:" + accuracy);
System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
} finally {
if (dataCache != null)
dataCache.destroy();
if (jsonMdlPath != null)
Files.deleteIfExists(jsonMdlPath);
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.knn.ann.ANNClassificationModel in project ignite by apache.
the class ANNClassificationTest method testUpdate.
/**
*/
@Test
public void testUpdate() {
Map<Integer, double[]> cacheMock = new HashMap<>();
for (int i = 0; i < twoClusters.length; i++) cacheMock.put(i, twoClusters[i]);
ANNClassificationTrainer trainer = new ANNClassificationTrainer().withK(10).withMaxIterations(10).withEpsilon(1e-4).withDistance(new EuclideanDistance());
ANNClassificationModel originalMdl = (ANNClassificationModel) trainer.fit(cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
ANNClassificationModel updatedOnSameDataset = (ANNClassificationModel) trainer.update(originalMdl, cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
ANNClassificationModel updatedOnEmptyDataset = (ANNClassificationModel) trainer.update(originalMdl, new HashMap<>(), parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
Assert.assertNotNull(updatedOnSameDataset.getCandidates());
assertTrue(updatedOnSameDataset.toString().contains("weighted = [false]"));
assertTrue(updatedOnSameDataset.toString(true).contains("weighted = [false]"));
assertTrue(updatedOnSameDataset.toString(false).contains("weighted = [false]"));
assertNotNull(updatedOnEmptyDataset.getCandidates());
assertTrue(updatedOnEmptyDataset.toString().contains("weighted = [false]"));
assertTrue(updatedOnEmptyDataset.toString(true).contains("weighted = [false]"));
assertTrue(updatedOnEmptyDataset.toString(false).contains("weighted = [false]"));
}
use of org.apache.ignite.ml.knn.ann.ANNClassificationModel in project ignite by apache.
the class ANNClassificationTest method testBinaryClassification.
/**
*/
@Test
public void testBinaryClassification() {
Map<Integer, double[]> cacheMock = new HashMap<>();
for (int i = 0; i < twoClusters.length; i++) cacheMock.put(i, twoClusters[i]);
ANNClassificationTrainer trainer = new ANNClassificationTrainer().withK(10).withMaxIterations(10).withEpsilon(1e-4).withDistance(new EuclideanDistance());
Assert.assertEquals(10, trainer.getK());
Assert.assertEquals(10, trainer.getMaxIterations());
TestUtils.assertEquals(1e-4, trainer.getEpsilon(), PRECISION);
Assert.assertEquals(new EuclideanDistance(), trainer.getDistance());
NNClassificationModel mdl = trainer.fit(cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
Assert.assertNotNull(((ANNClassificationModel) mdl).getCandidates());
assertTrue(mdl.toString().contains("weighted = [false]"));
assertTrue(mdl.toString(true).contains("weighted = [false]"));
assertTrue(mdl.toString(false).contains("weighted = [false]"));
}
use of org.apache.ignite.ml.knn.ann.ANNClassificationModel in project ignite by apache.
the class CollectionsTest method test.
/**
*/
@Test
@SuppressWarnings("unchecked")
public void test() {
test(new VectorizedViewMatrix(new DenseMatrix(2, 2), 1, 1, 1, 1), new VectorizedViewMatrix(new DenseMatrix(3, 2), 2, 1, 1, 1));
specialTest(new ManhattanDistance(), new ManhattanDistance());
specialTest(new HammingDistance(), new HammingDistance());
specialTest(new EuclideanDistance(), new EuclideanDistance());
FeatureMetadata data = new FeatureMetadata("name2");
data.setName("name1");
test(data, new FeatureMetadata("name2"));
test(new DatasetRow<>(new DenseVector()), new DatasetRow<>(new DenseVector(1)));
test(new LabeledVector<>(new DenseVector(), null), new LabeledVector<>(new DenseVector(1), null));
test(new Dataset<DatasetRow<Vector>>(new DatasetRow[] {}, new FeatureMetadata[] {}), new Dataset<DatasetRow<Vector>>(new DatasetRow[] { new DatasetRow() }, new FeatureMetadata[] { new FeatureMetadata() }));
test(new LogisticRegressionModel(new DenseVector(), 1.0), new LogisticRegressionModel(new DenseVector(), 0.5));
test(new KMeansModelFormat(new Vector[] {}, new ManhattanDistance()), new KMeansModelFormat(new Vector[] {}, new HammingDistance()));
test(new KMeansModel(new Vector[] {}, new ManhattanDistance()), new KMeansModel(new Vector[] {}, new HammingDistance()));
test(new SVMLinearClassificationModel(null, 1.0), new SVMLinearClassificationModel(null, 0.5));
test(new ANNClassificationModel(new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()), new ANNClassificationModel(new LabeledVectorSet<>(1, 1), new ANNClassificationTrainer.CentroidStat()));
test(new ANNModelFormat(1, new ManhattanDistance(), false, new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()), new ANNModelFormat(2, new ManhattanDistance(), false, new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()));
}
use of org.apache.ignite.ml.knn.ann.ANNClassificationModel in project ignite by apache.
the class LocalModelsTest method importExportANNModelTest.
/**
*/
@Test
public void importExportANNModelTest() throws IOException {
executeModelTest(mdlFilePath -> {
final LabeledVectorSet<LabeledVector> centers = new LabeledVectorSet<>();
NNClassificationModel mdl = new ANNClassificationModel(centers, new ANNClassificationTrainer.CentroidStat()).withK(4).withDistanceMeasure(new ManhattanDistance()).withWeighted(true);
Exporter<KNNModelFormat, String> exporter = new FileExporter<>();
mdl.saveModel(exporter, mdlFilePath);
ANNModelFormat load = (ANNModelFormat) exporter.load(mdlFilePath);
Assert.assertNotNull(load);
NNClassificationModel importedMdl = new ANNClassificationModel(load.getCandidates(), new ANNClassificationTrainer.CentroidStat()).withK(load.getK()).withDistanceMeasure(load.getDistanceMeasure()).withWeighted(true);
Assert.assertEquals("", mdl, importedMdl);
return null;
});
}
Aggregations