Search in sources :

Example 1 with MultiStringIndexerModelData

use of com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData in project Alink by alibaba.

the class ItemCfTrainBatchOp method linkFrom.

@Override
public ItemCfTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    final String userCol = getUserCol();
    final String itemCol = getItemCol();
    final String rateCol = getRateCol();
    final TypeInformation<?> userType = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), userCol);
    final String itemType = FlinkTypeConverter.getTypeString(TableUtil.findColTypeWithAssertAndHint(in.getSchema(), itemCol));
    if (null == rateCol) {
        Preconditions.checkArgument(getSimilarityType().equals(SimilarityType.JACCARD), "When rateCol is not given, only Jaccard calc is supported!");
    }
    String[] selectedCols = (null == rateCol ? new String[] { userCol, itemCol } : new String[] { userCol, itemCol, rateCol });
    in = in.select(selectedCols);
    OneHotTrainBatchOp oneHot = new OneHotTrainBatchOp().setSelectedCols(userCol, itemCol).linkFrom(in);
    String USER_ENCODE = "userEncode";
    BatchOperator<?> userEncode = new OneHotPredictBatchOp().setSelectedCols(userCol, itemCol).setOutputCols(USER_ENCODE, itemCol).setEncode(HasEncodeWithoutWoe.Encode.INDEX).linkFrom(oneHot, in);
    DataSet<Row> itemVector = userEncode.select(null == rateCol ? new String[] { USER_ENCODE, itemCol } : new String[] { USER_ENCODE, itemCol, rateCol }).getDataSet().groupBy(1).reduceGroup(new ItemVectorGenerator(rateCol, userCol)).withBroadcastSet(oneHot.getSideOutput(0).getDataSet(), USER_NUM).name("GenerateItemVector");
    DataSet<Row> userVector = userEncode.select(null == rateCol ? new String[] { userCol, itemCol } : new String[] { userCol, itemCol, rateCol }).getDataSet().groupBy(0).reduceGroup(new UserItemVectorGenerator(rateCol, itemCol)).withBroadcastSet(oneHot.getSideOutput(0).getDataSet(), USER_NUM).name("GetUserItems");
    BatchOperator<?> items = new DataSetWrapperBatchOp(itemVector, COL_NAMES, new TypeInformation[] { Types.LONG, VectorTypes.SPARSE_VECTOR });
    BatchOperator<?> train = new VectorNearestNeighborTrainBatchOp().setIdCol(COL_NAMES[0]).setSelectedCol(COL_NAMES[1]).setMetric(HasFastMetric.Metric.valueOf(this.getSimilarityType().name())).linkFrom(items);
    BatchOperator<?> op = new VectorNearestNeighborPredictBatchOp().setSelectedCol(COL_NAMES[1]).setReservedCols(COL_NAMES[0]).setTopN(this.getMaxNeighborNumber() + 1).setRadius(1.0 - this.getSimilarityThreshold()).linkFrom(train, items);
    DataSet<Row> itemSimilarities = op.select(new String[] { COL_NAMES[0], COL_NAMES[1] }).getDataSet().mapPartition(new ItemSimilarityVectorGenerator(itemCol)).withBroadcastSet(oneHot.getSideOutput(0).getDataSet(), USER_NUM).name("CalcItemSimilarity");
    DataSet<Row> itemMapTable = oneHot.getDataSet().filter(new FilterFunction<Row>() {

        private static final long serialVersionUID = 7406134775433418651L;

        @Override
        public boolean filter(Row value) {
            return !value.getField(0).equals(0L);
        }
    });
    Params params = getParams();
    DataSet<Row> outs = userVector.union(itemSimilarities).mapPartition(new RichMapPartitionFunction<Row, Row>() {

        private static final long serialVersionUID = 3779020277896699637L;

        @Override
        public void mapPartition(Iterable<Row> values, Collector<Row> out) {
            Params meta = null;
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                List<Row> modelRows = getRuntimeContext().getBroadcastVariable("ITEM_MAP");
                MultiStringIndexerModelData modelData = new OneHotModelDataConverter().load(modelRows).modelData;
                String[] items = new String[(int) modelData.getNumberOfTokensOfColumn(itemCol)];
                for (int i = 0; i < items.length; i++) {
                    items[i] = modelData.getToken(itemCol, (long) i);
                }
                meta = params.set(ItemCfRecommTrainParams.RATE_COL, rateCol).set(ItemCfRecommModelDataConverter.ITEMS, items).set(ItemCfRecommModelDataConverter.ITEM_TYPE, itemType).set(ItemCfRecommModelDataConverter.USER_TYPE, FlinkTypeConverter.getTypeString(userType));
            }
            new ItemCfRecommModelDataConverter(userCol, userType, itemCol).save(Tuple2.of(meta, values), out);
        }
    }).withBroadcastSet(itemMapTable, "ITEM_MAP").name("build_model");
    this.setOutput(outs, new ItemCfRecommModelDataConverter(userCol, userType, itemCol).getModelSchema());
    return this;
}
Also used : ItemCfRecommModelDataConverter(com.alibaba.alink.operator.common.recommendation.ItemCfRecommModelDataConverter) OneHotTrainBatchOp(com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp) VectorNearestNeighborTrainBatchOp(com.alibaba.alink.operator.batch.similarity.VectorNearestNeighborTrainBatchOp) List(java.util.List) MultiStringIndexerModelData(com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData) OneHotPredictBatchOp(com.alibaba.alink.operator.batch.feature.OneHotPredictBatchOp) VectorNearestNeighborPredictBatchOp(com.alibaba.alink.operator.batch.similarity.VectorNearestNeighborPredictBatchOp) ItemCfRecommTrainParams(com.alibaba.alink.params.recommendation.ItemCfRecommTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) DataSetWrapperBatchOp(com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp) OneHotModelDataConverter(com.alibaba.alink.operator.common.feature.OneHotModelDataConverter) Row(org.apache.flink.types.Row)

