use of org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer in project ignite by apache.
the class StackingTest method testSimpleStack.
/**
* Tests simple stack training.
*/
@Test
public void testSimpleStack() {
StackedDatasetTrainer<Vector, Vector, Double, LinearRegressionModel, Double> trainer = new StackedDatasetTrainer<>();
UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG);
MLPArchitecture arch = new MLPArchitecture(2).withAddedLayer(10, true, Activators.RELU).withAddedLayer(1, false, Activators.SIGMOID);
MLPTrainer<SimpleGDParameterUpdate> trainer1 = new MLPTrainer<>(arch, LossFunctions.MSE, updatesStgy, 3000, 10, 50, 123L);
// Convert model trainer to produce Vector -> Vector model
DatasetTrainer<AdaptableDatasetModel<Vector, Vector, Matrix, Matrix, MultilayerPerceptron>, Double> mlpTrainer = AdaptableDatasetTrainer.of(trainer1).beforeTrainedModel((Vector v) -> new DenseMatrix(v.asArray(), 1)).afterTrainedModel((Matrix mtx) -> mtx.getRow(0)).withConvertedLabels(VectorUtils::num2Arr);
final double factor = 3;
StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer.withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor)).addTrainer(mlpTrainer).withAggregatorInputMerger(VectorUtils::concat).withSubmodelOutput2VectorConverter(IgniteFunction.identity()).withVector2SubmodelInputConverter(IgniteFunction.identity()).withOriginalFeaturesKept(IgniteFunction.identity()).withEnvironmentBuilder(TestUtils.testEnvBuilder()).fit(getCacheMock(xor), parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST));
assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
}
Aggregations