use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class ModelExporterUtils method preOrderSerialize.
private static BatchOperator<?> preOrderSerialize(StageNode[] stages, BatchOperator<?> packed, final TableSchema schema, final int offset) {
if (stages == null || stages.length == 0) {
return packed;
}
final int len = schema.getFieldTypes().length;
final long[] id = new long[1];
final BatchOperator<?>[] localPacked = new BatchOperator<?>[] { packed };
Consumer<Integer> serializeModelData = lp -> {
StageNode stageNode = stages[lp];
if (stageNode.parent >= 0 && stageNode.schemaIndices != null && stageNode.children == null && stageNode.stage instanceof ModelBase<?>) {
final long localId = id[0];
final int[] localSchemaIndices = stageNode.schemaIndices;
DataSet<Row> modelData = ((ModelBase<?>) stageNode.stage).getModelData().getDataSet().map(new MapFunction<Row, Row>() {
private static final long serialVersionUID = 5218543921039328938L;
@Override
public Row map(Row value) {
Row ret = new Row(len);
ret.setField(0, localId);
for (int i = 0; i < localSchemaIndices.length; ++i) {
ret.setField(localSchemaIndices[i] + offset, value.getField(i));
}
return ret;
}
}).returns(new RowTypeInfo(schema.getFieldTypes()));
localPacked[0] = new TableSourceBatchOp(DataSetConversionUtil.toTable(localPacked[0].getMLEnvironmentId(), localPacked[0].getDataSet().union(modelData), schema)).setMLEnvironmentId(localPacked[0].getMLEnvironmentId());
}
id[0]++;
};
preOrder(stages, serializeModelData);
return localPacked[0];
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class BaseTuning method findBestTVSplit.
protected Tuple2<Pipeline, Report> findBestTVSplit(BatchOperator<?> in, double ratio, PipelineCandidatesBase candidates) {
int nIter = candidates.size();
SplitBatchOp sbo = new SplitBatchOp().setFraction(ratio).setMLEnvironmentId(getMLEnvironmentId()).linkFrom(new TableSourceBatchOp(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), shuffle(in.getDataSet()), in.getSchema())).setMLEnvironmentId(getMLEnvironmentId()));
int bestIdx = -1;
double bestMetric = 0.;
ArrayList<Double> experienceScores = new ArrayList<>(nIter);
List<Report.ReportElement> reportElements = new ArrayList<>();
for (int i = 0; i < nIter; i++) {
Tuple2<Pipeline, List<Tuple3<Integer, ParamInfo, Object>>> cur;
try {
cur = candidates.get(i, experienceScores);
} catch (CloneNotSupportedException e) {
throw new RuntimeException(e);
}
double metric = Double.NaN;
try {
metric = tuningEvaluator.evaluate(cur.f0.fit(sbo).transform(sbo.getSideOutput(0)));
} catch (Exception ex) {
if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
System.out.println(String.format("BestTVSplit, i: %d, best: %f, metric: %f, exception: %s", i, bestMetric, metric, ExceptionUtils.stringifyException(ex)));
}
experienceScores.add(i, metric);
reportElements.add(new Report.ReportElement(cur.f0, cur.f1, metric, ExceptionUtils.stringifyException(ex)));
continue;
}
experienceScores.add(i, metric);
if (Double.isNaN(metric)) {
reportElements.add(new Report.ReportElement(cur.f0, cur.f1, metric, "Metric is nan."));
continue;
}
reportElements.add(new Report.ReportElement(cur.f0, cur.f1, metric));
if (bestIdx == -1) {
bestMetric = metric;
bestIdx = i;
} else {
if ((tuningEvaluator.isLargerBetter() && bestMetric < metric) || (!tuningEvaluator.isLargerBetter() && bestMetric > metric)) {
bestMetric = metric;
bestIdx = i;
}
}
if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
System.out.println(String.format("BestTVSplit, i: %d, best: %f, metric: %f", i, bestMetric, metric));
}
}
if (bestIdx < 0) {
throw new RuntimeException("Can not find a best model. Report: " + new Report(tuningEvaluator, reportElements).toPrettyJson());
}
try {
return Tuple2.of(candidates.get(bestIdx, experienceScores).f0, new Report(tuningEvaluator, reportElements));
} catch (CloneNotSupportedException e) {
throw new RuntimeException(e);
}
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class BaseTuning method kFoldCv.
private Tuple2<Double, String> kFoldCv(DataSet<Tuple2<Integer, Row>> splitData, Pipeline pipeline, TableSchema schema, int k) {
double ret = 0.;
int validSize = 0;
StringBuilder reason = new StringBuilder();
for (int i = 0; i < k; ++i) {
final int loop = i;
DataSet<Row> trainInput = splitData.filter(new FilterFunction<Tuple2<Integer, Row>>() {
private static final long serialVersionUID = 2249884521437544236L;
@Override
public boolean filter(Tuple2<Integer, Row> value) {
return value.f0 != loop;
}
}).map(new MapFunction<Tuple2<Integer, Row>, Row>() {
private static final long serialVersionUID = 2618229645786221757L;
@Override
public Row map(Tuple2<Integer, Row> value) {
return value.f1;
}
});
DataSet<Row> testInput = splitData.filter(new FilterFunction<Tuple2<Integer, Row>>() {
private static final long serialVersionUID = 5811166054549336470L;
@Override
public boolean filter(Tuple2<Integer, Row> value) {
return value.f0 == loop;
}
}).map(new MapFunction<Tuple2<Integer, Row>, Row>() {
private static final long serialVersionUID = -1760709990316111721L;
@Override
public Row map(Tuple2<Integer, Row> value) {
return value.f1;
}
});
PipelineModel model = pipeline.fit(new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), trainInput, schema)).setMLEnvironmentId(getMLEnvironmentId()));
double localMetric = Double.NaN;
try {
localMetric = tuningEvaluator.evaluate(model.transform(new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), testInput, schema))));
if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
System.out.println(String.format("kFoldCv, k: %d, i: %d, metric: %f", k, i, localMetric));
}
} catch (Exception ex) {
if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
System.out.println(String.format("kFoldCv err, k: %d, i: %d, metric: %f, exception: %s", k, i, localMetric, ExceptionUtils.stringifyException(ex)));
}
reason.append(ExceptionUtils.stringifyException(ex)).append("\n");
continue;
}
ret += localMetric;
validSize++;
}
if (validSize == 0) {
reason.append("valid size is zero.").append("\n");
return Tuple2.of(Double.NaN, reason.toString());
}
ret /= validSize;
if (validSize > 0) {
return Tuple2.of(ret, reason.toString());
} else {
reason.append("valid size if negative.").append("\n");
return Tuple2.of(Double.NaN, reason.toString());
}
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class DerbyCatalogTest method sinkBatch.
@Test
public void sinkBatch() throws Exception {
Row[] rows = new Row[] { Row.of(0L, "string", new Date(0), new BigDecimal("0.00"), 0.0, 0.0f, 0, (short) 0, new Time(0), new Timestamp(0), new byte[] { 0, 1 }) };
MemSourceBatchOp memSourceBatchOp = new MemSourceBatchOp(Arrays.asList(rows), new TableSchema(new String[] { "col_long", "col_string", "col_date", "col_bigdecimal", "col_double", "col_float", "col_int", "col_short", "col_time", "col_timestamp", "col_varcharforbit" }, new TypeInformation<?>[] { Types.LONG, Types.STRING, Types.SQL_DATE, Types.BIG_DEC, Types.DOUBLE, Types.FLOAT, Types.INT, Types.SHORT, Types.SQL_TIME, Types.SQL_TIMESTAMP, Types.PRIMITIVE_ARRAY(Types.BYTE) }));
derby.sinkBatch(new ObjectPath(DERBY_SCHEMA, DERBY_DB_TABLE), memSourceBatchOp.getOutputTable(), new Params(), memSourceBatchOp.getMLEnvironmentId());
BatchOperator.execute();
Assert.assertFalse(new TableSourceBatchOp(derby.sourceBatch(new ObjectPath(DERBY_SCHEMA, DERBY_DB_TABLE), new Params(), MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID)).collect().isEmpty());
}
use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.
the class KeywordsExtractionBatchOp method linkFrom.
@Override
public KeywordsExtractionBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
final String docId = "doc_alink_id";
String selectedColName = this.getSelectedCol();
TableUtil.assertSelectedColExist(in.getColNames(), selectedColName);
String outputColName = this.getOutputCol();
if (null == outputColName) {
outputColName = selectedColName;
}
OutputColsHelper outputColsHelper = new OutputColsHelper(in.getSchema(), outputColName, Types.STRING);
final Integer topN = this.getTopN();
Method method = this.getMethod();
BatchOperator inWithId = new TableSourceBatchOp(AppendIdBatchOp.appendId(in.getDataSet(), in.getSchema(), getMLEnvironmentId())).setMLEnvironmentId(getMLEnvironmentId());
DataSet<Row> weights;
StopWordsRemoverBatchOp filterOp = new StopWordsRemoverBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCol(selectedColName).setOutputCol("selectedColName");
BatchOperator filtered = filterOp.linkFrom(inWithId);
switch(method) {
case TF_IDF:
{
DocWordCountBatchOp wordCount = new DocWordCountBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setDocIdCol(AppendIdBatchOp.appendIdColName).setContentCol("selectedColName");
TfidfBatchOp tfIdf = new TfidfBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setDocIdCol(AppendIdBatchOp.appendIdColName).setWordCol("word").setCountCol("cnt");
BatchOperator op = filtered.link(wordCount).link(tfIdf);
weights = op.select(AppendIdBatchOp.appendIdColName + ", " + "word, tfidf").getDataSet();
break;
}
case TEXT_RANK:
{
DataSet<Row> data = filtered.select(AppendIdBatchOp.appendIdColName + ", selectedColName").getDataSet();
// Initialize the TextRank class, which runs the text rank algorithm.
final Params params = getParams();
weights = data.flatMap(new FlatMapFunction<Row, Row>() {
private static final long serialVersionUID = -4083643981693873537L;
@Override
public void flatMap(Row row, Collector<Row> collector) throws Exception {
// For each row, apply the text rank algorithm to get the key words.
Row[] out = TextRank.getKeyWords(row, params.get(KeywordsExtractionParams.DAMPING_FACTOR), params.get(KeywordsExtractionParams.WINDOW_SIZE), params.get(KeywordsExtractionParams.MAX_ITER), params.get(KeywordsExtractionParams.EPSILON));
for (int i = 0; i < out.length; i++) {
collector.collect(out[i]);
}
}
});
break;
}
default:
{
throw new RuntimeException("Not support this type!");
}
}
DataSet<Row> res = weights.groupBy(new KeySelector<Row, String>() {
private static final long serialVersionUID = 801794449492798203L;
@Override
public String getKey(Row row) {
Object obj = row.getField(0);
if (obj == null) {
return "NULL";
}
return row.getField(0).toString();
}
}).reduceGroup(new GroupReduceFunction<Row, Row>() {
private static final long serialVersionUID = -4051509261188494119L;
@Override
public void reduce(Iterable<Row> rows, Collector<Row> collector) {
List<Row> list = new ArrayList<>();
for (Row row : rows) {
list.add(row);
}
Collections.sort(list, new Comparator<Row>() {
@Override
public int compare(Row row1, Row row2) {
Double v1 = (double) row1.getField(2);
Double v2 = (double) row2.getField(2);
return v2.compareTo(v1);
}
});
int len = Math.min(list.size(), topN);
Row out = new Row(2);
StringBuilder builder = new StringBuilder();
for (int i = 0; i < len; i++) {
builder.append(list.get(i).getField(1).toString());
if (i != len - 1) {
builder.append(" ");
}
}
out.setField(0, list.get(0).getField(0));
out.setField(1, builder.toString());
collector.collect(out);
}
});
// Set the output into table.
Table tmpTable = DataSetConversionUtil.toTable(getMLEnvironmentId(), res, new String[] { docId, outputColName }, new TypeInformation[] { Types.LONG, Types.STRING });
StringBuilder selectClause = new StringBuilder("a." + outputColName);
String[] keepColNames = outputColsHelper.getReservedColumns();
for (int i = 0; i < keepColNames.length; i++) {
selectClause.append("," + keepColNames[i]);
}
JoinBatchOp join = new JoinBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setType("join").setSelectClause(selectClause.toString()).setJoinPredicate(docId + "=" + AppendIdBatchOp.appendIdColName);
this.setOutputTable(join.linkFrom(new TableSourceBatchOp(tmpTable).setMLEnvironmentId(getMLEnvironmentId()), inWithId).getOutputTable());
return this;
}
Aggregations