Search in sources :

Example 1 with FmClassifierTrainBatchOp

use of com.alibaba.alink.operator.batch.classification.FmClassifierTrainBatchOp in project Alink by alibaba.

the class FmRecommTrainBatchOp method linkFrom.

/**
 * There are 3 input tables: 1) user-item-label table, 2) user features table, 3) item features table.
 * If user or item features table is missing, then use their IDs as features.
 */
@Override
public FmRecommTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> samplesOp = inputs[0];
    final Long envId = samplesOp.getMLEnvironmentId();
    BatchOperator<?> userFeaturesOp = inputs.length >= 2 ? inputs[1] : null;
    BatchOperator<?> itemFeaturesOp = inputs.length >= 3 ? inputs[2] : null;
    Params params = getParams();
    String userCol = params.get(USER_COL);
    String itemCol = params.get(ITEM_COL);
    String labelCol = params.get(RATE_COL);
    String[] userFeatureCols = params.get(USER_FEATURE_COLS);
    String[] itemFeatureCols = params.get(ITEM_FEATURE_COLS);
    String[] userCateFeatureCols = params.get(USER_CATEGORICAL_FEATURE_COLS);
    String[] itemCateFeatureCols = params.get(ITEM_CATEGORICAL_FEATURE_COLS);
    if (userFeaturesOp == null) {
        userFeaturesOp = samplesOp.select("`" + userCol + "`").distinct();
        userFeatureCols = new String[] { userCol };
        userCateFeatureCols = new String[] { userCol };
    } else {
        Preconditions.checkArgument(TableUtil.findColTypeWithAssert(userFeaturesOp.getSchema(), userCol).equals(TableUtil.findColTypeWithAssert(samplesOp.getSchema(), userCol)), "user column type mismatch");
    }
    if (itemFeaturesOp == null) {
        itemFeaturesOp = samplesOp.select("`" + itemCol + "`").distinct();
        itemFeatureCols = new String[] { itemCol };
        itemCateFeatureCols = new String[] { itemCol };
    } else {
        Preconditions.checkArgument(TableUtil.findColTypeWithAssert(itemFeaturesOp.getSchema(), itemCol).equals(TableUtil.findColTypeWithAssert(samplesOp.getSchema(), itemCol)), "item column type mismatch");
    }
    BatchOperator<?> history = samplesOp.select(new String[] { userCol, itemCol });
    userFeaturesOp = createFeatureVectors(userFeaturesOp, userCol, userFeatureCols, userCateFeatureCols);
    itemFeaturesOp = createFeatureVectors(itemFeaturesOp, itemCol, itemFeatureCols, itemCateFeatureCols);
    LeftOuterJoinBatchOp joinOp1 = new LeftOuterJoinBatchOp().setMLEnvironmentId(envId).setJoinPredicate("a.`" + userCol + "`=" + "b.`" + userCol + "`").setSelectClause("a.*, b.__fm_features__ as __user_features__");
    LeftOuterJoinBatchOp joinOp2 = new LeftOuterJoinBatchOp().setMLEnvironmentId(envId).setJoinPredicate("a.`" + itemCol + "`=" + "b.`" + itemCol + "`").setSelectClause("a.*, b.__fm_features__ as __item_features__");
    samplesOp = joinOp1.linkFrom(samplesOp, userFeaturesOp);
    samplesOp = joinOp2.linkFrom(samplesOp, itemFeaturesOp);
    samplesOp = samplesOp.udf("__user_features__", "__user_features__", new CheckNotNull());
    samplesOp = samplesOp.udf("__item_features__", "__item_features__", new CheckNotNull());
    VectorAssembler va = new VectorAssembler().setMLEnvironmentId(envId).setSelectedCols("__user_features__", "__item_features__").setOutputCol("__alink_features__").setReservedCols(labelCol);
    samplesOp = va.transform(samplesOp);
    BatchOperator<?> fmModel;
    if (!implicitFeedback) {
        fmModel = new FmRegressorTrainBatchOp(params).setLabelCol(params.get(RATE_COL)).setVectorCol("__alink_features__").setMLEnvironmentId(envId);
    } else {
        fmModel = new FmClassifierTrainBatchOp(params).setLabelCol(params.get(RATE_COL)).setVectorCol("__alink_features__").setMLEnvironmentId(envId);
    }
    fmModel.linkFrom(samplesOp);
    BatchOperator<?> model = PackBatchOperatorUtil.packBatchOps(new BatchOperator<?>[] { fmModel, userFeaturesOp, itemFeaturesOp, history });
    setOutputTable(model.getOutputTable());
    return this;
}
Also used : LeftOuterJoinBatchOp(com.alibaba.alink.operator.batch.sql.LeftOuterJoinBatchOp) FmRegressorTrainBatchOp(com.alibaba.alink.operator.batch.regression.FmRegressorTrainBatchOp) VectorAssembler(com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler) FmRecommTrainParams(com.alibaba.alink.params.recommendation.FmRecommTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) FmClassifierTrainBatchOp(com.alibaba.alink.operator.batch.classification.FmClassifierTrainBatchOp)

Aggregations

FmClassifierTrainBatchOp (com.alibaba.alink.operator.batch.classification.FmClassifierTrainBatchOp)1 FmRegressorTrainBatchOp (com.alibaba.alink.operator.batch.regression.FmRegressorTrainBatchOp)1 LeftOuterJoinBatchOp (com.alibaba.alink.operator.batch.sql.LeftOuterJoinBatchOp)1 FmRecommTrainParams (com.alibaba.alink.params.recommendation.FmRecommTrainParams)1 VectorAssembler (com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler)1 Params (org.apache.flink.ml.api.misc.param.Params)1