use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.
the class TextPipeline method setup.
private void setup() {
// Set up accumulators and broadcast stopwords
this.sc = new JavaSparkContext(corpusRDD.context());
this.wordFreqAcc = sc.accumulator(new Counter<String>(), new WordFreqAccumulator());
this.stopWordBroadCast = sc.broadcast(stopWords);
}
use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.
the class ParagraphVectors method predict.
/**
* This method predicts label of the document.
* Computes a similarity wrt the mean of the
* representation of words in the document
* @param document the document
* @return the word distances for each label
*/
@Deprecated
public String predict(List<VocabWord> document) {
/*
This code was transferred from original ParagraphVectors DL4j implementation, and yet to be tested
*/
if (document.isEmpty())
throw new IllegalStateException("Document has no words inside");
INDArray arr = Nd4j.create(document.size(), this.layerSize);
for (int i = 0; i < document.size(); i++) {
arr.putRow(i, getWordVectorMatrix(document.get(i).getWord()));
}
INDArray docMean = arr.mean(0);
Counter<String> distances = new Counter<>();
for (String s : labelsSource.getLabels()) {
INDArray otherVec = getWordVectorMatrix(s);
double sim = Transforms.cosineSim(docMean, otherVec);
distances.incrementCount(s, sim);
}
return distances.argMax();
}
use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.
the class GloVe method learnSequence.
/**
* Learns sequence using GloVe algorithm
*
* @param sequence
* @param nextRandom
* @param learningRate
*/
@Override
public synchronized double learnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong nextRandom, double learningRate) {
/*
GloVe learning algorithm is implemented like a hack over settled ElementsLearningAlgorithm mechanics. It's called in SequenceVectors context, but actually only for the first call.
All subsequent calls will met early termination condition, and will be successfully ignored. But since elements vectors will be updated within first call,
this will allow compatibility with everything beyond this implementaton
*/
if (isTerminate.get())
return 0;
final AtomicLong pairsCount = new AtomicLong(0);
final Counter<Integer> errorCounter = new Counter<>();
for (int i = 0; i < configuration.getEpochs(); i++) {
// TODO: shuffle should be built in another way.
//if (shuffle)
//Collections.shuffle(coList);
Iterator<Pair<Pair<T, T>, Double>> pairs = coOccurrences.iterator();
List<GloveCalculationsThread> threads = new ArrayList<>();
for (int x = 0; x < workers; x++) {
threads.add(x, new GloveCalculationsThread(i, x, pairs, pairsCount, errorCounter));
threads.get(x).start();
}
for (int x = 0; x < workers; x++) {
try {
threads.get(x).join();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
log.info("Processed [" + pairsCount.get() + "] pairs, Error was [" + errorCounter.getCount(i) + "]");
}
isTerminate.set(true);
return 0;
}
use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.
the class BasicModelUtils method wordsNearestSum.
/**
* Words nearest based on positive and negative words
* * @param top the top n words
* @return the words nearest the mean of the words
*/
@Override
public Collection<String> wordsNearestSum(INDArray words, int top) {
if (lookupTable instanceof InMemoryLookupTable) {
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
INDArray syn0 = l.getSyn0();
INDArray weights = syn0.norm2(0).rdivi(1).muli(words);
INDArray distances = syn0.mulRowVector(weights).sum(1);
INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
INDArray sort = sorted[0];
List<String> ret = new ArrayList<>();
if (top > sort.length())
top = sort.length();
//there will be a redundant word
int end = top;
for (int i = 0; i < end; i++) {
String add = vocabCache.wordAtIndex(sort.getInt(i));
if (add == null || add.equals("UNK") || add.equals("STOP")) {
end++;
if (end >= sort.length())
break;
continue;
}
ret.add(vocabCache.wordAtIndex(sort.getInt(i)));
}
return ret;
}
Counter<String> distances = new Counter<>();
for (String s : vocabCache.words()) {
INDArray otherVec = lookupTable.vector(s);
double sim = Transforms.cosineSim(words, otherVec);
distances.incrementCount(s, sim);
}
distances.keepTopNKeys(top);
return distances.keySet();
}
use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.
the class FlatModelUtils method wordsNearest.
/**
* This method does full scan against whole vocabulary, building descending list of similar words
*
* @param words
* @param top
* @return the words nearest the mean of the words
*/
@Override
public Collection<String> wordsNearest(INDArray words, int top) {
Counter<String> distances = new Counter<>();
for (String s : vocabCache.words()) {
INDArray otherVec = lookupTable.vector(s);
double sim = Transforms.cosineSim(words.dup(), otherVec.dup());
distances.incrementCount(s, sim);
}
distances.keepTopNKeys(top);
return distances.getSortedKeys();
}
Aggregations