Search in sources :

Example 1 with VectorIterator

use of com.alibaba.alink.common.linalg.VectorIterator in project Alink by alibaba.

the class EmCorpusStep method calc.

@Override
public void calc(ComContext context) {
    if (!addedIndex && seed != null) {
        rand.reSeed(seed);
        addedIndex = true;
    }
    int vocabularySize = ((List<Integer>) context.getObj(LdaVariable.vocabularySize)).get(0);
    // initialize the params.
    if (context.getStepNo() == 1) {
        DenseMatrix nWordTopics = new DenseMatrix(vocabularySize + 1, numTopic);
        context.putObj(LdaVariable.nWordTopics, nWordTopics.getData());
        List<SparseVector> data = context.getObj(LdaVariable.data);
        if (data == null) {
            return;
        }
        // the size of docs.
        int localDocSize = data.size();
        Document[] docs = new Document[localDocSize];
        DenseMatrix nDocTopics = new DenseMatrix(localDocSize, numTopic);
        int docId = 0;
        int topic, word;
        for (SparseVector sparseVector : data) {
            int wordNum = 0;
            for (double value : sparseVector.getValues()) {
                wordNum += value;
            }
            Document doc = new Document(wordNum);
            int idx = 0;
            VectorIterator iter = sparseVector.iterator();
            while (iter.hasNext()) {
                word = iter.getIndex();
                for (int j = 0; j < (int) iter.getValue(); j++) {
                    topic = rand.nextInt(0, numTopic - 1);
                    doc.setWordIdxs(idx, word);
                    doc.setTopicIdxs(idx, topic);
                    updateDocWordTopics(nDocTopics, nWordTopics, docId, word, vocabularySize, topic, 1);
                    idx++;
                }
                iter.next();
            }
            docs[docId] = doc;
            docId++;
        }
        context.putObj(LdaVariable.corpus, docs);
        context.putObj(LdaVariable.nDocTopics, nDocTopics);
        context.removeObj(LdaVariable.data);
    } else {
        Document[] docs = context.getObj(LdaVariable.corpus);
        if (docs == null) {
            return;
        }
        DenseMatrix nDocTopics = context.getObj(LdaVariable.nDocTopics);
        DenseMatrix nWordTopics = new DenseMatrix(vocabularySize + 1, numTopic, context.getObj(LdaVariable.nWordTopics), false);
        int docId = 0;
        double[] p = new double[numTopic];
        double pSum;
        int newTopic;
        // update params with each doc.
        for (Document doc : docs) {
            int wordCount = doc.getLength();
            for (int i = 0; i < wordCount; ++i) {
                int word = doc.getWordIdxs(i);
                int topic = doc.getTopicIdxs(i);
                // choose the word and minus its topic
                updateDocWordTopics(nDocTopics, nWordTopics, docId, word, vocabularySize, topic, -1);
                pSum = 0;
                for (int k = 0; k < numTopic; k++) {
                    // calculate the probability that word belongs to each topic, and then generate the topic.
                    pSum += (nWordTopics.get(word, k) + beta) * (nDocTopics.get(docId, k) + alpha) / (nWordTopics.get(vocabularySize, k) + vocabularySize * beta);
                    p[k] = pSum;
                }
                double u = rand.nextUniform(0, 1) * pSum;
                newTopic = findProbIdx(p, u);
                doc.setTopicIdxs(i, newTopic);
                // update the word and its new topic.
                updateDocWordTopics(nDocTopics, nWordTopics, docId, word, vocabularySize, newTopic, 1);
            }
            docId++;
        }
        nWordTopics = new DenseMatrix(nWordTopics.numRows(), nWordTopics.numCols());
        for (Document doc : docs) {
            int length = doc.getLength();
            for (int i = 0; i < length; i++) {
                nWordTopics.add(doc.getWordIdxs(i), doc.getTopicIdxs(i), 1);
                nWordTopics.add(vocabularySize, doc.getTopicIdxs(i), 1);
            }
        }
        context.putObj(LdaVariable.nWordTopics, nWordTopics.getData());
    }
}
Also used : List(java.util.List) SparseVector(com.alibaba.alink.common.linalg.SparseVector) VectorIterator(com.alibaba.alink.common.linalg.VectorIterator) DenseMatrix(com.alibaba.alink.common.linalg.DenseMatrix)

Example 2 with VectorIterator

use of com.alibaba.alink.common.linalg.VectorIterator in project Alink by alibaba.

the class SparseVectorSummarizer method visit.

/**
 * update by vector.
 */
@Override
public BaseVectorSummarizer visit(Vector vec) {
    SparseVector sv;
    if (vec instanceof DenseVector) {
        DenseVector dv = (DenseVector) vec;
        int[] indices = new int[dv.size()];
        for (int i = 0; i < dv.size(); i++) {
            indices[i] = i;
        }
        sv = new SparseVector(dv.size(), indices, dv.getData());
    } else {
        sv = (SparseVector) vec;
    }
    count++;
    this.colNum = Math.max(this.colNum, sv.size());
    if (sv.numberOfValues() != 0) {
        // max index + 1 for size.
        VectorIterator iter = sv.iterator();
        while (iter.hasNext()) {
            int index = iter.getIndex();
            double value = iter.getValue();
            if (cols.containsKey(index)) {
                cols.get(index).visit(value);
            } else {
                VectorStatCol statCol = new VectorStatCol();
                statCol.visit(value);
                cols.put(index, statCol);
            }
            iter.next();
        }
        if (calculateOuterProduct) {
            int size = sv.getIndices()[sv.getIndices().length - 1] + 1;
            if (outerProduct == null) {
                outerProduct = DenseMatrix.zeros(size, size);
            } else {
                if (size > outerProduct.numRows()) {
                    DenseMatrix dpNew = DenseMatrix.zeros(size, size);
                    if (outerProduct != null) {
                        outerProduct = VectorSummarizerUtil.plusEqual(dpNew, outerProduct);
                    }
                }
            }
            for (int i = 0; i < sv.getIndices().length; i++) {
                double val = sv.getValues()[i];
                int iIdx = sv.getIndices()[i];
                for (int j = 0; j < sv.getIndices().length; j++) {
                    outerProduct.add(iIdx, sv.getIndices()[j], val * sv.getValues()[j]);
                }
            }
        }
    }
    return this;
}
Also used : SparseVector(com.alibaba.alink.common.linalg.SparseVector) DenseVector(com.alibaba.alink.common.linalg.DenseVector) VectorIterator(com.alibaba.alink.common.linalg.VectorIterator) DenseMatrix(com.alibaba.alink.common.linalg.DenseMatrix)

Aggregations

DenseMatrix (com.alibaba.alink.common.linalg.DenseMatrix)2 SparseVector (com.alibaba.alink.common.linalg.SparseVector)2 VectorIterator (com.alibaba.alink.common.linalg.VectorIterator)2 DenseVector (com.alibaba.alink.common.linalg.DenseVector)1 List (java.util.List)1