Example 2 with MultiStringIndexerModelData

use of com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData in project Alink by alibaba.

the class NaiveBayesModelInfo method getCategoryFeatureInfo.

/**
 * This function gets the feature information of categorical features.
 * For each categorical feature, this function calculates the proportion among all the labels.
 */
public HashMap<Object, HashMap<Object, HashMap<Object, Double>>> getCategoryFeatureInfo() {
    MultiStringIndexerModelData model = new MultiStringIndexerModelDataConverter().load(stringIndexerModelSerialized);
    if (model.meta == null || !model.meta.contains(HasSelectedCols.SELECTED_COLS)) {
        return new HashMap<>(0);
    }
    HashMap<Object, HashMap<Object, HashMap<Object, Double>>> labelFeatureMap = new HashMap<>(labelSize);
    String[] cateCols = model.meta.get(HasSelectedCols.SELECTED_COLS);
    int tokenNumber = cateCols.length;
    HashMap<Long, String>[] tokenIndex = new HashMap[tokenNumber];
    for (int i = 0; i < tokenNumber; i++) {
        tokenIndex[i] = new HashMap<>((int) model.getNumberOfTokensOfColumn(cateCols[i]));
    }
    for (Tuple3<Integer, String, Long> tuple3 : model.tokenAndIndex) {
        tokenIndex[tuple3.f0].put(tuple3.f2, tuple3.f1);
    }
    int cateIndex = 0;
    for (int i = 0; i < featureSize; i++) {
        if (isCategorical[i]) {
            String featureName = featureNames[i];
            HashSet<Object> featureValue = new HashSet<>();
            double[] featureSum = new double[Math.toIntExact(model.getNumberOfTokensOfColumn(cateCols[cateIndex]))];
            for (int j = 0; j < labelSize; j++) {
                SparseVector sv = featureInfo[j][i];
                int[] svIndices = sv.getIndices();
                double[] svValues = sv.getValues();
                // the value number of this feature.
                int feaValNum = svIndices.length;
                for (int k = 0; k < feaValNum; k++) {
                    featureSum[svIndices[k]] += svValues[k];
                }
            }
            for (int j = 0; j < labelSize; j++) {
                SparseVector sv = featureInfo[j][i];
                int[] svIndices = sv.getIndices();
                double[] svValues = sv.getValues();
                int feaValNum = svIndices.length;
                HashMap<Object, HashMap<Object, Double>> v;
                if (!labelFeatureMap.containsKey(labels[j])) {
                    v = new HashMap<>();
                } else {
                    v = labelFeatureMap.get(labels[j]);
                }
                HashMap<Object, Double> featureValues = new HashMap<>();
                for (int k = 0; k < feaValNum; k++) {
                    Object key = tokenIndex[cateIndex].get((long) svIndices[k]);
                    featureValue.add(key);
                    double value = svValues[k] / featureSum[svIndices[k]];
                    featureValues.put(key, value);
                }
                v.put(featureName, featureValues);
                labelFeatureMap.put(labels[j], v);
            }
            cateIndex++;
            cateFeatureValue.put(featureName, featureValue);
        }
    }
    // transform
    List<String> listFeature = new ArrayList<>();
    for (int i = 0; i < featureSize; i++) {
        if (isCategorical[i]) {
            listFeature.add(featureNames[i]);
        }
    }
    HashMap<Object, HashMap<Object, HashMap<Object, Double>>> res = new HashMap<>(featureSize);
    for (String o : listFeature) {
        HashMap<Object, HashMap<Object, Double>> labelMap = new HashMap<>(labelSize);
        for (Object label : labels) {
            labelMap.put(label, labelFeatureMap.get(label).get(o));
        }
        res.put(o, labelMap);
    }
    return res;
}
Also used : MultiStringIndexerModelDataConverter(com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelDataConverter) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) SparseVector(com.alibaba.alink.common.linalg.SparseVector) MultiStringIndexerModelData(com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData) HashSet(java.util.HashSet)

