Search in sources :

Example 1 with AggregateSkipGram

use of org.nd4j.linalg.api.ops.aggregates.impl.AggregateSkipGram in project deeplearning4j by deeplearning4j.

the class SkipGram method iterateSample.

public double iterateSample(T w1, T lastWord, AtomicLong nextRandom, double alpha, boolean isInference, INDArray inferenceVector) {
    if (w1 == null || lastWord == null || (lastWord.getIndex() < 0 && !isInference) || w1.getIndex() == lastWord.getIndex() || w1.getLabel().equals("STOP") || lastWord.getLabel().equals("STOP") || w1.getLabel().equals("UNK") || lastWord.getLabel().equals("UNK")) {
        return 0.0;
    }
    double score = 0.0;
    int[] idxSyn1 = null;
    int[] codes = null;
    if (configuration.isUseHierarchicSoftmax()) {
        idxSyn1 = new int[w1.getCodeLength()];
        codes = new int[w1.getCodeLength()];
        for (int i = 0; i < w1.getCodeLength(); i++) {
            int code = w1.getCodes().get(i);
            int point = w1.getPoints().get(i);
            if (point >= vocabCache.numWords() || point < 0)
                continue;
            codes[i] = code;
            idxSyn1[i] = point;
        }
    } else {
        idxSyn1 = new int[0];
        codes = new int[0];
    }
    int target = w1.getIndex();
    //negative sampling
    if (negative > 0) {
        if (syn1Neg == null) {
            ((InMemoryLookupTable<T>) lookupTable).initNegative();
            syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
        }
    }
    if (batches.get() == null) {
        batches.set(new ArrayList<Aggregate>());
    }
    //log.info("VocabWords: {}; lastWordIndex: {}; syn1neg: {}", vocabCache.numWords(), lastWord.getIndex(), syn1Neg.get().rows());
    AggregateSkipGram sg = new AggregateSkipGram(syn0.get(), syn1.get(), syn1Neg.get(), expTable.get(), table.get(), lastWord.getIndex(), idxSyn1, codes, (int) negative, target, vectorLength, alpha, nextRandom.get(), vocabCache.numWords(), inferenceVector);
    nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
    if (!isInference)
        batches.get().add(sg);
    else
        Nd4j.getExecutioner().exec(sg);
    return score;
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) DeviceLocalNDArray(org.nd4j.linalg.util.DeviceLocalNDArray) AggregateSkipGram(org.nd4j.linalg.api.ops.aggregates.impl.AggregateSkipGram) Aggregate(org.nd4j.linalg.api.ops.aggregates.Aggregate)

Aggregations

InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)1 Aggregate (org.nd4j.linalg.api.ops.aggregates.Aggregate)1 AggregateSkipGram (org.nd4j.linalg.api.ops.aggregates.impl.AggregateSkipGram)1 DeviceLocalNDArray (org.nd4j.linalg.util.DeviceLocalNDArray)1