Search in sources :

Example 1 with Counter

use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.

the class BasicModelUtils method wordsNearest.

/**
     * Words nearest based on positive and negative words
     * * @param top the top n words
     * @return the words nearest the mean of the words
     */
@Override
public Collection<String> wordsNearest(INDArray words, int top) {
    if (lookupTable instanceof InMemoryLookupTable) {
        InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
        INDArray syn0 = l.getSyn0();
        if (!normalized) {
            synchronized (this) {
                if (!normalized) {
                    syn0.diviColumnVector(syn0.norm2(1));
                    normalized = true;
                }
            }
        }
        INDArray similarity = Transforms.unitVec(words).mmul(syn0.transpose());
        List<Double> highToLowSimList = getTopN(similarity, top + 20);
        List<WordSimilarity> result = new ArrayList<>();
        for (int i = 0; i < highToLowSimList.size(); i++) {
            String word = vocabCache.wordAtIndex(highToLowSimList.get(i).intValue());
            if (word != null && !word.equals("UNK") && !word.equals("STOP")) {
                INDArray otherVec = lookupTable.vector(word);
                double sim = Transforms.cosineSim(words, otherVec);
                result.add(new WordSimilarity(word, sim));
            }
        }
        Collections.sort(result, new SimilarityComparator());
        return getLabels(result, top);
    }
    Counter<String> distances = new Counter<>();
    for (String s : vocabCache.words()) {
        INDArray otherVec = lookupTable.vector(s);
        double sim = Transforms.cosineSim(words, otherVec);
        distances.incrementCount(s, sim);
    }
    distances.keepTopNKeys(top);
    return distances.keySet();
}
Also used : InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) Counter(org.deeplearning4j.berkeley.Counter) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 2 with Counter

use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.

the class ParagraphVectors method predictSeveral.

/**
     * Predict several labels based on the document.
     * Computes a similarity wrt the mean of the
     * representation of words in the document
     * @param document the document
     * @return possible labels in descending order
     */