Example 3 with MultiStringIndexerModelData

use of com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData in project Alink by alibaba.

the class CrossFeatureModelMapper method loadModel.

@Override
public void loadModel(List<Row> modelRows) {
    MultiStringIndexerModelData data = new MultiStringIndexerModelDataConverter().load(modelRows);
    String[] selectedCols = data.meta.get(CrossFeatureTrainParams.SELECTED_COLS);
    selectedColIndices = TableUtil.findColIndices(dataColNames, selectedCols);
    int featureNumber = data.tokenNumber.size();
    tokenAndIndex = new HashMap[featureNumber];
    nullIndex = new int[featureNumber];
    Arrays.fill(nullIndex, -1);
    carry = new int[featureNumber];
    carry[0] = 1;
    for (int i = 0; i < featureNumber - 1; i++) {
        carry[i + 1] = (int) ((data.tokenNumber.get(i)) * carry[i]);
    }
    svLength = carry[featureNumber - 1] * (data.tokenNumber.get(featureNumber - 1).intValue());
    for (int i = 0; i < featureNumber; i++) {
        int thisSize = data.tokenNumber.get(i).intValue();
        tokenAndIndex[i] = new HashMap<>(thisSize);
    }
    for (Tuple3<Integer, String, Long> tuple3 : data.tokenAndIndex) {
        if (tuple3.f1 == null) {
            nullIndex[tuple3.f0] = tuple3.f2.intValue();
        } else {
            tokenAndIndex[tuple3.f0].put(tuple3.f1, tuple3.f2.intValue());
        }
    }
    dataIndices = new int[featureNumber];
}
Also used : MultiStringIndexerModelDataConverter(com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelDataConverter) MultiStringIndexerModelData(com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData)

Aggregations

MultiStringIndexerModelData (com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData)3 MultiStringIndexerModelDataConverter (com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelDataConverter)2 SparseVector (com.alibaba.alink.common.linalg.SparseVector)1 OneHotPredictBatchOp (com.alibaba.alink.operator.batch.feature.OneHotPredictBatchOp)1 OneHotTrainBatchOp (com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp)1 VectorNearestNeighborPredictBatchOp (com.alibaba.alink.operator.batch.similarity.VectorNearestNeighborPredictBatchOp)1 VectorNearestNeighborTrainBatchOp (com.alibaba.alink.operator.batch.similarity.VectorNearestNeighborTrainBatchOp)1 DataSetWrapperBatchOp (com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp)1 OneHotModelDataConverter (com.alibaba.alink.operator.common.feature.OneHotModelDataConverter)1 ItemCfRecommModelDataConverter (com.alibaba.alink.operator.common.recommendation.ItemCfRecommModelDataConverter)1 ItemCfRecommTrainParams (com.alibaba.alink.params.recommendation.ItemCfRecommTrainParams)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 HashSet (java.util.HashSet)1 List (java.util.List)1 Params (org.apache.flink.ml.api.misc.param.Params)1 Row (org.apache.flink.types.Row)1