use of com.alibaba.alink.operator.batch.classification.FmClassifierTrainBatchOp in project Alink by alibaba.
the class FmRecommTrainBatchOp method linkFrom.
/**
* There are 3 input tables: 1) user-item-label table, 2) user features table, 3) item features table.
* If user or item features table is missing, then use their IDs as features.
*/
@Override
public FmRecommTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> samplesOp = inputs[0];
final Long envId = samplesOp.getMLEnvironmentId();
BatchOperator<?> userFeaturesOp = inputs.length >= 2 ? inputs[1] : null;
BatchOperator<?> itemFeaturesOp = inputs.length >= 3 ? inputs[2] : null;
Params params = getParams();
String userCol = params.get(USER_COL);
String itemCol = params.get(ITEM_COL);
String labelCol = params.get(RATE_COL);
String[] userFeatureCols = params.get(USER_FEATURE_COLS);
String[] itemFeatureCols = params.get(ITEM_FEATURE_COLS);
String[] userCateFeatureCols = params.get(USER_CATEGORICAL_FEATURE_COLS);
String[] itemCateFeatureCols = params.get(ITEM_CATEGORICAL_FEATURE_COLS);
if (userFeaturesOp == null) {
userFeaturesOp = samplesOp.select("`" + userCol + "`").distinct();
userFeatureCols = new String[] { userCol };
userCateFeatureCols = new String[] { userCol };
} else {
Preconditions.checkArgument(TableUtil.findColTypeWithAssert(userFeaturesOp.getSchema(), userCol).equals(TableUtil.findColTypeWithAssert(samplesOp.getSchema(), userCol)), "user column type mismatch");
}
if (itemFeaturesOp == null) {
itemFeaturesOp = samplesOp.select("`" + itemCol + "`").distinct();
itemFeatureCols = new String[] { itemCol };
itemCateFeatureCols = new String[] { itemCol };
} else {
Preconditions.checkArgument(TableUtil.findColTypeWithAssert(itemFeaturesOp.getSchema(), itemCol).equals(TableUtil.findColTypeWithAssert(samplesOp.getSchema(), itemCol)), "item column type mismatch");
}
BatchOperator<?> history = samplesOp.select(new String[] { userCol, itemCol });
userFeaturesOp = createFeatureVectors(userFeaturesOp, userCol, userFeatureCols, userCateFeatureCols);
itemFeaturesOp = createFeatureVectors(itemFeaturesOp, itemCol, itemFeatureCols, itemCateFeatureCols);
LeftOuterJoinBatchOp joinOp1 = new LeftOuterJoinBatchOp().setMLEnvironmentId(envId).setJoinPredicate("a.`" + userCol + "`=" + "b.`" + userCol + "`").setSelectClause("a.*, b.__fm_features__ as __user_features__");
LeftOuterJoinBatchOp joinOp2 = new LeftOuterJoinBatchOp().setMLEnvironmentId(envId).setJoinPredicate("a.`" + itemCol + "`=" + "b.`" + itemCol + "`").setSelectClause("a.*, b.__fm_features__ as __item_features__");
samplesOp = joinOp1.linkFrom(samplesOp, userFeaturesOp);
samplesOp = joinOp2.linkFrom(samplesOp, itemFeaturesOp);
samplesOp = samplesOp.udf("__user_features__", "__user_features__", new CheckNotNull());
samplesOp = samplesOp.udf("__item_features__", "__item_features__", new CheckNotNull());
VectorAssembler va = new VectorAssembler().setMLEnvironmentId(envId).setSelectedCols("__user_features__", "__item_features__").setOutputCol("__alink_features__").setReservedCols(labelCol);
samplesOp = va.transform(samplesOp);
BatchOperator<?> fmModel;
if (!implicitFeedback) {
fmModel = new FmRegressorTrainBatchOp(params).setLabelCol(params.get(RATE_COL)).setVectorCol("__alink_features__").setMLEnvironmentId(envId);
} else {
fmModel = new FmClassifierTrainBatchOp(params).setLabelCol(params.get(RATE_COL)).setVectorCol("__alink_features__").setMLEnvironmentId(envId);
}
fmModel.linkFrom(samplesOp);
BatchOperator<?> model = PackBatchOperatorUtil.packBatchOps(new BatchOperator<?>[] { fmModel, userFeaturesOp, itemFeaturesOp, history });
setOutputTable(model.getOutputTable());
return this;
}
Aggregations