Search in sources :

Example 1 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class PopularityWalker method next.

/**
     * This method returns next walk sequence from this graph
     *
     * @return
     */
@Override
public Sequence<T> next() {
    Sequence<T> sequence = new Sequence<>();
    int[] visitedHops = new int[walkLength];
    Arrays.fill(visitedHops, -1);
    int startPosition = position.getAndIncrement();
    int lastId = -1;
    int startPoint = order[startPosition];
    startPosition = startPoint;
    for (int i = 0; i < walkLength; i++) {
        Vertex<T> vertex = sourceGraph.getVertex(startPosition);
        int currentPosition = startPosition;
        sequence.addElement(vertex.getValue());
        visitedHops[i] = vertex.vertexID();
        int cSpread = 0;
        if (alpha > 0 && lastId != startPoint && lastId != -1 && alpha > rng.nextDouble()) {
            startPosition = startPoint;
            continue;
        }
        switch(walkDirection) {
            case RANDOM:
            case FORWARD_ONLY:
            case FORWARD_UNIQUE:
            case FORWARD_PREFERRED:
                {
                    // we get  popularity of each node connected to the current node.
                    PriorityQueue<Node<T>> queue = new PriorityQueue<>();
                    // ArrayUtils.removeElements(sourceGraph.getConnectedVertexIndices(order[currentPosition]), visitedHops);
                    int[] connections = ArrayUtils.removeElements(sourceGraph.getConnectedVertexIndices(vertex.vertexID()), visitedHops);
                    int start = 0;
                    int stop = 0;
                    int cnt = 0;
                    if (connections.length > 0) {
                        for (int connected : connections) {
                            queue.add(new Node<T>(connected, sourceGraph.getConnectedVertices(connected).size()), sourceGraph.getConnectedVertices(connected).size());
                        }
                        cSpread = spread > connections.length ? connections.length : spread;
                        switch(popularityMode) {
                            case MAXIMUM:
                                start = 0;
                                stop = start + cSpread - 1;
                                break;
                            case MINIMUM:
                                start = connections.length - cSpread;
                                stop = connections.length - 1;
                                break;
                            case AVERAGE:
                                int mid = connections.length / 2;
                                start = mid - (cSpread / 2);
                                stop = mid + (cSpread / 2);
                                break;
                        }
                        // logger.info("Spread: ["+ cSpread+ "], Connections: ["+ connections.length+"], Start: ["+start+"], Stop: ["+stop+"]");
                        cnt = 0;
                        //logger.info("Queue: " + queue);
                        //logger.info("Queue size: " + queue.size());
                        List<Node<T>> list = new ArrayList<>();
                        double[] weights = new double[cSpread];
                        int fcnt = 0;
                        while (queue.hasNext()) {
                            Node<T> node = queue.next();
                            if (cnt >= start && cnt <= stop) {
                                list.add(node);
                                weights[fcnt] = node.getWeight();
                                fcnt++;
                            }
                            connections[cnt] = node.getVertexId();
                            cnt++;
                        }
                        int con = -1;
                        switch(spectrum) {
                            case PLAIN:
                                {
                                    con = RandomUtils.nextInt(start, stop + 1);
                                    //    logger.info("Picked selection: " + con);
                                    Vertex<T> nV = sourceGraph.getVertex(connections[con]);
                                    startPosition = nV.vertexID();
                                    lastId = vertex.vertexID();
                                }
                                break;
                            case PROPORTIONAL:
                                {
                                    double[] norm = MathArrays.normalizeArray(weights, 1);
                                    double prob = rng.nextDouble();
                                    double floor = 0.0;
                                    for (int b = 0; b < weights.length; b++) {
                                        if (prob >= floor && prob < floor + norm[b]) {
                                            startPosition = list.get(b).getVertexId();
                                            lastId = startPosition;
                                            break;
                                        } else {
                                            floor += norm[b];
                                        }
                                    }
                                }
                                break;
                        }
                    } else {
                        switch(noEdgeHandling) {
                            case EXCEPTION_ON_DISCONNECTED:
                                throw new NoEdgesException("No more edges at vertex [" + currentPosition + "]");
                            case CUTOFF_ON_DISCONNECTED:
                                i += walkLength;
                                break;
                            case SELF_LOOP_ON_DISCONNECTED:
                                startPosition = currentPosition;
                                break;
                            case RESTART_ON_DISCONNECTED:
                                startPosition = startPoint;
                                break;
                            default:
                                throw new UnsupportedOperationException("Unsupported noEdgeHandling: [" + noEdgeHandling + "]");
                        }
                    }
                }
                break;
            default:
                throw new UnsupportedOperationException("Unknown WalkDirection: [" + walkDirection + "]");
        }
    }
    return sequence;
}
Also used : NoEdgesException(org.deeplearning4j.models.sequencevectors.graph.exception.NoEdgesException) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) PriorityQueue(org.deeplearning4j.berkeley.PriorityQueue) ArrayList(java.util.ArrayList) List(java.util.List)

