use of com.alibaba.alink.common.utils.RowCollector in project Alink by alibaba.
the class OneHotTrainBatchOpTest method testTransform.
@Test
public void testTransform() {
String s = "[{\"binDivideType\":\"DISCRETE\",\"featureName\":\"col1\",\"bin\":{\"NORM\":[{\"values\":[\"b\"]," + "\"total\":2,\"positive\":2,\"index\":1,\"negative\":0,\"totalRate\":0.2222222222222222," + "\"positiveRate\":0.4,\"negativeRate\":0.0,\"positivePercentage\":1.0},{\"values\":[\"c\"]," + "\"woe\":-0.2231435513142097,\"total\":2,\"positive\":1,\"index\":2,\"negative\":1," + "\"totalRate\":0.2222222222222222,\"positiveRate\":0.2,\"negativeRate\":0.25," + "\"positivePercentage\":0.5,\"iv\":0.011157177565710483},{\"values\":[\"a\"]," + "\"woe\":-0.2231435513142097,\"total\":2,\"positive\":1,\"index\":0,\"negative\":1," + "\"totalRate\":0.2222222222222222,\"positiveRate\":0.2,\"negativeRate\":0.25," + "\"positivePercentage\":0.5,\"iv\":0.011157177565710483}],\"NULL\":{\"total\":1,\"positive\":1," + "\"index\":3,\"negative\":0,\"totalRate\":0.1111111111111111,\"positiveRate\":0.2,\"negativeRate\":0.0," + "\"positivePercentage\":1.0},\"ELSE\":{\"total\":2,\"positive\":0,\"index\":4,\"negative\":2," + "\"totalRate\":0.2222222222222222,\"positiveRate\":0.0,\"negativeRate\":0.5," + "\"positivePercentage\":0.0}},\"featureType\":\"STRING\",\"iv\":0.022314355131420965,\"binCount\":3}]\n";
List<Tuple2<Long, FeatureBinsCalculator>> featureList = new ArrayList<>();
long index = 0L;
List<String> name = new ArrayList<>();
for (FeatureBinsCalculator calculator : FeatureBinsUtil.deSerialize(s)) {
featureList.add(Tuple2.of(index++, calculator));
name.add(calculator.getFeatureName());
}
RowCollector modelRows = new RowCollector();
modelRows.clear();
Params meta = new Params().set(HasSelectedCols.SELECTED_COLS, name.toArray(new String[0]));
OneHotTrainBatchOp.transformFeatureBinsToModel(featureList, modelRows, meta);
List<Row> list = modelRows.getRows();
Assert.assertEquals(new OneHotModelDataConverter().load(list).modelData.getNumberOfTokensOfColumn("col1"), 3);
}
use of com.alibaba.alink.common.utils.RowCollector in project Alink by alibaba.
the class AnyToTripleFlatMapperTest method flatMap.
@Test
public void flatMap() throws Exception {
AnyToTripleFlatMapper transKvToTriple = new AnyToTripleFlatMapper(CsvUtil.schemaStr2Schema("row_id long, kv string"), new Params().set(FormatTransParams.FROM_FORMAT, FormatType.KV).set(FromKvParams.KV_COL, "kv").set(FromKvParams.KV_COL_DELIMITER, ",").set(FromKvParams.KV_VAL_DELIMITER, ":").set(ToTripleParams.TRIPLE_COLUMN_VALUE_SCHEMA_STR, "col_id int, val double").set(ToTripleParams.RESERVED_COLS, new String[] { "row_id" }));
transKvToTriple.open();
RowCollector collector = new RowCollector();
transKvToTriple.flatMap(Row.of(3L, "1:1.0,4:1.0"), collector);
Assert.assertEquals(collector.getRows().size(), 2);
// for (Row row : collector.getRows()) {
// System.out.println(row);
// }
}
use of com.alibaba.alink.common.utils.RowCollector in project Alink by alibaba.
the class ImputerTrainBatchOp method linkFrom.
@Override
public ImputerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String[] selectedColNames = getSelectedCols();
Strategy strategy = getStrategy();
// result is statistic model with strategy.
ImputerModelDataConverter converter = new ImputerModelDataConverter();
converter.selectedColNames = selectedColNames;
converter.selectedColTypes = TableUtil.findColTypesWithAssertAndHint(in.getSchema(), selectedColNames);
// if strategy is not min, max, mean
DataSet<Row> rows;
if (isNeedStatModel()) {
rows = StatisticsHelper.summary(in, selectedColNames).flatMap(new BuildImputerModel(selectedColNames, TableUtil.findColTypesWithAssertAndHint(in.getSchema(), selectedColNames), strategy));
} else {
if (!getParams().contains(ImputerTrainParams.FILL_VALUE)) {
throw new RuntimeException("In VALUE strategy, the filling value is necessary.");
}
String fillValue = getFillValue();
RowCollector collector = new RowCollector();
converter.save(Tuple3.of(Strategy.VALUE, null, fillValue), collector);
rows = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromCollection(collector.getRows());
}
this.setOutput(rows, converter.getModelSchema());
return this;
}
use of com.alibaba.alink.common.utils.RowCollector in project Alink by alibaba.
the class RecommResultToTableMapperTest method flatMap.
@Test
public void flatMap() {
Params params = new Params().set(FlattenKObjectParams.SELECTED_COL, "recomm").set(FlattenKObjectParams.OUTPUT_COLS, new String[] { "calc" }).set(FlattenKObjectParams.OUTPUT_COL_TYPES, new String[] { "long" });
FlattenKObjectMapper mapper = new FlattenKObjectMapper(dataSchema, params);
RowCollector collector = new RowCollector();
for (Row row : rows) {
mapper.flatMap(row, collector);
}
Assert.assertEquals(collector.getRows().size(), 6);
}
use of com.alibaba.alink.common.utils.RowCollector in project Alink by alibaba.
the class KMeansOutputModel method calc.
@Override
public List<Row> calc(ComContext context) {
if (context.getTaskId() != 0) {
return null;
}
Integer vectorSize = context.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
Integer k = context.getObj(KMeansTrainBatchOp.K);
Tuple2<Integer, FastDistanceMatrixData> stepNoCentroids1 = context.getObj(KMeansTrainBatchOp.CENTROID1);
Tuple2<Integer, FastDistanceMatrixData> stepNoCentroids2 = context.getObj(KMeansTrainBatchOp.CENTROID2);
double[] buffer = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);
FastDistanceMatrixData centroid;
if (stepNoCentroids1.f0 > stepNoCentroids2.f0) {
centroid = stepNoCentroids1.f1;
} else {
centroid = stepNoCentroids2.f1;
}
KMeansTrainModelData modelData = new KMeansTrainModelData();
modelData.centroids = new ArrayList<>();
DenseMatrix matrix = centroid.getVectors();
int weightIndex = vectorSize;
for (int id = 0; id < k; id++) {
modelData.centroids.add(new KMeansTrainModelData.ClusterSummary(new DenseVector(matrix.getColumn(id)), id, buffer[weightIndex]));
weightIndex += vectorSize + 1;
}
modelData.params = new KMeansTrainModelData.ParamSummary();
modelData.params.k = k;
modelData.params.vectorColName = vectorColName;
modelData.params.distanceType = distanceType;
modelData.params.vectorSize = vectorSize;
modelData.params.latitudeColName = latitudeColName;
modelData.params.longtitudeColName = longtitudeColName;
RowCollector collector = new RowCollector();
new KMeansModelDataConverter().save(modelData, collector);
return collector.getRows();
}
Aggregations