use of org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer in project ignite by apache.
the class SVMBinaryClassificationExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws InterruptedException {
System.out.println();
System.out.println(">>> SVM Binary classification model over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), SVMBinaryClassificationExample.class.getSimpleName(), () -> {
IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
SVMLinearBinaryClassificationTrainer<Integer, double[]> trainer = new SVMLinearBinaryClassificationTrainer<>();
SVMLinearBinaryClassificationModel mdl = trainer.fit(new CacheBasedDatasetBuilder<>(ignite, dataCache), (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0], 4);
System.out.println(">>> SVM model " + mdl);
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
int amountOfErrors = 0;
int totalAmount = 0;
// Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
int[][] confusionMtx = { { 0, 0 }, { 0, 0 } };
try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, double[]> observation : observations) {
double[] val = observation.getValue();
double[] inputs = Arrays.copyOfRange(val, 1, val.length);
double groundTruth = val[0];
double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
totalAmount++;
if (groundTruth != prediction)
amountOfErrors++;
int idx1 = (int) prediction == -1.0 ? 0 : 1;
int idx2 = (int) groundTruth == -1.0 ? 0 : 1;
confusionMtx[idx1][idx2]++;
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
System.out.println(">>> ---------------------------------");
System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
}
System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
});
igniteThread.start();
igniteThread.join();
}
}
Aggregations