use of com.alibaba.alink.operator.batch.dataproc.JsonValueBatchOp in project Alink by alibaba.
the class SoftmaxTest method batchVectorTest.
@Test
public void batchVectorTest() {
BatchOperator<?> trainData = new MemSourceBatchOp(Arrays.asList(vecrows), veccolNames);
String labelColName = "label";
SoftmaxTrainBatchOp lr = new SoftmaxTrainBatchOp().setVectorCol("svec").setStandardization(false).setWithIntercept(true).setEpsilon(1.0e-4).setOptimMethod("LBFGS").setLabelCol(labelColName).setMaxIter(10);
SoftmaxTrainBatchOp model = lr.linkFrom(trainData);
List<Row> acc = new SoftmaxPredictBatchOp().setPredictionCol("predLr").setVectorCol("svec").setPredictionDetailCol("predDetail").linkFrom(model, trainData).link(new EvalMultiClassBatchOp().setLabelCol("predLr").setPredictionDetailCol("predDetail")).link(new JsonValueBatchOp().setSelectedCol("Data").setReservedCols(new String[] { "Statistics" }).setOutputCols(new String[] { "Accuracy" }).setJsonPath("$.Accuracy")).collect();
Assert.assertEquals(Double.parseDouble(acc.get(0).getField(0).toString()), 1.0, 0.001);
}
use of com.alibaba.alink.operator.batch.dataproc.JsonValueBatchOp in project Alink by alibaba.
the class FmClassifierTest method testFm.
@Test
public void testFm() {
BatchOperator<?> trainData = new MemSourceBatchOp(new Object[][] { { "1.1 2.0", 1.0 }, { "2.1 3.1", 1.0 }, { "3.1 2.2", 1.0 }, { "1.2 3.2", 0.0 }, { "1.2 4.2", 0.0 } }, new String[] { "vec", "label" });
FmClassifierTrainBatchOp adagrad = new FmClassifierTrainBatchOp().setVectorCol("vec").setLabelCol("label").setNumEpochs(10).setInitStdev(0.01).setLearnRate(0.01).setEpsilon(0.0001).linkFrom(trainData);
adagrad.lazyPrintModelInfo();
adagrad.lazyPrintTrainInfo();
BatchOperator<?> result = new FmPredictBatchOp().setVectorCol("vec").setPredictionCol("pred").setPredictionDetailCol("details").linkFrom(adagrad, trainData);
List<Row> eval = new EvalBinaryClassBatchOp().setLabelCol("label").setPredictionDetailCol("details").linkFrom(result).link(new JsonValueBatchOp().setSelectedCol("Data").setReservedCols(new String[] { "Statistics" }).setOutputCols(new String[] { "Accuracy", "AUC", "ConfusionMatrix" }).setJsonPath("$.Accuracy", "$.AUC", "$.ConfusionMatrix")).collect();
Assert.assertEquals(Double.parseDouble(eval.get(0).getField(0).toString()), 0.6, 0.01);
}
use of com.alibaba.alink.operator.batch.dataproc.JsonValueBatchOp in project Alink by alibaba.
the class SoftmaxTest method batchTableTest.
@Test
public void batchTableTest() {
BatchOperator<?> trainData = new MemSourceBatchOp(Arrays.asList(vecrows), veccolNames);
String labelColName = "label";
SoftmaxTrainBatchOp lr = new SoftmaxTrainBatchOp().setVectorCol("svec").setStandardization(true).setWithIntercept(true).setEpsilon(1.0e-4).setL1(0.0001).setOptimMethod("lbfgs").setLabelCol(labelColName).setMaxIter(20);
SoftmaxTrainBatchOp model = lr.linkFrom(trainData);
model.lazyPrintTrainInfo();
model.lazyPrintModelInfo();
List<Row> acc = new SoftmaxPredictBatchOp().setPredictionCol("predLr").setPredictionDetailCol("predDetail").linkFrom(model, trainData).link(new EvalMultiClassBatchOp().setLabelCol("predLr").setPredictionDetailCol("predDetail")).link(new JsonValueBatchOp().setSelectedCol("Data").setReservedCols(new String[] { "Statistics" }).setOutputCols(new String[] { "Accuracy" }).setJsonPath("$.Accuracy")).collect();
Assert.assertEquals(Double.parseDouble(acc.get(0).getField(0).toString()), 1.0, 0.001);
}
use of com.alibaba.alink.operator.batch.dataproc.JsonValueBatchOp in project Alink by alibaba.
the class FmClassifierTest method testFmSparse.
@Test
public void testFmSparse() {
BatchOperator<?> trainData = new MemSourceBatchOp(new Object[][] { { "1:1.1 3:2.0", 1.0 }, { "2:2.1 10:3.1", 1.0 }, { "3:3.1 7:2.2", 1.0 }, { "1:1.2 5:3.2", 0.0 }, { "3:1.2 7:4.2", 0.0 } }, new String[] { "vec", "label" });
FmClassifierTrainBatchOp adagrad = new FmClassifierTrainBatchOp().setVectorCol("vec").setLabelCol("label").setNumEpochs(10).setInitStdev(0.01).setLearnRate(0.01).setEpsilon(0.0001).linkFrom(trainData);
BatchOperator<?> result = new FmPredictBatchOp().setVectorCol("vec").setPredictionCol("pred").setPredictionDetailCol("details").linkFrom(adagrad, trainData);
List<Row> eval = new EvalBinaryClassBatchOp().setLabelCol("label").setPredictionDetailCol("details").linkFrom(result).link(new JsonValueBatchOp().setSelectedCol("Data").setReservedCols(new String[] { "Statistics" }).setOutputCols(new String[] { "Accuracy", "AUC", "ConfusionMatrix" }).setJsonPath("$.Accuracy", "$.AUC", "$.ConfusionMatrix")).collect();
Assert.assertEquals(Double.parseDouble(eval.get(0).getField(0).toString()), 0.8, 0.01);
}
Aggregations