Search in sources :

Example 1 with OnlineLogLikelihood

use of com.alibaba.alink.operator.common.clustering.lda.OnlineLogLikelihood in project Alink by alibaba.

the class LdaTrainBatchOp method online.

private void online(Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> dataAndStat, int numTopic, int numIter, double alpha, double beta, DataSet<DocCountVectorizerModelData> resDocCountModel, int gammaShape, Integer seed) {
    if (beta == -1) {
        beta = 1.0 / numTopic;
    }
    if (alpha == -1) {
        alpha = 1.0 / numTopic;
    }
    double learningOffset = getParams().get(ONLINE_LEARNING_OFFSET);
    double learningDecay = getParams().get(LEARNING_DECAY);
    double subSamplingRate = getParams().get(SUBSAMPLING_RATE);
    boolean optimizeDocConcentration = getParams().get(OPTIMIZE_DOC_CONCENTRATION);
    DataSet<Vector> data = dataAndStat.f0;
    DataSet<Tuple2<Long, Integer>> shape = dataAndStat.f1.map(new MapFunction<BaseVectorSummary, Tuple2<Long, Integer>>() {

        private static final long serialVersionUID = 1305270477796787466L;

        @Override
        public Tuple2<Long, Integer> map(BaseVectorSummary srt) {
            return new Tuple2<>(srt.count(), srt.vectorSize());
        }
    });
    DataSet<Tuple2<DenseMatrix, DenseMatrix>> initModel = data.mapPartition(new OnlineInit(numTopic, gammaShape, alpha, seed)).name("init lambda").withBroadcastSet(shape, LdaVariable.shape);
    DataSet<Row> ldaModelData = new IterativeComQueue().initWithPartitionedData(LdaVariable.data, data).initWithBroadcastData(LdaVariable.shape, shape).initWithBroadcastData(LdaVariable.initModel, initModel).add(new OnlineCorpusStep(numTopic, subSamplingRate, gammaShape, seed)).add(new AllReduce(LdaVariable.wordTopicStat)).add(new AllReduce(LdaVariable.logPhatPart)).add(new AllReduce(LdaVariable.nonEmptyWordCount)).add(new AllReduce(LdaVariable.nonEmptyDocCount)).add(new UpdateLambdaAndAlpha(numTopic, learningOffset, learningDecay, subSamplingRate, optimizeDocConcentration, beta)).add(new OnlineLogLikelihood(beta, numTopic, numIter, gammaShape, seed)).add(new AllReduce(LdaVariable.logLikelihood)).closeWith(new BuildOnlineLdaModel(numTopic, beta)).setMaxIter(numIter).exec();
    DataSet<Row> model = ldaModelData.flatMap(new BuildResModel(seed)).withBroadcastSet(resDocCountModel, "DocCountModel");
    setOutput(model, new LdaModelDataConverter().getModelSchema());
    saveWordTopicModelAndPerplexity(model, numTopic, true);
}
Also used : IterativeComQueue(com.alibaba.alink.common.comqueue.IterativeComQueue) AllReduce(com.alibaba.alink.common.comqueue.communication.AllReduce) UpdateLambdaAndAlpha(com.alibaba.alink.operator.common.clustering.lda.UpdateLambdaAndAlpha) OnlineLogLikelihood(com.alibaba.alink.operator.common.clustering.lda.OnlineLogLikelihood) BuildOnlineLdaModel(com.alibaba.alink.operator.common.clustering.lda.BuildOnlineLdaModel) Tuple2(org.apache.flink.api.java.tuple.Tuple2) OnlineCorpusStep(com.alibaba.alink.operator.common.clustering.lda.OnlineCorpusStep) LdaModelDataConverter(com.alibaba.alink.operator.common.clustering.LdaModelDataConverter) BaseVectorSummary(com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary) Row(org.apache.flink.types.Row) Vector(com.alibaba.alink.common.linalg.Vector) SparseVector(com.alibaba.alink.common.linalg.SparseVector)

Aggregations

IterativeComQueue (com.alibaba.alink.common.comqueue.IterativeComQueue)1 AllReduce (com.alibaba.alink.common.comqueue.communication.AllReduce)1 SparseVector (com.alibaba.alink.common.linalg.SparseVector)1 Vector (com.alibaba.alink.common.linalg.Vector)1 LdaModelDataConverter (com.alibaba.alink.operator.common.clustering.LdaModelDataConverter)1 BuildOnlineLdaModel (com.alibaba.alink.operator.common.clustering.lda.BuildOnlineLdaModel)1 OnlineCorpusStep (com.alibaba.alink.operator.common.clustering.lda.OnlineCorpusStep)1 OnlineLogLikelihood (com.alibaba.alink.operator.common.clustering.lda.OnlineLogLikelihood)1 UpdateLambdaAndAlpha (com.alibaba.alink.operator.common.clustering.lda.UpdateLambdaAndAlpha)1 BaseVectorSummary (com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary)1 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)1 Row (org.apache.flink.types.Row)1