use of com.alibaba.alink.operator.common.clustering.lda.OnlineLogLikelihood in project Alink by alibaba.
the class LdaTrainBatchOp method online.
private void online(Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> dataAndStat, int numTopic, int numIter, double alpha, double beta, DataSet<DocCountVectorizerModelData> resDocCountModel, int gammaShape, Integer seed) {
if (beta == -1) {
beta = 1.0 / numTopic;
}
if (alpha == -1) {
alpha = 1.0 / numTopic;
}
double learningOffset = getParams().get(ONLINE_LEARNING_OFFSET);
double learningDecay = getParams().get(LEARNING_DECAY);
double subSamplingRate = getParams().get(SUBSAMPLING_RATE);
boolean optimizeDocConcentration = getParams().get(OPTIMIZE_DOC_CONCENTRATION);
DataSet<Vector> data = dataAndStat.f0;
DataSet<Tuple2<Long, Integer>> shape = dataAndStat.f1.map(new MapFunction<BaseVectorSummary, Tuple2<Long, Integer>>() {
private static final long serialVersionUID = 1305270477796787466L;
@Override
public Tuple2<Long, Integer> map(BaseVectorSummary srt) {
return new Tuple2<>(srt.count(), srt.vectorSize());
}
});
DataSet<Tuple2<DenseMatrix, DenseMatrix>> initModel = data.mapPartition(new OnlineInit(numTopic, gammaShape, alpha, seed)).name("init lambda").withBroadcastSet(shape, LdaVariable.shape);
DataSet<Row> ldaModelData = new IterativeComQueue().initWithPartitionedData(LdaVariable.data, data).initWithBroadcastData(LdaVariable.shape, shape).initWithBroadcastData(LdaVariable.initModel, initModel).add(new OnlineCorpusStep(numTopic, subSamplingRate, gammaShape, seed)).add(new AllReduce(LdaVariable.wordTopicStat)).add(new AllReduce(LdaVariable.logPhatPart)).add(new AllReduce(LdaVariable.nonEmptyWordCount)).add(new AllReduce(LdaVariable.nonEmptyDocCount)).add(new UpdateLambdaAndAlpha(numTopic, learningOffset, learningDecay, subSamplingRate, optimizeDocConcentration, beta)).add(new OnlineLogLikelihood(beta, numTopic, numIter, gammaShape, seed)).add(new AllReduce(LdaVariable.logLikelihood)).closeWith(new BuildOnlineLdaModel(numTopic, beta)).setMaxIter(numIter).exec();
DataSet<Row> model = ldaModelData.flatMap(new BuildResModel(seed)).withBroadcastSet(resDocCountModel, "DocCountModel");
setOutput(model, new LdaModelDataConverter().getModelSchema());
saveWordTopicModelAndPerplexity(model, numTopic, true);
}
Aggregations