use of com.alibaba.alink.operator.batch.regression.LinearRegTrainBatchOp in project Alink by alibaba.
the class Chap01 method c_5_2.
static void c_5_2() throws Exception {
BatchOperator<?> train_set = new MemSourceBatchOp(new Row[] { Row.of(2009, 0.5), Row.of(2010, 9.36), Row.of(2011, 52.0), Row.of(2012, 191.0), Row.of(2013, 350.0), Row.of(2014, 571.0), Row.of(2015, 912.0), Row.of(2016, 1207.0), Row.of(2017, 1682.0) }, new String[] { "x", "gmv" });
BatchOperator<?> pred_set = new MemSourceBatchOp(new Integer[] { 2018, 2019 }, "x");
train_set = train_set.select("x, x*x AS x2, gmv");
LinearRegTrainBatchOp trainer = new LinearRegTrainBatchOp().setFeatureCols("x", "x2").setLabelCol("gmv");
train_set.link(trainer);
trainer.link(new AkSinkBatchOp().setFilePath(DATA_DIR + "gmv_reg.model").setOverwriteSink(true));
BatchOperator.execute();
BatchOperator<?> lr_model = new AkSourceBatchOp().setFilePath(DATA_DIR + "gmv_reg.model");
pred_set = pred_set.select("x, x*x AS x2");
LinearRegPredictBatchOp predictor = new LinearRegPredictBatchOp().setPredictionCol("pred");
predictor.linkFrom(lr_model, pred_set).print();
}
use of com.alibaba.alink.operator.batch.regression.LinearRegTrainBatchOp in project Alink by alibaba.
the class LinearRegTest method batchDenseSparseVectorTest.
@Test
public void batchDenseSparseVectorTest() {
Row[] localRows = new Row[] { Row.of(0, "1.0 0.0 7.0 0.0 9.0 .0 .0 .0 .0 .0 .0 .0 .0 .0 .0", "1.0 7.0 9.0", 1.0, 7.0, 9.0, 2), Row.of(1, "0:1.0 2:3.0 4:3.0", "1.0 3.0 3.0", 1.0, 3.0, 3.0, 3), Row.of(2, "0:1.0 2:2.0 4:4.0", "1.0 2.0 4.0", 1.0, 2.0, 4.0, 1), Row.of(3, "0:1.0 2:3.0 14:3.0", "1.0 3.0 3.0", 1.0, 3.0, 3.0, 4), Row.of(4, "0:1.0 2:2.0 4:4.0", "1.0 2.0 4.0", 1.0, 2.0, 4.0, 5), Row.of(5, "0:1.0 2:2.0 4:4.0", "1.0 2.0 4.0", 1.0, 2.0, 4.0, 6), Row.of(6, "0:1.0 2:2.0 4:4.0", "1.0 2.0 4.0", 1.0, 2.0, 4.0, 7), Row.of(7, "1.0 0.0 2.0 0.0 4.0 .0 .0 .0 .0 .0 .0 .0 .0 .0 .0", "1.0 2.0 4.0", 1.0, 2.0, 4.0, 1) };
String[] veccolNames = new String[] { "id", "svec", "vec", "f0", "f1", "f2", "label" };
BatchOperator<?> trainData = new MemSourceBatchOp(Arrays.asList(localRows), veccolNames);
String labelColName = "label";
LinearRegTrainBatchOp lr = new LinearRegTrainBatchOp().setVectorCol("svec").setStandardization(true).setWithIntercept(true).setEpsilon(1.0e-4).setOptimMethod("LBFGS").setLabelCol(labelColName).setMaxIter(10);
LinearRegTrainBatchOp model = lr.linkFrom(trainData);
List<Row> mixedResult = new LinearRegPredictBatchOp().setReservedCols(new String[] { "id" }).setPredictionCol("predLr").setVectorCol("svec").linkFrom(model, trainData).collect();
for (Row row : mixedResult) {
if ((int) row.getField(0) == 0) {
Assert.assertEquals(Double.parseDouble(row.getField(1).toString()), 1.9404, 0.001);
}
}
}
use of com.alibaba.alink.operator.batch.regression.LinearRegTrainBatchOp in project Alink by alibaba.
the class Chap15 method main.
public static void main(String[] args) throws Exception {
BatchOperator.setParallelism(1);
CsvSourceBatchOp source = new CsvSourceBatchOp().setFilePath(DATA_DIR + ORIGIN_FILE).setSchemaStr("father double, son double").setFieldDelimiter("\t").setIgnoreFirstLine(true);
source.firstN(5).print();
source.lazyPrintStatistics();
source.filter("father>=71.5 AND father<72.5").lazyPrintStatistics("father 72");
source.filter("father>=64.5 AND father<65.5").lazyPrintStatistics("father 65");
LinearRegTrainBatchOp linear_model = new LinearRegTrainBatchOp().setFeatureCols("father").setLabelCol("son").linkFrom(source);
linear_model.lazyPrintTrainInfo();
linear_model.lazyPrintModelInfo();
LinearRegPredictBatchOp linear_reg = new LinearRegPredictBatchOp().setPredictionCol("linear_reg").linkFrom(linear_model, source);
linear_reg.lazyPrint(5);
BatchOperator.execute();
}
Aggregations