@Deprecated
public Collection<String> predictSeveral(List<VocabWord> document, int limit) {
    /*
            This code was transferred from original ParagraphVectors DL4j implementation, and yet to be tested
         */
    if (document.isEmpty())
        throw new IllegalStateException("Document has no words inside");
    INDArray arr = Nd4j.create(document.size(), this.layerSize);
    for (int i = 0; i < document.size(); i++) {
        arr.putRow(i, getWordVectorMatrix(document.get(i).getWord()));
    }
    INDArray docMean = arr.mean(0);
    Counter<String> distances = new Counter<>();
    for (String s : labelsSource.getLabels()) {
        INDArray otherVec = getWordVectorMatrix(s);
        double sim = Transforms.cosineSim(docMean, otherVec);
        log.debug("Similarity inside: [" + s + "] -> " + sim);
        distances.incrementCount(s, sim);
    }
    return distances.getSortedKeys().subList(0, limit);
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Counter(org.deeplearning4j.berkeley.Counter) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 3 with Counter

use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.

the class CountFunction method call.

@Override
public Pair<Sequence<T>, Long> call(Sequence<T> sequence) throws Exception {
    // since we can't be 100% sure that sequence size is ok itself, or it's not overflow through int limits, we'll recalculate it.
    // anyway we're going to loop through it for elements frequencies
    Counter<Long> localCounter = new Counter<>();
    long seqLen = 0;
    if (ela == null) {
        try {
            ela = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfigurationBroadcast.getValue().getElementsLearningAlgorithm()).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    driver = ela.getTrainingDriver();
    //System.out.println("Initializing VoidParameterServer in CountFunction");
    VoidParameterServer.getInstance().init(voidConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
    for (T element : sequence.getElements()) {
        if (element == null)
            continue;
        // FIXME: hashcode is bad idea here. we need Long id
        localCounter.incrementCount(element.getStorageId(), 1.0);
        seqLen++;
    }
    // FIXME: we're missing label information here due to shallow vocab mechanics
    if (sequence.getSequenceLabels() != null)
        for (T label : sequence.getSequenceLabels()) {
            localCounter.incrementCount(label.getStorageId(), 1.0);
        }
    accumulator.add(localCounter);
    return Pair.makePair(sequence, seqLen);
}
Also used : Counter(org.deeplearning4j.berkeley.Counter) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport)

Example 4 with Counter

use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.

the class SparkSequenceVectors method fitSequences.

/**
     * Base training entry point
     *
     * @param corpus
     */
public void fitSequences(JavaRDD<Sequence<T>> corpus) {
    /**
         * Basically all we want for base implementation here is 3 things:
         * a) build vocabulary
         * b) build huffman tree
         * c) do training
         *
         * in this case all classes extending SeqVec, like deepwalk or word2vec will be just building their RDD<Sequence<T>>,
         * and calling this method for training, instead implementing own routines
         */
    validateConfiguration();
    if (ela == null) {
        try {
            ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm()).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (workers > 1) {
        log.info("Repartitioning corpus to {} parts...", workers);
        corpus.repartition(workers);
    }
    if (storageLevel != null)
        corpus.persist(storageLevel);
    final JavaSparkContext sc = new JavaSparkContext(corpus.context());
    // this will have any effect only if wasn't called before, in extension classes
    broadcastEnvironment(sc);
    Counter<Long> finalCounter;
    long numberOfSequences = 0;
    /**
         * Here we s
         */
    if (paramServerConfiguration == null)
        paramServerConfiguration = VoidConfiguration.builder().faultToleranceStrategy(FaultToleranceStrategy.NONE).numberOfShards(2).unicastPort(40123).multicastPort(40124).build();
    isAutoDiscoveryMode = paramServerConfiguration.getShardAddresses() != null && !paramServerConfiguration.getShardAddresses().isEmpty() ? false : true;
    Broadcast<VoidConfiguration> paramServerConfigurationBroadcast = null;
    if (isAutoDiscoveryMode) {
        log.info("Trying auto discovery mode...");
        elementsFreqAccumExtra = corpus.context().accumulator(new ExtraCounter<Long>(), new ExtraElementsFrequenciesAccumulator());
        ExtraCountFunction<T> elementsCounter = new ExtraCountFunction<>(elementsFreqAccumExtra, configuration.isTrainSequenceVectors());
        JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
        // just to trigger map function, since we need huffman tree before proceeding
        numberOfSequences = countedCorpus.count();
        finalCounter = elementsFreqAccumExtra.value();
        ExtraCounter<Long> spareReference = (ExtraCounter<Long>) finalCounter;
        // getting list of available hosts
        Set<NetworkInformation> availableHosts = spareReference.getNetworkInformation();
        log.info("availableHosts: {}", availableHosts);
        if (availableHosts.size() > 1) {
            // now we have to pick N shards and optionally N backup nodes, and pass them within configuration bean
            NetworkOrganizer organizer = new NetworkOrganizer(availableHosts, paramServerConfiguration.getNetworkMask());
            paramServerConfiguration.setShardAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards()));
            // backup shards are optional
            if (paramServerConfiguration.getFaultToleranceStrategy() != FaultToleranceStrategy.NONE) {
                paramServerConfiguration.setBackupAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.getShardAddresses()));
            }
        } else {
            // for single host (aka driver-only, aka spark-local) just run on loopback interface
            paramServerConfiguration.setShardAddresses(Arrays.asList("127.0.0.1:" + paramServerConfiguration.getUnicastPort()));
            paramServerConfiguration.setFaultToleranceStrategy(FaultToleranceStrategy.NONE);
        }
        log.info("Got Shards so far: {}", paramServerConfiguration.getShardAddresses());
        // update ps configuration with real values where required
        paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
        paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
        paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
        paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
    } else {
        // update ps configuration with real values where required
        paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
        paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
        paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
        paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
        // set up freqs accumulator
        elementsFreqAccum = corpus.context().accumulator(new Counter<Long>(), new ElementsFrequenciesAccumulator());
        CountFunction<T> elementsCounter = new CountFunction<>(configurationBroadcast, paramServerConfigurationBroadcast, elementsFreqAccum, configuration.isTrainSequenceVectors());
        // count all sequence elements and their sum
        JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
        // just to trigger map function, since we need huffman tree before proceeding
        numberOfSequences = countedCorpus.count();
        // now we grab counter, which contains frequencies for all SequenceElements in corpus
        finalCounter = elementsFreqAccum.value();
    }
    long numberOfElements = (long) finalCounter.totalCount();
    long numberOfUniqueElements = finalCounter.size();
    log.info("Total number of sequences: {}; Total number of elements entries: {}; Total number of unique elements: {}", numberOfSequences, numberOfElements, numberOfUniqueElements);
    /*
         build RDD of reduced SequenceElements, just get rid of labels temporary, stick to some numerical values,
         like index or hashcode. So we could reduce driver memory footprint
         */
    // build huffman tree, and update original RDD with huffman encoding info
    shallowVocabCache = buildShallowVocabCache(finalCounter);
    shallowVocabCacheBroadcast = sc.broadcast(shallowVocabCache);
    // FIXME: probably we need to reconsider this approach
    JavaRDD<T> vocabRDD = corpus.flatMap(new VocabRddFunctionFlat<T>(configurationBroadcast, paramServerConfigurationBroadcast)).distinct();
    vocabRDD.count();
    /**
         * now we initialize Shards with values. That call should be started from driver which is either Client or Shard in standalone mode.
         */
    VoidParameterServer.getInstance().init(paramServerConfiguration, new RoutedTransport(), ela.getTrainingDriver());
    VoidParameterServer.getInstance().initializeSeqVec(configuration.getLayersSize(), (int) numberOfUniqueElements, 119, configuration.getLayersSize() / paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.isUseHS(), paramServerConfiguration.isUseNS());
    // proceed to training
    // also, training function is the place where we invoke ParameterServer
    TrainingFunction<T> trainer = new TrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
    PartitionTrainingFunction<T> partitionTrainer = new PartitionTrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
    if (configuration != null)
        for (int e = 0; e < configuration.getEpochs(); e++) corpus.foreachPartition(partitionTrainer);
    //corpus.foreach(trainer);
    // we're transferring vectors to ExportContainer
    JavaRDD<ExportContainer<T>> exportRdd = vocabRDD.map(new DistributedFunction<T>(paramServerConfigurationBroadcast, configurationBroadcast, shallowVocabCacheBroadcast));
    // at this particular moment training should be pretty much done, and we're good to go for export
    if (exporter != null)
        exporter.export(exportRdd);
    // unpersist, if we've persisten corpus after all
    if (storageLevel != null)
        corpus.unpersist();
    log.info("Training finish, starting cleanup...");
    VoidParameterServer.getInstance().shutdown();
}
Also used : NetworkOrganizer(org.deeplearning4j.spark.models.sequencevectors.utils.NetworkOrganizer) ExportContainer(org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer) ExtraCounter(org.deeplearning4j.spark.models.sequencevectors.primitives.ExtraCounter) Counter(org.deeplearning4j.berkeley.Counter) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ExtraCounter(org.deeplearning4j.spark.models.sequencevectors.primitives.ExtraCounter) Pair(org.deeplearning4j.berkeley.Pair) VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) DL4JInvalidConfigException(org.deeplearning4j.exception.DL4JInvalidConfigException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) NetworkInformation(org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport)

