Search in sources :

Example 1 with StackedDatasetTrainer

use of org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer in project ignite by apache.

the class StackingTest method testSimpleStack.

/**
 * Tests simple stack training.
 */
@Test
public void testSimpleStack() {
    StackedDatasetTrainer<Vector, Vector, Double, LinearRegressionModel, Double> trainer = new StackedDatasetTrainer<>();
    UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG);
    MLPArchitecture arch = new MLPArchitecture(2).withAddedLayer(10, true, Activators.RELU).withAddedLayer(1, false, Activators.SIGMOID);
    MLPTrainer<SimpleGDParameterUpdate> trainer1 = new MLPTrainer<>(arch, LossFunctions.MSE, updatesStgy, 3000, 10, 50, 123L);
    // Convert model trainer to produce Vector -> Vector model
    DatasetTrainer<AdaptableDatasetModel<Vector, Vector, Matrix, Matrix, MultilayerPerceptron>, Double> mlpTrainer = AdaptableDatasetTrainer.of(trainer1).beforeTrainedModel((Vector v) -> new DenseMatrix(v.asArray(), 1)).afterTrainedModel((Matrix mtx) -> mtx.getRow(0)).withConvertedLabels(VectorUtils::num2Arr);
    final double factor = 3;
    StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer.withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor)).addTrainer(mlpTrainer).withAggregatorInputMerger(VectorUtils::concat).withSubmodelOutput2VectorConverter(IgniteFunction.identity()).withVector2SubmodelInputConverter(IgniteFunction.identity()).withOriginalFeaturesKept(IgniteFunction.identity()).withEnvironmentBuilder(TestUtils.testEnvBuilder()).fit(getCacheMock(xor), parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST));
    assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
    assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
    assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
    assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
}
Also used : VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) LinearRegressionModel(org.apache.ignite.ml.regressions.linear.LinearRegressionModel) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) MLPTrainer(org.apache.ignite.ml.nn.MLPTrainer) SimpleGDParameterUpdate(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) LinearRegressionLSQRTrainer(org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Matrix(org.apache.ignite.ml.math.primitives.matrix.Matrix) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) AdaptableDatasetModel(org.apache.ignite.ml.trainers.AdaptableDatasetModel) StackedDatasetTrainer(org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) SmoothParametrized(org.apache.ignite.ml.optimization.SmoothParametrized) UpdatesStrategy(org.apache.ignite.ml.nn.UpdatesStrategy) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Aggregations

TrainerTest (org.apache.ignite.ml.common.TrainerTest)1 StackedDatasetTrainer (org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer)1 DoubleArrayVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer)1 Matrix (org.apache.ignite.ml.math.primitives.matrix.Matrix)1 DenseMatrix (org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix)1 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)1 VectorUtils (org.apache.ignite.ml.math.primitives.vector.VectorUtils)1 MLPTrainer (org.apache.ignite.ml.nn.MLPTrainer)1 UpdatesStrategy (org.apache.ignite.ml.nn.UpdatesStrategy)1 MLPArchitecture (org.apache.ignite.ml.nn.architecture.MLPArchitecture)1 SmoothParametrized (org.apache.ignite.ml.optimization.SmoothParametrized)1 SimpleGDParameterUpdate (org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate)1 SimpleGDUpdateCalculator (org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator)1 LinearRegressionLSQRTrainer (org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer)1 LinearRegressionModel (org.apache.ignite.ml.regressions.linear.LinearRegressionModel)1 AdaptableDatasetModel (org.apache.ignite.ml.trainers.AdaptableDatasetModel)1 Test (org.junit.Test)1