use of org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer in project ignite by apache.
the class Step_11_Boosting method main.
/**
* Run example.
*/
public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 11 (Boosting) example started.");
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
try {
IgniteCache<Integer, Vector> dataCache = TitanicUtils.readPassengers(ignite);
// Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare".
final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 3, 4, 5, 6, 8, 10).labeled(1);
TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>().split(0.75);
Preprocessor<Integer, Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Vector>().withEncoderType(EncoderType.STRING_ENCODER).withEncodedFeature(1).withEncodedFeature(// <--- Changed index here.
6).fit(ignite, dataCache, vectorizer);
Preprocessor<Integer, Vector> imputingPreprocessor = new ImputerTrainer<Integer, Vector>().fit(ignite, dataCache, strEncoderPreprocessor);
Preprocessor<Integer, Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Vector>().fit(ignite, dataCache, imputingPreprocessor);
Preprocessor<Integer, Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Vector>().withP(1).fit(ignite, dataCache, minMaxScalerPreprocessor);
// Create classification trainer.
GDBTrainer trainer = new GDBBinaryClassifierOnTreesTrainer(0.5, 500, 4, 0.).withCheckConvergenceStgyFactory(new MedianOfMedianConvergenceCheckerFactory(0.1));
// Train decision tree model.
GDBModel mdl = trainer.fit(ignite, dataCache, split.getTrainFilter(), normalizationPreprocessor);
System.out.println("\n>>> Trained model: " + mdl.toString(true));
double accuracy = Evaluator.evaluate(dataCache, split.getTestFilter(), mdl, normalizationPreprocessor, MetricName.ACCURACY);
System.out.println("\n>>> Accuracy " + accuracy);
System.out.println("\n>>> Test Error " + (1 - accuracy));
System.out.println(">>> Tutorial step 11 (Boosting) example completed.");
} catch (FileNotFoundException e) {
e.printStackTrace();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer in project ignite by apache.
the class Step_7_Split_train_test method main.
/**
* Run example.
*/
public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 7 (split to train and test) example started.");
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
try {
IgniteCache<Integer, Vector> dataCache = TitanicUtils.readPassengers(ignite);
// Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare".
final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 3, 4, 5, 6, 8, 10).labeled(1);
TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>().split(0.75);
Preprocessor<Integer, Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Vector>().withEncoderType(EncoderType.STRING_ENCODER).withEncodedFeature(1).withEncodedFeature(// <--- Changed index here.
6).fit(ignite, dataCache, vectorizer);
Preprocessor<Integer, Vector> imputingPreprocessor = new ImputerTrainer<Integer, Vector>().fit(ignite, dataCache, strEncoderPreprocessor);
Preprocessor<Integer, Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Vector>().fit(ignite, dataCache, imputingPreprocessor);
Preprocessor<Integer, Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Vector>().withP(1).fit(ignite, dataCache, minMaxScalerPreprocessor);
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
// Train decision tree model.
DecisionTreeModel mdl = trainer.fit(ignite, dataCache, split.getTrainFilter(), normalizationPreprocessor);
System.out.println("\n>>> Trained model: " + mdl);
double accuracy = Evaluator.evaluate(dataCache, split.getTestFilter(), mdl, normalizationPreprocessor, new Accuracy<>());
System.out.println("\n>>> Accuracy " + accuracy);
System.out.println("\n>>> Test Error " + (1 - accuracy));
System.out.println(">>> Tutorial step 7 (split to train and test) example completed.");
} catch (FileNotFoundException e) {
e.printStackTrace();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer in project ignite by apache.
the class Step_8_CV method main.
/**
* Run example.
*/
public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 8 (cross-validation) example started.");
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
try {
IgniteCache<Integer, Vector> dataCache = TitanicUtils.readPassengers(ignite);
// Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare".
final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 3, 4, 5, 6, 8, 10).labeled(1);
TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>().split(0.75);
Preprocessor<Integer, Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Vector>().withEncoderType(EncoderType.STRING_ENCODER).withEncodedFeature(1).withEncodedFeature(// <--- Changed index here.
6).fit(ignite, dataCache, vectorizer);
Preprocessor<Integer, Vector> imputingPreprocessor = new ImputerTrainer<Integer, Vector>().fit(ignite, dataCache, strEncoderPreprocessor);
Preprocessor<Integer, Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Vector>().fit(ignite, dataCache, imputingPreprocessor);
// Tune hyper-parameters with K-fold Cross-Validation on the split training set.
int[] pSet = new int[] { 1, 2 };
int[] maxDeepSet = new int[] { 1, 2, 3, 4, 5, 10, 20 };
int bestP = 1;
int bestMaxDeep = 1;
double avg = Double.MIN_VALUE;
for (int p : pSet) {
for (int maxDeep : maxDeepSet) {
Preprocessor<Integer, Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Vector>().withP(p).fit(ignite, dataCache, minMaxScalerPreprocessor);
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(maxDeep, 0);
CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator = new CrossValidation<>();
double[] scores = scoreCalculator.withIgnite(ignite).withUpstreamCache(dataCache).withTrainer(trainer).withMetric(MetricName.ACCURACY).withFilter(split.getTrainFilter()).withPreprocessor(normalizationPreprocessor).withAmountOfFolds(3).isRunningOnPipeline(false).scoreByFolds();
System.out.println("Scores are: " + Arrays.toString(scores));
final double currAvg = Arrays.stream(scores).average().orElse(Double.MIN_VALUE);
if (currAvg > avg) {
avg = currAvg;
bestP = p;
bestMaxDeep = maxDeep;
}
System.out.println("Avg is: " + currAvg + " with p: " + p + " with maxDeep: " + maxDeep);
}
}
System.out.println("Train with p: " + bestP + " and maxDeep: " + bestMaxDeep);
Preprocessor<Integer, Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Vector>().withP(bestP).fit(ignite, dataCache, minMaxScalerPreprocessor);
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(bestMaxDeep, 0);
// Train decision tree model.
DecisionTreeModel bestMdl = trainer.fit(ignite, dataCache, split.getTrainFilter(), normalizationPreprocessor);
System.out.println("\n>>> Trained model: " + bestMdl);
double accuracy = Evaluator.evaluate(dataCache, split.getTestFilter(), bestMdl, normalizationPreprocessor, new Accuracy<>());
System.out.println("\n>>> Accuracy " + accuracy);
System.out.println("\n>>> Test Error " + (1 - accuracy));
System.out.println(">>> Tutorial step 8 (cross-validation) example completed.");
} catch (FileNotFoundException e) {
e.printStackTrace();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer in project ignite by apache.
the class SVMMultiClassClassificationExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws InterruptedException {
System.out.println();
System.out.println(">>> SVM Multi-class classification model over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), SVMMultiClassClassificationExample.class.getSimpleName(), () -> {
IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
SVMLinearMultiClassClassificationTrainer<Integer, double[]> trainer = new SVMLinearMultiClassClassificationTrainer<>();
SVMLinearMultiClassClassificationModel mdl = trainer.fit(new CacheBasedDatasetBuilder<>(ignite, dataCache), (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0], 5);
System.out.println(">>> SVM Multi-class model");
System.out.println(mdl.toString());
NormalizationTrainer<Integer, double[]> normalizationTrainer = new NormalizationTrainer<>();
NormalizationPreprocessor<Integer, double[]> preprocessor = normalizationTrainer.fit(new CacheBasedDatasetBuilder<>(ignite, dataCache), (k, v) -> Arrays.copyOfRange(v, 1, v.length), 5);
SVMLinearMultiClassClassificationModel mdlWithNormalization = trainer.fit(new CacheBasedDatasetBuilder<>(ignite, dataCache), preprocessor, (k, v) -> v[0], 5);
System.out.println(">>> SVM Multi-class model with normalization");
System.out.println(mdlWithNormalization.toString());
System.out.println(">>> ----------------------------------------------------------------");
System.out.println(">>> | Prediction\t| Prediction with Normalization\t| Ground Truth\t|");
System.out.println(">>> ----------------------------------------------------------------");
int amountOfErrors = 0;
int amountOfErrorsWithNormalization = 0;
int totalAmount = 0;
// Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
int[][] confusionMtx = { { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } };
int[][] confusionMtxWithNormalization = { { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } };
try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, double[]> observation : observations) {
double[] val = observation.getValue();
double[] inputs = Arrays.copyOfRange(val, 1, val.length);
double groundTruth = val[0];
double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
double predictionWithNormalization = mdlWithNormalization.apply(new DenseLocalOnHeapVector(inputs));
totalAmount++;
// Collect data for model
if (groundTruth != prediction)
amountOfErrors++;
int idx1 = (int) prediction == 1 ? 0 : ((int) prediction == 3 ? 1 : 2);
int idx2 = (int) groundTruth == 1 ? 0 : ((int) groundTruth == 3 ? 1 : 2);
confusionMtx[idx1][idx2]++;
// Collect data for model with normalization
if (groundTruth != predictionWithNormalization)
amountOfErrorsWithNormalization++;
idx1 = (int) predictionWithNormalization == 1 ? 0 : ((int) predictionWithNormalization == 3 ? 1 : 2);
idx2 = (int) groundTruth == 1 ? 0 : ((int) groundTruth == 3 ? 1 : 2);
confusionMtxWithNormalization[idx1][idx2]++;
System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithNormalization, groundTruth);
}
System.out.println(">>> ----------------------------------------------------------------");
System.out.println("\n>>> -----------------SVM model-------------");
System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
System.out.println("\n>>> -----------------SVM model with Normalization-------------");
System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithNormalization);
System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithNormalization / (double) totalAmount));
System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithNormalization));
}
});
igniteThread.start();
igniteThread.join();
}
}
use of org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer in project ignite by apache.
the class NormalizationExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws Exception {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Normalization example started.");
IgniteCache<Integer, Vector> data = null;
try {
data = createCache(ignite);
Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(1, 2);
// Defines second preprocessor that normalizes features.
Preprocessor<Integer, Vector> preprocessor = new NormalizationTrainer<Integer, Vector>().withP(1).fit(ignite, data, vectorizer);
// Creates a cache based simple dataset containing features and providing standard dataset API.
try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, data, preprocessor)) {
new DatasetHelper(dataset).describe();
}
System.out.println(">>> Normalization example completed.");
} finally {
data.destroy();
}
} finally {
System.out.flush();
}
}
Aggregations