use of com.alibaba.alink.pipeline.feature.OneHotEncoder in project Alink by alibaba.
the class FmRecommTrainBatchOp method createFeatureVectors.
private static BatchOperator<?> createFeatureVectors(BatchOperator<?> featureTable, String idCol, String[] featureCols, String[] categoricalCols) {
TableUtil.assertSelectedColExist(featureCols, categoricalCols);
String[] numericalCols = subtract(featureCols, categoricalCols);
final Long envId = featureTable.getMLEnvironmentId();
if (categoricalCols.length > 0) {
OneHotEncoder onehot = new OneHotEncoder().setMLEnvironmentId(envId).setSelectedCols(categoricalCols).setOutputCols("__fm_features__").setDropLast(false);
featureTable = onehot.fit(featureTable).transform(featureTable);
numericalCols = (String[]) ArrayUtils.add(numericalCols, "__fm_features__");
}
VectorAssembler va = new VectorAssembler().setMLEnvironmentId(envId).setSelectedCols(numericalCols).setOutputCol("__fm_features__").setReservedCols(idCol);
featureTable = va.transform(featureTable);
featureTable = featureTable.udf("__fm_features__", "__fm_features__", new ConvertVec());
return featureTable;
}
use of com.alibaba.alink.pipeline.feature.OneHotEncoder in project Alink by alibaba.
the class SplitBatchOpTest method testSplitAfterOneHot.
@Test
public void testSplitAfterOneHot() throws Exception {
BatchOperator data = Iris.getBatchData();
OneHotEncoderModel model = new OneHotEncoder().setSelectedCols(Iris.getFeatureColNames()).setReservedCols(Iris.getLabelColName()).setOutputCols("features").fit(data);
data = model.transform(data);
SplitBatchOp split = new SplitBatchOp().setFraction(0.4);
BatchOperator data1 = split.linkFrom(data);
BatchOperator data2 = split.getSideOutput(0);
Assert.assertEquals(data1.count(), 60);
Assert.assertEquals(data2.count(), 90);
}
use of com.alibaba.alink.pipeline.feature.OneHotEncoder in project Alink by alibaba.
the class PipelineModelTest method pipelineTestSetLazy.
@Test
public void pipelineTestSetLazy() throws Exception {
String[] binaryNames = new String[] { "docid", "word", "cnt" };
TableSchema schema = new TableSchema(new String[] { "id", "docid", "word", "cnt" }, new TypeInformation<?>[] { Types.STRING, Types.STRING, Types.STRING, Types.LONG });
Row[] array = new Row[] { Row.of("0", "doc0", "天", 4L), Row.of("1", "doc0", "地", 5L), Row.of("2", "doc0", "人", 1L), Row.of("3", "doc1", null, 3L), Row.of("4", null, "人", 2L), Row.of("5", "doc1", "合", 4L), Row.of("6", "doc1", "一", 4L), Row.of("7", "doc2", "清", 3L), Row.of("8", "doc2", "一", 2L), Row.of("9", "doc2", "色", 2L) };
BatchOperator batchSource = new MemSourceBatchOp(Arrays.asList(array), schema);
OneHotEncoder oneHot = new OneHotEncoder().setSelectedCols(binaryNames).setOutputCols("results").setDropLast(false);
VectorAssembler va = new VectorAssembler().setSelectedCols(new String[] { "cnt", "results" }).enableLazyPrintTransformData(10, "xxxxxx").setOutputCol("outN");
VectorAssembler va2 = new VectorAssembler().setSelectedCols(new String[] { "cnt", "results" }).setOutputCol("outN");
VectorAssembler va3 = new VectorAssembler().setSelectedCols(new String[] { "cnt", "results" }).setOutputCol("outN");
VectorAssembler va4 = new VectorAssembler().setSelectedCols(new String[] { "cnt", "results" }).enableLazyPrintTransformStat("xxxxxx4").setOutputCol("outN");
Pipeline pl = new Pipeline().add(oneHot).add(va).add(va2).add(va3).add(va4);
PipelineModel model = pl.fit(batchSource);
Row[] parray = new Row[] { Row.of("0", "doc0", "天", 4L), Row.of("1", "doc2", null, 3L) };
// batch predict
MemSourceBatchOp predData = new MemSourceBatchOp(Arrays.asList(parray), schema);
BatchOperator result = model.transform(predData).select(new String[] { "docid", "outN" });
List<Row> rows = result.collect();
for (Row row : rows) {
if (row.getField(0).toString().equals("doc0")) {
Assert.assertEquals(VectorUtil.getVector(row.getField(1).toString()).size(), 19);
} else if (row.getField(0).toString().equals("doc2")) {
Assert.assertEquals(VectorUtil.getVector(row.getField(1).toString()).size(), 19);
}
}
// stream predict
MemSourceStreamOp predSData = new MemSourceStreamOp(Arrays.asList(parray), schema);
model.transform(predSData).print();
StreamOperator.execute();
}
use of com.alibaba.alink.pipeline.feature.OneHotEncoder in project Alink by alibaba.
the class Chap10 method c_3_1.
static void c_3_1() throws Exception {
BatchOperator<?> train_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_FILE);
BatchOperator<?> test_data = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_FILE);
Pipeline pipeline = new Pipeline().add(new OneHotEncoder().setSelectedCols(CATEGORY_FEATURE_COL_NAMES).setEncode(Encode.VECTOR)).add(new VectorAssembler().setSelectedCols(FEATURE_COL_NAMES).setOutputCol(VEC_COL_NAME)).add(new LogisticRegression().setVectorCol(VEC_COL_NAME).setLabelCol(LABEL_COL_NAME).setPredictionCol(PREDICTION_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME));
pipeline.fit(train_data).transform(test_data).link(new EvalBinaryClassBatchOp().setPositiveLabelValueString("2").setLabelCol(LABEL_COL_NAME).setPredictionDetailCol(PRED_DETAIL_COL_NAME).lazyPrintMetrics());
BatchOperator.execute();
}
Aggregations