use of com.alibaba.alink.common.linalg.VectorIterator in project Alink by alibaba.
the class EmCorpusStep method calc.
@Override
public void calc(ComContext context) {
if (!addedIndex && seed != null) {
rand.reSeed(seed);
addedIndex = true;
}
int vocabularySize = ((List<Integer>) context.getObj(LdaVariable.vocabularySize)).get(0);
// initialize the params.
if (context.getStepNo() == 1) {
DenseMatrix nWordTopics = new DenseMatrix(vocabularySize + 1, numTopic);
context.putObj(LdaVariable.nWordTopics, nWordTopics.getData());
List<SparseVector> data = context.getObj(LdaVariable.data);
if (data == null) {
return;
}
// the size of docs.
int localDocSize = data.size();
Document[] docs = new Document[localDocSize];
DenseMatrix nDocTopics = new DenseMatrix(localDocSize, numTopic);
int docId = 0;
int topic, word;
for (SparseVector sparseVector : data) {
int wordNum = 0;
for (double value : sparseVector.getValues()) {
wordNum += value;
}
Document doc = new Document(wordNum);
int idx = 0;
VectorIterator iter = sparseVector.iterator();
while (iter.hasNext()) {
word = iter.getIndex();
for (int j = 0; j < (int) iter.getValue(); j++) {
topic = rand.nextInt(0, numTopic - 1);
doc.setWordIdxs(idx, word);
doc.setTopicIdxs(idx, topic);
updateDocWordTopics(nDocTopics, nWordTopics, docId, word, vocabularySize, topic, 1);
idx++;
}
iter.next();
}
docs[docId] = doc;
docId++;
}
context.putObj(LdaVariable.corpus, docs);
context.putObj(LdaVariable.nDocTopics, nDocTopics);
context.removeObj(LdaVariable.data);
} else {
Document[] docs = context.getObj(LdaVariable.corpus);
if (docs == null) {
return;
}
DenseMatrix nDocTopics = context.getObj(LdaVariable.nDocTopics);
DenseMatrix nWordTopics = new DenseMatrix(vocabularySize + 1, numTopic, context.getObj(LdaVariable.nWordTopics), false);
int docId = 0;
double[] p = new double[numTopic];
double pSum;
int newTopic;
// update params with each doc.
for (Document doc : docs) {
int wordCount = doc.getLength();
for (int i = 0; i < wordCount; ++i) {
int word = doc.getWordIdxs(i);
int topic = doc.getTopicIdxs(i);
// choose the word and minus its topic
updateDocWordTopics(nDocTopics, nWordTopics, docId, word, vocabularySize, topic, -1);
pSum = 0;
for (int k = 0; k < numTopic; k++) {
// calculate the probability that word belongs to each topic, and then generate the topic.
pSum += (nWordTopics.get(word, k) + beta) * (nDocTopics.get(docId, k) + alpha) / (nWordTopics.get(vocabularySize, k) + vocabularySize * beta);
p[k] = pSum;
}
double u = rand.nextUniform(0, 1) * pSum;
newTopic = findProbIdx(p, u);
doc.setTopicIdxs(i, newTopic);
// update the word and its new topic.
updateDocWordTopics(nDocTopics, nWordTopics, docId, word, vocabularySize, newTopic, 1);
}
docId++;
}
nWordTopics = new DenseMatrix(nWordTopics.numRows(), nWordTopics.numCols());
for (Document doc : docs) {
int length = doc.getLength();
for (int i = 0; i < length; i++) {
nWordTopics.add(doc.getWordIdxs(i), doc.getTopicIdxs(i), 1);
nWordTopics.add(vocabularySize, doc.getTopicIdxs(i), 1);
}
}
context.putObj(LdaVariable.nWordTopics, nWordTopics.getData());
}
}
use of com.alibaba.alink.common.linalg.VectorIterator in project Alink by alibaba.
the class SparseVectorSummarizer method visit.
/**
* update by vector.
*/
@Override
public BaseVectorSummarizer visit(Vector vec) {
SparseVector sv;
if (vec instanceof DenseVector) {
DenseVector dv = (DenseVector) vec;
int[] indices = new int[dv.size()];
for (int i = 0; i < dv.size(); i++) {
indices[i] = i;
}
sv = new SparseVector(dv.size(), indices, dv.getData());
} else {
sv = (SparseVector) vec;
}
count++;
this.colNum = Math.max(this.colNum, sv.size());
if (sv.numberOfValues() != 0) {
// max index + 1 for size.
VectorIterator iter = sv.iterator();
while (iter.hasNext()) {
int index = iter.getIndex();
double value = iter.getValue();
if (cols.containsKey(index)) {
cols.get(index).visit(value);
} else {
VectorStatCol statCol = new VectorStatCol();
statCol.visit(value);
cols.put(index, statCol);
}
iter.next();
}
if (calculateOuterProduct) {
int size = sv.getIndices()[sv.getIndices().length - 1] + 1;
if (outerProduct == null) {
outerProduct = DenseMatrix.zeros(size, size);
} else {
if (size > outerProduct.numRows()) {
DenseMatrix dpNew = DenseMatrix.zeros(size, size);
if (outerProduct != null) {
outerProduct = VectorSummarizerUtil.plusEqual(dpNew, outerProduct);
}
}
}
for (int i = 0; i < sv.getIndices().length; i++) {
double val = sv.getValues()[i];
int iIdx = sv.getIndices()[i];
for (int j = 0; j < sv.getIndices().length; j++) {
outerProduct.add(iIdx, sv.getIndices()[j], val * sv.getValues()[j]);
}
}
}
}
return this;
}
Aggregations