use of com.alibaba.alink.operator.common.fm.FmPredictBatchOp in project Alink by alibaba.
the class FmClassifierTest method testFm.
@Test
public void testFm() {
BatchOperator<?> trainData = new MemSourceBatchOp(new Object[][] { { "1.1 2.0", 1.0 }, { "2.1 3.1", 1.0 }, { "3.1 2.2", 1.0 }, { "1.2 3.2", 0.0 }, { "1.2 4.2", 0.0 } }, new String[] { "vec", "label" });
FmClassifierTrainBatchOp adagrad = new FmClassifierTrainBatchOp().setVectorCol("vec").setLabelCol("label").setNumEpochs(10).setInitStdev(0.01).setLearnRate(0.01).setEpsilon(0.0001).linkFrom(trainData);
adagrad.lazyPrintModelInfo();
adagrad.lazyPrintTrainInfo();
BatchOperator<?> result = new FmPredictBatchOp().setVectorCol("vec").setPredictionCol("pred").setPredictionDetailCol("details").linkFrom(adagrad, trainData);
List<Row> eval = new EvalBinaryClassBatchOp().setLabelCol("label").setPredictionDetailCol("details").linkFrom(result).link(new JsonValueBatchOp().setSelectedCol("Data").setReservedCols(new String[] { "Statistics" }).setOutputCols(new String[] { "Accuracy", "AUC", "ConfusionMatrix" }).setJsonPath("$.Accuracy", "$.AUC", "$.ConfusionMatrix")).collect();
Assert.assertEquals(Double.parseDouble(eval.get(0).getField(0).toString()), 0.6, 0.01);
}
use of com.alibaba.alink.operator.common.fm.FmPredictBatchOp in project Alink by alibaba.
the class FmRegressionTest method testFm.
@Test
public void testFm() {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
BatchOperator<?> trainData = new MemSourceBatchOp(new Object[][] { { "1.1 2.0", 1.0 }, { "2.1 3.1", 1.0 }, { "3.1 2.2", 1.0 }, { "1.2 3.2", 0.0 }, { "1.2 4.2", 0.0 } }, new String[] { "vec", "label" });
FmRegressorTrainBatchOp adagrad = new FmRegressorTrainBatchOp().setVectorCol("vec").setLabelCol("label").setNumEpochs(10).setInitStdev(0.01).setLearnRate(0.01).setEpsilon(0.0001).linkFrom(trainData);
adagrad.lazyPrintModelInfo();
adagrad.lazyPrintTrainInfo();
new FmPredictBatchOp().setVectorCol("vec").setPredictionCol("pred").setPredictionDetailCol("details").linkFrom(adagrad, trainData).collect();
}
use of com.alibaba.alink.operator.common.fm.FmPredictBatchOp in project Alink by alibaba.
the class FmClassifierTest method testFmSparse.
@Test
public void testFmSparse() {
BatchOperator<?> trainData = new MemSourceBatchOp(new Object[][] { { "1:1.1 3:2.0", 1.0 }, { "2:2.1 10:3.1", 1.0 }, { "3:3.1 7:2.2", 1.0 }, { "1:1.2 5:3.2", 0.0 }, { "3:1.2 7:4.2", 0.0 } }, new String[] { "vec", "label" });
FmClassifierTrainBatchOp adagrad = new FmClassifierTrainBatchOp().setVectorCol("vec").setLabelCol("label").setNumEpochs(10).setInitStdev(0.01).setLearnRate(0.01).setEpsilon(0.0001).linkFrom(trainData);
BatchOperator<?> result = new FmPredictBatchOp().setVectorCol("vec").setPredictionCol("pred").setPredictionDetailCol("details").linkFrom(adagrad, trainData);
List<Row> eval = new EvalBinaryClassBatchOp().setLabelCol("label").setPredictionDetailCol("details").linkFrom(result).link(new JsonValueBatchOp().setSelectedCol("Data").setReservedCols(new String[] { "Statistics" }).setOutputCols(new String[] { "Accuracy", "AUC", "ConfusionMatrix" }).setJsonPath("$.Accuracy", "$.AUC", "$.ConfusionMatrix")).collect();
Assert.assertEquals(Double.parseDouble(eval.get(0).getField(0).toString()), 0.8, 0.01);
}
Aggregations