use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.
the class BasicModelUtils method wordsNearest.
/**
* Words nearest based on positive and negative words
* * @param top the top n words
* @return the words nearest the mean of the words
*/
@Override
public Collection<String> wordsNearest(INDArray words, int top) {
if (lookupTable instanceof InMemoryLookupTable) {
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
INDArray syn0 = l.getSyn0();
if (!normalized) {
synchronized (this) {
if (!normalized) {
syn0.diviColumnVector(syn0.norm2(1));
normalized = true;
}
}
}
INDArray similarity = Transforms.unitVec(words).mmul(syn0.transpose());
List<Double> highToLowSimList = getTopN(similarity, top + 20);
List<WordSimilarity> result = new ArrayList<>();
for (int i = 0; i < highToLowSimList.size(); i++) {
String word = vocabCache.wordAtIndex(highToLowSimList.get(i).intValue());
if (word != null && !word.equals("UNK") && !word.equals("STOP")) {
INDArray otherVec = lookupTable.vector(word);
double sim = Transforms.cosineSim(words, otherVec);
result.add(new WordSimilarity(word, sim));
}
}
Collections.sort(result, new SimilarityComparator());
return getLabels(result, top);
}
Counter<String> distances = new Counter<>();
for (String s : vocabCache.words()) {
INDArray otherVec = lookupTable.vector(s);
double sim = Transforms.cosineSim(words, otherVec);
distances.incrementCount(s, sim);
}
distances.keepTopNKeys(top);
return distances.keySet();
}
use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.
the class ParagraphVectors method predictSeveral.
/**
* Predict several labels based on the document.
* Computes a similarity wrt the mean of the
* representation of words in the document
* @param document the document
* @return possible labels in descending order
*/
@Deprecated
public Collection<String> predictSeveral(List<VocabWord> document, int limit) {
/*
This code was transferred from original ParagraphVectors DL4j implementation, and yet to be tested
*/
if (document.isEmpty())
throw new IllegalStateException("Document has no words inside");
INDArray arr = Nd4j.create(document.size(), this.layerSize);
for (int i = 0; i < document.size(); i++) {
arr.putRow(i, getWordVectorMatrix(document.get(i).getWord()));
}
INDArray docMean = arr.mean(0);
Counter<String> distances = new Counter<>();
for (String s : labelsSource.getLabels()) {
INDArray otherVec = getWordVectorMatrix(s);
double sim = Transforms.cosineSim(docMean, otherVec);
log.debug("Similarity inside: [" + s + "] -> " + sim);
distances.incrementCount(s, sim);
}
return distances.getSortedKeys().subList(0, limit);
}
use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.
the class CountFunction method call.
@Override
public Pair<Sequence<T>, Long> call(Sequence<T> sequence) throws Exception {
// since we can't be 100% sure that sequence size is ok itself, or it's not overflow through int limits, we'll recalculate it.
// anyway we're going to loop through it for elements frequencies
Counter<Long> localCounter = new Counter<>();
long seqLen = 0;
if (ela == null) {
try {
ela = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfigurationBroadcast.getValue().getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
driver = ela.getTrainingDriver();
//System.out.println("Initializing VoidParameterServer in CountFunction");
VoidParameterServer.getInstance().init(voidConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
for (T element : sequence.getElements()) {
if (element == null)
continue;
// FIXME: hashcode is bad idea here. we need Long id
localCounter.incrementCount(element.getStorageId(), 1.0);
seqLen++;
}
// FIXME: we're missing label information here due to shallow vocab mechanics
if (sequence.getSequenceLabels() != null)
for (T label : sequence.getSequenceLabels()) {
localCounter.incrementCount(label.getStorageId(), 1.0);
}
accumulator.add(localCounter);
return Pair.makePair(sequence, seqLen);
}
use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectors method fitSequences.
/**
* Base training entry point
*
* @param corpus
*/
public void fitSequences(JavaRDD<Sequence<T>> corpus) {
/**
* Basically all we want for base implementation here is 3 things:
* a) build vocabulary
* b) build huffman tree
* c) do training
*
* in this case all classes extending SeqVec, like deepwalk or word2vec will be just building their RDD<Sequence<T>>,
* and calling this method for training, instead implementing own routines
*/
validateConfiguration();
if (ela == null) {
try {
ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (workers > 1) {
log.info("Repartitioning corpus to {} parts...", workers);
corpus.repartition(workers);
}
if (storageLevel != null)
corpus.persist(storageLevel);
final JavaSparkContext sc = new JavaSparkContext(corpus.context());
// this will have any effect only if wasn't called before, in extension classes
broadcastEnvironment(sc);
Counter<Long> finalCounter;
long numberOfSequences = 0;
/**
* Here we s
*/
if (paramServerConfiguration == null)
paramServerConfiguration = VoidConfiguration.builder().faultToleranceStrategy(FaultToleranceStrategy.NONE).numberOfShards(2).unicastPort(40123).multicastPort(40124).build();
isAutoDiscoveryMode = paramServerConfiguration.getShardAddresses() != null && !paramServerConfiguration.getShardAddresses().isEmpty() ? false : true;
Broadcast<VoidConfiguration> paramServerConfigurationBroadcast = null;
if (isAutoDiscoveryMode) {
log.info("Trying auto discovery mode...");
elementsFreqAccumExtra = corpus.context().accumulator(new ExtraCounter<Long>(), new ExtraElementsFrequenciesAccumulator());
ExtraCountFunction<T> elementsCounter = new ExtraCountFunction<>(elementsFreqAccumExtra, configuration.isTrainSequenceVectors());
JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
// just to trigger map function, since we need huffman tree before proceeding
numberOfSequences = countedCorpus.count();
finalCounter = elementsFreqAccumExtra.value();
ExtraCounter<Long> spareReference = (ExtraCounter<Long>) finalCounter;
// getting list of available hosts
Set<NetworkInformation> availableHosts = spareReference.getNetworkInformation();
log.info("availableHosts: {}", availableHosts);
if (availableHosts.size() > 1) {
// now we have to pick N shards and optionally N backup nodes, and pass them within configuration bean
NetworkOrganizer organizer = new NetworkOrganizer(availableHosts, paramServerConfiguration.getNetworkMask());
paramServerConfiguration.setShardAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards()));
// backup shards are optional
if (paramServerConfiguration.getFaultToleranceStrategy() != FaultToleranceStrategy.NONE) {
paramServerConfiguration.setBackupAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.getShardAddresses()));
}
} else {
// for single host (aka driver-only, aka spark-local) just run on loopback interface
paramServerConfiguration.setShardAddresses(Arrays.asList("127.0.0.1:" + paramServerConfiguration.getUnicastPort()));
paramServerConfiguration.setFaultToleranceStrategy(FaultToleranceStrategy.NONE);
}
log.info("Got Shards so far: {}", paramServerConfiguration.getShardAddresses());
// update ps configuration with real values where required
paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
} else {
// update ps configuration with real values where required
paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
// set up freqs accumulator
elementsFreqAccum = corpus.context().accumulator(new Counter<Long>(), new ElementsFrequenciesAccumulator());
CountFunction<T> elementsCounter = new CountFunction<>(configurationBroadcast, paramServerConfigurationBroadcast, elementsFreqAccum, configuration.isTrainSequenceVectors());
// count all sequence elements and their sum
JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
// just to trigger map function, since we need huffman tree before proceeding
numberOfSequences = countedCorpus.count();
// now we grab counter, which contains frequencies for all SequenceElements in corpus
finalCounter = elementsFreqAccum.value();
}
long numberOfElements = (long) finalCounter.totalCount();
long numberOfUniqueElements = finalCounter.size();
log.info("Total number of sequences: {}; Total number of elements entries: {}; Total number of unique elements: {}", numberOfSequences, numberOfElements, numberOfUniqueElements);
/*
build RDD of reduced SequenceElements, just get rid of labels temporary, stick to some numerical values,
like index or hashcode. So we could reduce driver memory footprint
*/
// build huffman tree, and update original RDD with huffman encoding info
shallowVocabCache = buildShallowVocabCache(finalCounter);
shallowVocabCacheBroadcast = sc.broadcast(shallowVocabCache);
// FIXME: probably we need to reconsider this approach
JavaRDD<T> vocabRDD = corpus.flatMap(new VocabRddFunctionFlat<T>(configurationBroadcast, paramServerConfigurationBroadcast)).distinct();
vocabRDD.count();
/**
* now we initialize Shards with values. That call should be started from driver which is either Client or Shard in standalone mode.
*/
VoidParameterServer.getInstance().init(paramServerConfiguration, new RoutedTransport(), ela.getTrainingDriver());
VoidParameterServer.getInstance().initializeSeqVec(configuration.getLayersSize(), (int) numberOfUniqueElements, 119, configuration.getLayersSize() / paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.isUseHS(), paramServerConfiguration.isUseNS());
// proceed to training
// also, training function is the place where we invoke ParameterServer
TrainingFunction<T> trainer = new TrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
PartitionTrainingFunction<T> partitionTrainer = new PartitionTrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
if (configuration != null)
for (int e = 0; e < configuration.getEpochs(); e++) corpus.foreachPartition(partitionTrainer);
//corpus.foreach(trainer);
// we're transferring vectors to ExportContainer
JavaRDD<ExportContainer<T>> exportRdd = vocabRDD.map(new DistributedFunction<T>(paramServerConfigurationBroadcast, configurationBroadcast, shallowVocabCacheBroadcast));
// at this particular moment training should be pretty much done, and we're good to go for export
if (exporter != null)
exporter.export(exportRdd);
// unpersist, if we've persisten corpus after all
if (storageLevel != null)
corpus.unpersist();
log.info("Training finish, starting cleanup...");
VoidParameterServer.getInstance().shutdown();
}
use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.
the class FoldWithinPartitionFunction method call.
@Override
public Iterator<AtomicLong> call(Integer ind, Iterator<AtomicLong> partition) throws Exception {
List<AtomicLong> foldedItemList = new ArrayList<AtomicLong>() {
{
add(new AtomicLong(0L));
}
};
// Recurrent state implementation of cum sum
int foldedItemListSize = 1;
while (partition.hasNext()) {
long curPartitionItem = partition.next().get();
int lastFoldedIndex = foldedItemListSize - 1;
long lastFoldedItem = foldedItemList.get(lastFoldedIndex).get();
AtomicLong sumLastCurrent = new AtomicLong(curPartitionItem + lastFoldedItem);
foldedItemList.set(lastFoldedIndex, sumLastCurrent);
foldedItemList.add(sumLastCurrent);
foldedItemListSize += 1;
}
// Update Accumulator
long maxFoldedItem = foldedItemList.remove(foldedItemListSize - 1).get();
Counter<Integer> partitionIndex2maxItemCounter = new Counter<>();
partitionIndex2maxItemCounter.incrementCount(ind, maxFoldedItem);
maxPerPartitionAcc.add(partitionIndex2maxItemCounter);
return foldedItemList.iterator();
}
Aggregations