Example 5 with Counter

use of org.deeplearning4j.berkeley.Counter in project deeplearning4j by deeplearning4j.

the class FoldWithinPartitionFunction method call.

@Override
public Iterator<AtomicLong> call(Integer ind, Iterator<AtomicLong> partition) throws Exception {
    List<AtomicLong> foldedItemList = new ArrayList<AtomicLong>() {

        {
            add(new AtomicLong(0L));
        }
    };
    // Recurrent state implementation of cum sum
    int foldedItemListSize = 1;
    while (partition.hasNext()) {
        long curPartitionItem = partition.next().get();
        int lastFoldedIndex = foldedItemListSize - 1;
        long lastFoldedItem = foldedItemList.get(lastFoldedIndex).get();
        AtomicLong sumLastCurrent = new AtomicLong(curPartitionItem + lastFoldedItem);
        foldedItemList.set(lastFoldedIndex, sumLastCurrent);
        foldedItemList.add(sumLastCurrent);
        foldedItemListSize += 1;
    }
    // Update Accumulator
    long maxFoldedItem = foldedItemList.remove(foldedItemListSize - 1).get();
    Counter<Integer> partitionIndex2maxItemCounter = new Counter<>();
    partitionIndex2maxItemCounter.incrementCount(ind, maxFoldedItem);
    maxPerPartitionAcc.add(partitionIndex2maxItemCounter);
    return foldedItemList.iterator();
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) Counter(org.deeplearning4j.berkeley.Counter) ArrayList(java.util.ArrayList)

Aggregations

Counter (org.deeplearning4j.berkeley.Counter)12 INDArray (org.nd4j.linalg.api.ndarray.INDArray)5 AtomicLong (java.util.concurrent.atomic.AtomicLong)3 Pair (org.deeplearning4j.berkeley.Pair)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 ArrayList (java.util.ArrayList)2 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)2 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)2 RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)2 DL4JInvalidConfigException (org.deeplearning4j.exception.DL4JInvalidConfigException)1 ExportContainer (org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer)1 ExtraCounter (org.deeplearning4j.spark.models.sequencevectors.primitives.ExtraCounter)1 NetworkInformation (org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation)1 NetworkOrganizer (org.deeplearning4j.spark.models.sequencevectors.utils.NetworkOrganizer)1 MaxPerPartitionAccumulator (org.deeplearning4j.spark.text.accumulators.MaxPerPartitionAccumulator)1 WordFreqAccumulator (org.deeplearning4j.spark.text.accumulators.WordFreqAccumulator)1 VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)1