use of org.nd4j.linalg.api.ops.aggregates.impl.AggregateCBOW in project deeplearning4j by deeplearning4j.
the class CBOW method iterateSample.
public void iterateSample(T currentWord, int[] windowWords, AtomicLong nextRandom, double alpha, boolean isInference, int numLabels, boolean trainWords, INDArray inferenceVector) {
int[] idxSyn1 = null;
int[] codes = null;
if (configuration.isUseHierarchicSoftmax()) {
idxSyn1 = new int[currentWord.getCodeLength()];
codes = new int[currentWord.getCodeLength()];
for (int p = 0; p < currentWord.getCodeLength(); p++) {
if (currentWord.getPoints().get(p) < 0)
continue;
codes[p] = currentWord.getCodes().get(p);
idxSyn1[p] = currentWord.getPoints().get(p);
}
} else {
idxSyn1 = new int[0];
codes = new int[0];
}
if (negative > 0) {
if (syn1Neg == null) {
((InMemoryLookupTable<T>) lookupTable).initNegative();
syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
}
}
if (batches.get() == null)
batches.set(new ArrayList<Aggregate>());
AggregateCBOW cbow = new AggregateCBOW(syn0.get(), syn1.get(), syn1Neg.get(), expTable.get(), table.get(), currentWord.getIndex(), windowWords, idxSyn1, codes, (int) negative, currentWord.getIndex(), lookupTable.layerSize(), alpha, nextRandom.get(), vocabCache.numWords(), numLabels, trainWords, inferenceVector);
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
if (!isInference)
batches.get().add(cbow);
else
Nd4j.getExecutioner().exec(cbow);
}
Aggregations