Search in sources :

Example 1 with FtrlLearningKernel

use of com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp.FtrlLearningKernel in project Alink by alibaba.

the class OnlineLearningTest method Test.

@Test
public void Test() throws Exception {
    String[] xVars = new String[] { "f0", "f1", "f2", "f3" };
    String yVar = "labels";
    BatchOperator trainData = (BatchOperator) getData(true);
    LogisticRegressionTrainBatchOp lr = new LogisticRegressionTrainBatchOp().setLabelCol(yVar).setFeatureCols(xVars).setOptimMethod("lbfgs").linkFrom(trainData);
    FtrlTrainStreamOp ftrl = new FtrlTrainStreamOp(lr).setAlpha(0.1).setBeta(0.1).setL1(0.1).setL2(0.1).setFeatureCols(xVars).setLabelCol(yVar).setTimeInterval(1).setWithIntercept(false);
    FtrlLearningKernel kernel = new FtrlLearningKernel();
    kernel.setModelParams(new Params(), 2, new Object[] { 1, 0 });
    kernel.calcLocalWx(new double[] { 1, 2 }, new DenseVector(2), 0);
    kernel.getFeedbackVar(new double[] { 1, 2 });
    double[] coef = new double[] { 2.0, 3.0 };
    kernel.updateModel(coef, new DenseVector(2), new double[] { 1, 1 }, 1L, 0, 0);
    SparseVector svec = new SparseVector(2);
    svec.add(0, 1);
    svec.add(1, 1);
    kernel.updateModel(coef, svec, new double[] { 1, 1 }, 1L, 0, 0);
    ftrl.setLearningKernel(kernel);
    Assert.assertEquals(coef[0], -0.08761006569007045, 0.0001);
    Assert.assertEquals(coef[1], -0.08761006569007045, 0.0001);
    FtrlTrainStreamOp ftrlw = new FtrlTrainStreamOp(lr, new Params()).setAlpha(0.1).setBeta(0.1).setL1(0.1).setL2(0.1).setFeatureCols(xVars).setLabelCol(yVar).setTimeInterval(1).setWithIntercept(false);
    FtrlPredictStreamOp pred = new FtrlPredictStreamOp(lr).setPredictionCol("pred").setVectorCol("vec");
    FtrlPredictStreamOp predp = new FtrlPredictStreamOp(lr, new Params()).setPredictionCol("pred").setVectorCol("vec");
}
Also used : LogisticRegressionTrainBatchOp(com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp) FtrlLearningKernel(com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp.FtrlLearningKernel) Params(org.apache.flink.ml.api.misc.param.Params) SparseVector(com.alibaba.alink.common.linalg.SparseVector) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) DenseVector(com.alibaba.alink.common.linalg.DenseVector) Test(org.junit.Test)

Aggregations

DenseVector (com.alibaba.alink.common.linalg.DenseVector)1 SparseVector (com.alibaba.alink.common.linalg.SparseVector)1 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)1 LogisticRegressionTrainBatchOp (com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp)1 FtrlLearningKernel (com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp.FtrlLearningKernel)1 Params (org.apache.flink.ml.api.misc.param.Params)1 Test (org.junit.Test)1