use of org.apache.ignite.ml.knn.ann.ANNClassificationTrainer in project ignite by apache.
the class ANNClassificationExportImportExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, double[]> dataCache = null;
Path jsonMdlPath = null;
try {
dataCache = getTestCache(ignite);
ANNClassificationTrainer trainer = new ANNClassificationTrainer().withDistance(new ManhattanDistance()).withK(50).withMaxIterations(1000).withEpsilon(1e-2);
ANNClassificationModel mdl = (ANNClassificationModel) trainer.fit(ignite, dataCache, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(5).withDistanceMeasure(new EuclideanDistance()).withWeighted(true);
System.out.println("\n>>> Exported ANN model: " + mdl.toString(true));
double accuracy = evaluateModel(dataCache, mdl);
System.out.println("\n>>> Accuracy for exported ANN model:" + accuracy);
jsonMdlPath = Files.createTempFile(null, null);
mdl.toJSON(jsonMdlPath);
ANNClassificationModel modelImportedFromJSON = ANNClassificationModel.fromJSON(jsonMdlPath);
System.out.println("\n>>> Imported ANN model: " + modelImportedFromJSON.toString(true));
accuracy = evaluateModel(dataCache, modelImportedFromJSON);
System.out.println("\n>>> Accuracy for imported ANN model:" + accuracy);
System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
} finally {
if (dataCache != null)
dataCache.destroy();
if (jsonMdlPath != null)
Files.deleteIfExists(jsonMdlPath);
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.knn.ann.ANNClassificationTrainer in project ignite by apache.
the class ANNClassificationExample method main.
/**
* Run example.
*/
public static void main(String[] args) {
System.out.println();
System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, double[]> dataCache = null;
try {
dataCache = getTestCache(ignite);
ANNClassificationTrainer trainer = new ANNClassificationTrainer().withDistance(new ManhattanDistance()).withK(50).withMaxIterations(1000).withEpsilon(1e-2);
long startTrainingTime = System.currentTimeMillis();
NNClassificationModel knnMdl = trainer.fit(ignite, dataCache, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(5).withDistanceMeasure(new EuclideanDistance()).withWeighted(true);
long endTrainingTime = System.currentTimeMillis();
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
int amountOfErrors = 0;
int totalAmount = 0;
long totalPredictionTime = 0L;
try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, double[]> observation : observations) {
double[] val = observation.getValue();
double[] inputs = Arrays.copyOfRange(val, 1, val.length);
double groundTruth = val[0];
long startPredictionTime = System.currentTimeMillis();
double prediction = knnMdl.predict(new DenseVector(inputs));
long endPredictionTime = System.currentTimeMillis();
totalPredictionTime += (endPredictionTime - startPredictionTime);
totalAmount++;
if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
amountOfErrors++;
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
System.out.println(">>> ---------------------------------");
System.out.println("Training costs = " + (endTrainingTime - startTrainingTime));
System.out.println("Prediction costs = " + totalPredictionTime);
System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
System.out.println(totalAmount);
System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
}
} finally {
dataCache.destroy();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.knn.ann.ANNClassificationTrainer in project ignite by apache.
the class ANNClassificationTest method testUpdate.
/**
*/
@Test
public void testUpdate() {
Map<Integer, double[]> cacheMock = new HashMap<>();
for (int i = 0; i < twoClusters.length; i++) cacheMock.put(i, twoClusters[i]);
ANNClassificationTrainer trainer = new ANNClassificationTrainer().withK(10).withMaxIterations(10).withEpsilon(1e-4).withDistance(new EuclideanDistance());
ANNClassificationModel originalMdl = (ANNClassificationModel) trainer.fit(cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
ANNClassificationModel updatedOnSameDataset = (ANNClassificationModel) trainer.update(originalMdl, cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
ANNClassificationModel updatedOnEmptyDataset = (ANNClassificationModel) trainer.update(originalMdl, new HashMap<>(), parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
Assert.assertNotNull(updatedOnSameDataset.getCandidates());
assertTrue(updatedOnSameDataset.toString().contains("weighted = [false]"));
assertTrue(updatedOnSameDataset.toString(true).contains("weighted = [false]"));
assertTrue(updatedOnSameDataset.toString(false).contains("weighted = [false]"));
assertNotNull(updatedOnEmptyDataset.getCandidates());
assertTrue(updatedOnEmptyDataset.toString().contains("weighted = [false]"));
assertTrue(updatedOnEmptyDataset.toString(true).contains("weighted = [false]"));
assertTrue(updatedOnEmptyDataset.toString(false).contains("weighted = [false]"));
}
use of org.apache.ignite.ml.knn.ann.ANNClassificationTrainer in project ignite by apache.
the class ANNClassificationTest method testBinaryClassification.
/**
*/
@Test
public void testBinaryClassification() {
Map<Integer, double[]> cacheMock = new HashMap<>();
for (int i = 0; i < twoClusters.length; i++) cacheMock.put(i, twoClusters[i]);
ANNClassificationTrainer trainer = new ANNClassificationTrainer().withK(10).withMaxIterations(10).withEpsilon(1e-4).withDistance(new EuclideanDistance());
Assert.assertEquals(10, trainer.getK());
Assert.assertEquals(10, trainer.getMaxIterations());
TestUtils.assertEquals(1e-4, trainer.getEpsilon(), PRECISION);
Assert.assertEquals(new EuclideanDistance(), trainer.getDistance());
NNClassificationModel mdl = trainer.fit(cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
Assert.assertNotNull(((ANNClassificationModel) mdl).getCandidates());
assertTrue(mdl.toString().contains("weighted = [false]"));
assertTrue(mdl.toString(true).contains("weighted = [false]"));
assertTrue(mdl.toString(false).contains("weighted = [false]"));
}
Aggregations