use of com.alibaba.alink.operator.common.clustering.lda.BuildEmLdaModel in project Alink by alibaba.
the class LdaTrainBatchOp method gibbsSample.
private void gibbsSample(Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> dataAndStat, int numTopic, int numIter, double alpha, double beta, DataSet<DocCountVectorizerModelData> resDocCountModel, Integer seed) {
if (beta == -1) {
beta = 0.01 + 1;
}
if (alpha == -1) {
alpha = 50.0 / numTopic + 1;
}
DataSet<Vector> data = dataAndStat.f0;
DataSet<Integer> colNum = dataAndStat.f1.map(new MapFunction<BaseVectorSummary, Integer>() {
private static final long serialVersionUID = -7170259222827300492L;
@Override
public Integer map(BaseVectorSummary srt) {
return srt.vectorSize();
}
});
DataSet<Row> ldaModelData = new IterativeComQueue().initWithPartitionedData(LdaVariable.data, data).initWithBroadcastData(LdaVariable.vocabularySize, colNum).add(new EmCorpusStep(numTopic, alpha, beta, seed)).add(new AllReduce(LdaVariable.nWordTopics)).add(new EmLogLikelihood(numTopic, alpha, beta, numIter)).add(new AllReduce(LdaVariable.logLikelihood)).closeWith(new BuildEmLdaModel(numTopic, alpha, beta)).setMaxIter(numIter).exec();
DataSet<Row> model = ldaModelData.flatMap(new BuildResModel(seed)).withBroadcastSet(resDocCountModel, "DocCountModel");
setOutput(model, new LdaModelDataConverter().getModelSchema());
saveWordTopicModelAndPerplexity(model, numTopic, false);
}
Aggregations