use of com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult in project Alink by alibaba.
the class CorrelationBatchOp method linkFrom.
@Override
public CorrelationBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String[] selectedColNames = this.getParams().get(SELECTED_COLS);
if (selectedColNames == null) {
selectedColNames = in.getColNames();
}
// check col types must be double or bigint
TableUtil.assertNumericalCols(in.getSchema(), selectedColNames);
Method corrType = getMethod();
if (Method.PEARSON == corrType) {
DataSet<Tuple2<TableSummary, CorrelationResult>> srt = StatisticsHelper.pearsonCorrelation(in, selectedColNames);
DataSet<Row> result = srt.flatMap(new FlatMapFunction<Tuple2<TableSummary, CorrelationResult>, Row>() {
private static final long serialVersionUID = -4498296161046449646L;
@Override
public void flatMap(Tuple2<TableSummary, CorrelationResult> summary, Collector<Row> collector) {
new CorrelationDataConverter().save(summary.f1, collector);
}
});
this.setOutput(result, new CorrelationDataConverter().getModelSchema());
} else {
DataSet<Row> data = inputs[0].select(selectedColNames).getDataSet();
DataSet<Row> rank = SpearmanCorrelation.calcRank(data, false);
TypeInformation[] colTypes = new TypeInformation[selectedColNames.length];
for (int i = 0; i < colTypes.length; i++) {
colTypes[i] = Types.DOUBLE;
}
BatchOperator rankOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), rank, selectedColNames, colTypes)).setMLEnvironmentId(getMLEnvironmentId());
CorrelationBatchOp corrBatchOp = new CorrelationBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCols(selectedColNames);
rankOp.link(corrBatchOp);
this.setOutput(corrBatchOp.getDataSet(), corrBatchOp.getSchema());
}
return this;
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult in project Alink by alibaba.
the class VectorCorrelationBatchOp method linkFrom.
@Override
public VectorCorrelationBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String vectorColName = getSelectedCol();
Method corrType = getMethod();
if (Method.PEARSON == corrType) {
DataSet<Tuple2<BaseVectorSummary, CorrelationResult>> srt = StatisticsHelper.vectorPearsonCorrelation(in, vectorColName);
// block
DataSet<Row> result = srt.flatMap(new FlatMapFunction<Tuple2<BaseVectorSummary, CorrelationResult>, Row>() {
private static final long serialVersionUID = 2134644397476490118L;
@Override
public void flatMap(Tuple2<BaseVectorSummary, CorrelationResult> srt, Collector<Row> collector) throws Exception {
new CorrelationDataConverter().save(srt.f1, collector);
}
});
this.setOutput(result, new CorrelationDataConverter().getModelSchema());
} else {
DataSet<Row> data = StatisticsHelper.transformToColumns(in, null, vectorColName, null);
DataSet<Row> rank = SpearmanCorrelation.calcRank(data, true);
BatchOperator rankOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), rank, new String[] { "col" }, new TypeInformation[] { Types.STRING })).setMLEnvironmentId(getMLEnvironmentId());
VectorCorrelationBatchOp corrBatchOp = new VectorCorrelationBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCol("col");
rankOp.link(corrBatchOp);
this.setOutput(corrBatchOp.getDataSet(), corrBatchOp.getSchema());
}
return this;
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult in project Alink by alibaba.
the class CorrelationBatchOpTest method testSpearman.
@Test
public void testSpearman() throws Exception {
Row[] testArray = new Row[] { Row.of("a", 1L, 1, 2.0, true), Row.of(null, 2L, 2, -3.0, true), Row.of("c", null, null, 2.0, false), Row.of("a", 0L, 0, null, null) };
String[] colNames = new String[] { "f_string", "f_long", "f_int", "f_double", "f_boolean" };
MemSourceBatchOp source = new MemSourceBatchOp(Arrays.asList(testArray), colNames);
CorrelationBatchOp corr = new CorrelationBatchOp().setSelectedCols(new String[] { "f_double", "f_int", "f_long" }).setMethod(HasMethod.Method.SPEARMAN);
corr.linkFrom(source);
corr.lazyCollectCorrelation(new Consumer<CorrelationResult>() {
@Override
public void accept(CorrelationResult summary) {
Assert.assertArrayEquals(summary.getCorrelationMatrix().getArrayCopy1D(true), new double[] { 1.0, -0.39999999999999997, -0.39999999999999997, -0.39999999999999997, 1.0, 1.0, -0.39999999999999997, 1.0, 1.0 }, 10e-4);
}
});
BatchOperator.execute();
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult in project Alink by alibaba.
the class VectorCorrelationBatchOpTest method testSpearman.
@Test
public void testSpearman() {
Row[] testArray = new Row[] { Row.of("1.0 2.0"), Row.of("-1.0 -3.0"), Row.of("4.0 2.0") };
String selectedColName = "vec";
String[] colNames = new String[] { selectedColName };
MemSourceBatchOp source = new MemSourceBatchOp(Arrays.asList(testArray), colNames);
VectorCorrelationBatchOp corr = new VectorCorrelationBatchOp().setSelectedCol("vec").setMethod(HasMethod.Method.SPEARMAN);
corr.linkFrom(source);
CorrelationResult corrMat = corr.collectCorrelation();
Assert.assertArrayEquals(corrMat.getCorrelationMatrix().getArrayCopy1D(true), new double[] { 1.0, 1.0, 1.0, 1.0 }, 10e-4);
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult in project Alink by alibaba.
the class CorrelationBatchOpTest method testCorrelation.
@Test
public void testCorrelation() {
Row[] testArray = new Row[] { Row.of("a", 1L, 1, 2.0, true), Row.of(null, 2L, 2, -3.0, true), Row.of("c", null, null, 2.0, false), Row.of("a", 0L, 0, null, null) };
String[] colNames = new String[] { "f_string", "f_long", "f_int", "f_double", "f_boolean" };
MemSourceBatchOp source = new MemSourceBatchOp(Arrays.asList(testArray), colNames);
CorrelationBatchOp corr = new CorrelationBatchOp().setSelectedCols(new String[] { "f_double", "f_int", "f_long" }).setMethod("PEARSON");
corr.linkFrom(source);
corr.lazyPrintCorrelation();
CorrelationResult corrMat = corr.collectCorrelation();
System.out.println(corrMat.toString());
Assert.assertArrayEquals(corrMat.getCorrelationMatrix().getArrayCopy1D(true), new double[] { 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0, 1.0, 1.0 }, 10e-4);
}
Aggregations