Search in sources :

Example 1 with StackedVectorDatasetTrainer

use of org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer in project ignite by apache.

the class Step_9_Scaling_With_Stacking method main.

/**
 * Run example.
 */
public static void main(String[] args) {
    System.out.println();
    System.out.println(">>> Tutorial step 9 (scaling with stacking) example started.");
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        try {
            IgniteCache<Integer, Vector> dataCache = TitanicUtils.readPassengers(ignite);
            // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare".
            final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 3, 4, 5, 6, 8, 10).labeled(1);
            TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>().split(0.75);
            Preprocessor<Integer, Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Vector>().withEncoderType(EncoderType.STRING_ENCODER).withEncodedFeature(1).withEncodedFeature(// <--- Changed index here.
            6).fit(ignite, dataCache, vectorizer);
            Preprocessor<Integer, Vector> imputingPreprocessor = new ImputerTrainer<Integer, Vector>().fit(ignite, dataCache, strEncoderPreprocessor);
            Preprocessor<Integer, Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Vector>().fit(ignite, dataCache, imputingPreprocessor);
            Preprocessor<Integer, Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Vector>().withP(1).fit(ignite, dataCache, minMaxScalerPreprocessor);
            DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
            DecisionTreeClassificationTrainer trainer1 = new DecisionTreeClassificationTrainer(3, 0);
            DecisionTreeClassificationTrainer trainer2 = new DecisionTreeClassificationTrainer(4, 0);
            LogisticRegressionSGDTrainer aggregator = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG));
            StackedModel<Vector, Vector, Double, LogisticRegressionModel> mdl = new StackedVectorDatasetTrainer<>(aggregator).addTrainerWithDoubleOutput(trainer).addTrainerWithDoubleOutput(trainer1).addTrainerWithDoubleOutput(trainer2).fit(ignite, dataCache, split.getTrainFilter(), normalizationPreprocessor);
            System.out.println("\n>>> Trained model: " + mdl);
            double accuracy = Evaluator.evaluate(dataCache, split.getTestFilter(), mdl, normalizationPreprocessor, new Accuracy<>());
            System.out.println("\n>>> Accuracy " + accuracy);
            System.out.println("\n>>> Test Error " + (1 - accuracy));
            System.out.println(">>> Tutorial step 9 (scaling with stacking) example completed.");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }
    } finally {
        System.out.flush();
    }
}
Also used : LogisticRegressionSGDTrainer(org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer) FileNotFoundException(java.io.FileNotFoundException) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) NormalizationTrainer(org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer) StackedVectorDatasetTrainer(org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer) DecisionTreeClassificationTrainer(org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) Ignite(org.apache.ignite.Ignite) EncoderTrainer(org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 2 with StackedVectorDatasetTrainer

use of org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer in project ignite by apache.

the class StackingTest method testSimpleVectorStack.

/**
 * Tests simple stack training.
 */
@Test
public void testSimpleVectorStack() {
    StackedVectorDatasetTrainer<Double, LinearRegressionModel, Double> trainer = new StackedVectorDatasetTrainer<>();
    UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG);
    MLPArchitecture arch = new MLPArchitecture(2).withAddedLayer(10, true, Activators.RELU).withAddedLayer(1, false, Activators.SIGMOID);
    DatasetTrainer<MultilayerPerceptron, Double> mlpTrainer = new MLPTrainer<>(arch, LossFunctions.MSE, updatesStgy, 3000, 10, 50, 123L).withConvertedLabels(VectorUtils::num2Arr);
    final double factor = 3;
    StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer.withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor)).addMatrix2MatrixTrainer(mlpTrainer).withEnvironmentBuilder(TestUtils.testEnvBuilder()).fit(getCacheMock(xor), parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST));
    assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
    assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
    assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
    assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
}
Also used : VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) LinearRegressionModel(org.apache.ignite.ml.regressions.linear.LinearRegressionModel) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) StackedVectorDatasetTrainer(org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer) SimpleGDParameterUpdate(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate) MultilayerPerceptron(org.apache.ignite.ml.nn.MultilayerPerceptron) LinearRegressionLSQRTrainer(org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) SmoothParametrized(org.apache.ignite.ml.optimization.SmoothParametrized) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) UpdatesStrategy(org.apache.ignite.ml.nn.UpdatesStrategy) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Aggregations

StackedVectorDatasetTrainer (org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer)2 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)2 SimpleGDUpdateCalculator (org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator)2 FileNotFoundException (java.io.FileNotFoundException)1 Ignite (org.apache.ignite.Ignite)1 TrainerTest (org.apache.ignite.ml.common.TrainerTest)1 DoubleArrayVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer)1 VectorUtils (org.apache.ignite.ml.math.primitives.vector.VectorUtils)1 MultilayerPerceptron (org.apache.ignite.ml.nn.MultilayerPerceptron)1 UpdatesStrategy (org.apache.ignite.ml.nn.UpdatesStrategy)1 MLPArchitecture (org.apache.ignite.ml.nn.architecture.MLPArchitecture)1 SmoothParametrized (org.apache.ignite.ml.optimization.SmoothParametrized)1 SimpleGDParameterUpdate (org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate)1 EncoderTrainer (org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer)1 NormalizationTrainer (org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer)1 LinearRegressionLSQRTrainer (org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer)1 LinearRegressionModel (org.apache.ignite.ml.regressions.linear.LinearRegressionModel)1 LogisticRegressionModel (org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel)1 LogisticRegressionSGDTrainer (org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer)1 DecisionTreeClassificationTrainer (org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer)1