use of org.nd4j.linalg.util.DeviceLocalNDArray in project deeplearning4j by deeplearning4j.
the class SkipGram method configure.
/**
* SkipGram initialization over given vocabulary and WeightLookupTable
*
* @param vocabCache
* @param lookupTable
* @param configuration
*/
@Override
public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration) {
this.vocabCache = vocabCache;
this.lookupTable = lookupTable;
this.configuration = configuration;
if (configuration.getNegative() > 0) {
if (((InMemoryLookupTable<T>) lookupTable).getSyn1Neg() == null) {
log.info("Initializing syn1Neg...");
((InMemoryLookupTable<T>) lookupTable).setUseHS(configuration.isUseHierarchicSoftmax());
((InMemoryLookupTable<T>) lookupTable).setNegative(configuration.getNegative());
((InMemoryLookupTable<T>) lookupTable).resetWeights(false);
}
}
this.expTable = new DeviceLocalNDArray(Nd4j.create(((InMemoryLookupTable<T>) lookupTable).getExpTable()));
this.syn0 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn0());
this.syn1 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1());
this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
this.table = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getTable());
this.window = configuration.getWindow();
this.useAdaGrad = configuration.isUseAdaGrad();
this.negative = configuration.getNegative();
this.sampling = configuration.getSampling();
this.variableWindows = configuration.getVariableWindows();
this.vectorLength = configuration.getLayersSize();
}
use of org.nd4j.linalg.util.DeviceLocalNDArray in project deeplearning4j by deeplearning4j.
the class CBOW method configure.
@Override
public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration) {
this.vocabCache = vocabCache;
this.lookupTable = lookupTable;
this.configuration = configuration;
this.window = configuration.getWindow();
this.useAdaGrad = configuration.isUseAdaGrad();
this.negative = configuration.getNegative();
this.sampling = configuration.getSampling();
if (configuration.getNegative() > 0) {
if (((InMemoryLookupTable<T>) lookupTable).getSyn1Neg() == null) {
logger.info("Initializing syn1Neg...");
((InMemoryLookupTable<T>) lookupTable).setUseHS(configuration.isUseHierarchicSoftmax());
((InMemoryLookupTable<T>) lookupTable).setNegative(configuration.getNegative());
((InMemoryLookupTable<T>) lookupTable).resetWeights(false);
}
}
this.syn0 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn0());
this.syn1 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1());
this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
this.expTable = new DeviceLocalNDArray(Nd4j.create(((InMemoryLookupTable<T>) lookupTable).getExpTable()));
this.table = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getTable());
this.variableWindows = configuration.getVariableWindows();
}
use of org.nd4j.linalg.util.DeviceLocalNDArray in project deeplearning4j by deeplearning4j.
the class CBOW method iterateSample.
public void iterateSample(T currentWord, int[] windowWords, AtomicLong nextRandom, double alpha, boolean isInference, int numLabels, boolean trainWords, INDArray inferenceVector) {
int[] idxSyn1 = null;
int[] codes = null;
if (configuration.isUseHierarchicSoftmax()) {
idxSyn1 = new int[currentWord.getCodeLength()];
codes = new int[currentWord.getCodeLength()];
for (int p = 0; p < currentWord.getCodeLength(); p++) {
if (currentWord.getPoints().get(p) < 0)
continue;
codes[p] = currentWord.getCodes().get(p);
idxSyn1[p] = currentWord.getPoints().get(p);
}
} else {
idxSyn1 = new int[0];
codes = new int[0];
}
if (negative > 0) {
if (syn1Neg == null) {
((InMemoryLookupTable<T>) lookupTable).initNegative();
syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
}
}
if (batches.get() == null)
batches.set(new ArrayList<Aggregate>());
AggregateCBOW cbow = new AggregateCBOW(syn0.get(), syn1.get(), syn1Neg.get(), expTable.get(), table.get(), currentWord.getIndex(), windowWords, idxSyn1, codes, (int) negative, currentWord.getIndex(), lookupTable.layerSize(), alpha, nextRandom.get(), vocabCache.numWords(), numLabels, trainWords, inferenceVector);
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
if (!isInference)
batches.get().add(cbow);
else
Nd4j.getExecutioner().exec(cbow);
}
use of org.nd4j.linalg.util.DeviceLocalNDArray in project deeplearning4j by deeplearning4j.
the class SkipGram method iterateSample.
public double iterateSample(T w1, T lastWord, AtomicLong nextRandom, double alpha, boolean isInference, INDArray inferenceVector) {
if (w1 == null || lastWord == null || (lastWord.getIndex() < 0 && !isInference) || w1.getIndex() == lastWord.getIndex() || w1.getLabel().equals("STOP") || lastWord.getLabel().equals("STOP") || w1.getLabel().equals("UNK") || lastWord.getLabel().equals("UNK")) {
return 0.0;
}
double score = 0.0;
int[] idxSyn1 = null;
int[] codes = null;
if (configuration.isUseHierarchicSoftmax()) {
idxSyn1 = new int[w1.getCodeLength()];
codes = new int[w1.getCodeLength()];
for (int i = 0; i < w1.getCodeLength(); i++) {
int code = w1.getCodes().get(i);
int point = w1.getPoints().get(i);
if (point >= vocabCache.numWords() || point < 0)
continue;
codes[i] = code;
idxSyn1[i] = point;
}
} else {
idxSyn1 = new int[0];
codes = new int[0];
}
int target = w1.getIndex();
//negative sampling
if (negative > 0) {
if (syn1Neg == null) {
((InMemoryLookupTable<T>) lookupTable).initNegative();
syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg());
}
}
if (batches.get() == null) {
batches.set(new ArrayList<Aggregate>());
}
//log.info("VocabWords: {}; lastWordIndex: {}; syn1neg: {}", vocabCache.numWords(), lastWord.getIndex(), syn1Neg.get().rows());
AggregateSkipGram sg = new AggregateSkipGram(syn0.get(), syn1.get(), syn1Neg.get(), expTable.get(), table.get(), lastWord.getIndex(), idxSyn1, codes, (int) negative, target, vectorLength, alpha, nextRandom.get(), vocabCache.numWords(), inferenceVector);
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
if (!isInference)
batches.get().add(sg);
else
Nd4j.getExecutioner().exec(sg);
return score;
}
Aggregations