use of com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationDataConverter in project Alink by alibaba.
the class CorrelationBatchOp method linkFrom.
@Override
public CorrelationBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String[] selectedColNames = this.getParams().get(SELECTED_COLS);
if (selectedColNames == null) {
selectedColNames = in.getColNames();
}
// check col types must be double or bigint
TableUtil.assertNumericalCols(in.getSchema(), selectedColNames);
Method corrType = getMethod();
if (Method.PEARSON == corrType) {
DataSet<Tuple2<TableSummary, CorrelationResult>> srt = StatisticsHelper.pearsonCorrelation(in, selectedColNames);
DataSet<Row> result = srt.flatMap(new FlatMapFunction<Tuple2<TableSummary, CorrelationResult>, Row>() {
private static final long serialVersionUID = -4498296161046449646L;
@Override
public void flatMap(Tuple2<TableSummary, CorrelationResult> summary, Collector<Row> collector) {
new CorrelationDataConverter().save(summary.f1, collector);
}
});
this.setOutput(result, new CorrelationDataConverter().getModelSchema());
} else {
DataSet<Row> data = inputs[0].select(selectedColNames).getDataSet();
DataSet<Row> rank = SpearmanCorrelation.calcRank(data, false);
TypeInformation[] colTypes = new TypeInformation[selectedColNames.length];
for (int i = 0; i < colTypes.length; i++) {
colTypes[i] = Types.DOUBLE;
}
BatchOperator rankOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), rank, selectedColNames, colTypes)).setMLEnvironmentId(getMLEnvironmentId());
CorrelationBatchOp corrBatchOp = new CorrelationBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCols(selectedColNames);
rankOp.link(corrBatchOp);
this.setOutput(corrBatchOp.getDataSet(), corrBatchOp.getSchema());
}
return this;
}
use of com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationDataConverter in project Alink by alibaba.
the class VectorCorrelationBatchOp method linkFrom.
@Override
public VectorCorrelationBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String vectorColName = getSelectedCol();
Method corrType = getMethod();
if (Method.PEARSON == corrType) {
DataSet<Tuple2<BaseVectorSummary, CorrelationResult>> srt = StatisticsHelper.vectorPearsonCorrelation(in, vectorColName);
// block
DataSet<Row> result = srt.flatMap(new FlatMapFunction<Tuple2<BaseVectorSummary, CorrelationResult>, Row>() {
private static final long serialVersionUID = 2134644397476490118L;
@Override
public void flatMap(Tuple2<BaseVectorSummary, CorrelationResult> srt, Collector<Row> collector) throws Exception {
new CorrelationDataConverter().save(srt.f1, collector);
}
});
this.setOutput(result, new CorrelationDataConverter().getModelSchema());
} else {
DataSet<Row> data = StatisticsHelper.transformToColumns(in, null, vectorColName, null);
DataSet<Row> rank = SpearmanCorrelation.calcRank(data, true);
BatchOperator rankOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), rank, new String[] { "col" }, new TypeInformation[] { Types.STRING })).setMLEnvironmentId(getMLEnvironmentId());
VectorCorrelationBatchOp corrBatchOp = new VectorCorrelationBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCol("col");
rankOp.link(corrBatchOp);
this.setOutput(corrBatchOp.getDataSet(), corrBatchOp.getSchema());
}
return this;
}
Aggregations