Example 2 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class ParallelTransformerIteratorTest method testSpeedComparison1.

@Test
public void testSpeedComparison1() throws Exception {
    SentenceIterator iterator = new MutipleEpochsSentenceIterator(new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile()), 25);
    SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(false).tokenizerFactory(factory).build();
    Iterator<Sequence<VocabWord>> iter = transformer.iterator();
    int cnt = 0;
    long time1 = System.currentTimeMillis();
    while (iter.hasNext()) {
        Sequence<VocabWord> sequence = iter.next();
        assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
        assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
        cnt++;
    }
    long time2 = System.currentTimeMillis();
    log.info("Single-threaded time: {} ms", time2 - time1);
    iterator.reset();
    transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(true).tokenizerFactory(factory).build();
    iter = transformer.iterator();
    time1 = System.currentTimeMillis();
    while (iter.hasNext()) {
        Sequence<VocabWord> sequence = iter.next();
        assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
        assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
        cnt++;
    }
    time2 = System.currentTimeMillis();
    log.info("Multi-threaded time: {} ms", time2 - time1);
    SentenceIterator baseIterator = iterator;
    baseIterator.reset();
    LabelAwareIterator lai = new BasicLabelAwareIterator.Builder(new MutipleEpochsSentenceIterator(new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile()), 25)).build();
    transformer = new SentenceTransformer.Builder().iterator(lai).allowMultithreading(false).tokenizerFactory(factory).build();
    iter = transformer.iterator();
    time1 = System.currentTimeMillis();
    while (iter.hasNext()) {
        Sequence<VocabWord> sequence = iter.next();
        assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
        assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
        cnt++;
    }
    time2 = System.currentTimeMillis();
    log.info("Prefetched Single-threaded time: {} ms", time2 - time1);
    lai.reset();
    transformer = new SentenceTransformer.Builder().iterator(lai).allowMultithreading(true).tokenizerFactory(factory).build();
    iter = transformer.iterator();
    time1 = System.currentTimeMillis();
    while (iter.hasNext()) {
        Sequence<VocabWord> sequence = iter.next();
        assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
        assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
        cnt++;
    }
    time2 = System.currentTimeMillis();
    log.info("Prefetched Multi-threaded time: {} ms", time2 - time1);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) MutipleEpochsSentenceIterator(org.deeplearning4j.text.sentenceiterator.MutipleEpochsSentenceIterator) BasicLabelAwareIterator(org.deeplearning4j.text.documentiterator.BasicLabelAwareIterator) AsyncLabelAwareIterator(org.deeplearning4j.text.documentiterator.AsyncLabelAwareIterator) BasicLabelAwareIterator(org.deeplearning4j.text.documentiterator.BasicLabelAwareIterator) LabelAwareIterator(org.deeplearning4j.text.documentiterator.LabelAwareIterator) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) SentenceTransformer(org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) PrefetchingSentenceIterator(org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) MutipleEpochsSentenceIterator(org.deeplearning4j.text.sentenceiterator.MutipleEpochsSentenceIterator) ClassPathResource(org.datavec.api.util.ClassPathResource) Test(org.junit.Test)

Example 3 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class TokenizerFunction method call.

@Override
public Sequence<VocabWord> call(String s) throws Exception {
    if (tokenizerFactory == null)
        instantiateTokenizerFactory();
    List<String> tokens = tokenizerFactory.create(s).getTokens();
    Sequence<VocabWord> seq = new Sequence<>();
    for (String token : tokens) {
        if (token == null || token.isEmpty())
            continue;
        seq.addElement(new VocabWord(1.0, token));
    }
    return seq;
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence)

Example 4 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class TrainingFunction method call.

@Override
@SuppressWarnings("unchecked")
public void call(Sequence<T> sequence) throws Exception {
    /**
         * Depending on actual training mode, we'll either go for SkipGram/CBOW/PV-DM/PV-DBOW or whatever
         */
    if (vectorsConfiguration == null)
        vectorsConfiguration = configurationBroadcast.getValue();
    if (paramServer == null) {
        paramServer = VoidParameterServer.getInstance();
        if (elementsLearningAlgorithm == null) {
            try {
                elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        driver = elementsLearningAlgorithm.getTrainingDriver();
        // FIXME: init line should probably be removed, basically init happens in VocabRddFunction
        paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
    }
    if (vectorsConfiguration == null)
        vectorsConfiguration = configurationBroadcast.getValue();
    if (shallowVocabCache == null)
        shallowVocabCache = vocabCacheBroadcast.getValue();
    if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
        // TODO: do ELA initialization
        try {
            elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
        // TODO: do SLA initialization
        try {
            sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
            sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
        throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
    }
    /*
         at this moment we should have everything ready for actual initialization
         the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
        */
    Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
    for (T element : sequence.getElements()) {
        // it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
        ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
        if (reduced != null)
            mergedSequence.addElement(reduced);
    }
    // do the same with labels, transfer them, if any
    if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
        for (T label : sequence.getSequenceLabels()) {
            ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
            if (reduced != null)
                mergedSequence.addSequenceLabel(reduced);
        }
    }
    // FIXME: temporary hook
    if (sequence.size() > 0)
        paramServer.execDistributed(elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
    else
        log.warn("Skipping empty sequence...");
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 5 with Sequence

use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.

the class SparkParagraphVectors method fitLabelledDocuments.

/**
     * This method builds ParagraphVectors model, expecting JavaRDD<LabelledDocument>.
     * It can be either non-tokenized documents, or tokenized.
     *
     * @param documentsRdd
     */
public void fitLabelledDocuments(JavaRDD<LabelledDocument> documentsRdd) {
    validateConfiguration();
    broadcastEnvironment(new JavaSparkContext(documentsRdd.context()));
    JavaRDD<Sequence<VocabWord>> sequenceRDD = documentsRdd.map(new DocumentSequenceConvertFunction(configurationBroadcast));
    super.fitSequences(sequenceRDD);
}
Also used : JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) DocumentSequenceConvertFunction(org.deeplearning4j.spark.models.paragraphvectors.functions.DocumentSequenceConvertFunction)

Aggregations

Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)18 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)11 Test (org.junit.Test)5 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)4 ShallowSequenceElement (org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement)4 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)4 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)4 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 ArrayList (java.util.ArrayList)2 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)2 ClassPathResource (org.datavec.api.util.ClassPathResource)2 SequenceIterator (org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator)2 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)2 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)2 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)2 FileLabelAwareIterator (org.deeplearning4j.text.documentiterator.FileLabelAwareIterator)2 MutipleEpochsSentenceIterator (org.deeplearning4j.text.sentenceiterator.MutipleEpochsSentenceIterator)2 PrefetchingSentenceIterator (org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator)2 RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)2 List (java.util.List)1