use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class KeySequenceConvertFunction method call.
@Override
public Sequence<VocabWord> call(Tuple2<String, String> pair) throws Exception {
Sequence<VocabWord> sequence = new Sequence<>();
sequence.addSequenceLabel(new VocabWord(1.0, pair._1()));
if (tokenizerFactory == null)
instantiateTokenizerFactory();
List<String> tokens = tokenizerFactory.create(pair._2()).getTokens();
for (String token : tokens) {
if (token == null || token.isEmpty())
continue;
VocabWord word = new VocabWord(1.0, token);
sequence.addElement(word);
}
return sequence;
}
use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectorsTest method testFrequenciesCount.
@Test
public void testFrequenciesCount() throws Exception {
JavaRDD<Sequence<VocabWord>> sequences = sc.parallelize(sequencesCyclic);
SparkSequenceVectors<VocabWord> seqVec = new SparkSequenceVectors<>();
seqVec.fitSequences(sequences);
Counter<Long> counter = seqVec.getCounter();
// element "0" should have frequency of 20
assertEquals(20, counter.getCount(0L), 1e-5);
// elements 1 - 9 should have frequencies of 10
for (int e = 1; e < sequencesCyclic.get(0).getElements().size() - 1; e++) {
assertEquals(10, counter.getCount(sequencesCyclic.get(0).getElementByIndex(e).getStorageId()), 1e-5);
}
VocabCache<ShallowSequenceElement> shallowVocab = seqVec.getShallowVocabCache();
assertEquals(10, shallowVocab.numWords());
ShallowSequenceElement zero = shallowVocab.tokenFor(0L);
ShallowSequenceElement first = shallowVocab.tokenFor(1L);
assertNotEquals(null, zero);
assertEquals(20.0, zero.getElementFrequency(), 1e-5);
assertEquals(0, zero.getIndex());
assertEquals(10.0, first.getElementFrequency(), 1e-5);
}
use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectorsTest method setUp.
@Before
public void setUp() throws Exception {
if (sequencesCyclic == null) {
sequencesCyclic = new ArrayList<>();
// 10 sequences in total
for (int с = 0; с < 10; с++) {
Sequence<VocabWord> sequence = new Sequence<>();
for (int e = 0; e < 10; e++) {
// we will have 9 equal elements, with total frequency of 10
sequence.addElement(new VocabWord(1.0, "" + e, (long) e));
}
// and 1 element with frequency of 20
sequence.addElement(new VocabWord(1.0, "0", 0L));
sequencesCyclic.add(sequence);
}
}
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests");
sc = new JavaSparkContext(sparkConf);
}
use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class ParagraphVectors method inferVector.
/**
* This method calculates inferred vector for given document
*
* @param document
* @return
*/
public INDArray inferVector(@NonNull List<VocabWord> document, double learningRate, double minLearningRate, int iterations) {
SequenceLearningAlgorithm<VocabWord> learner = sequenceLearningAlgorithm;
if (learner == null) {
synchronized (this) {
if (sequenceLearningAlgorithm == null) {
log.info("Creating new PV-DM learner...");
learner = new DM<VocabWord>();
learner.configure(vocab, lookupTable, configuration);
sequenceLearningAlgorithm = learner;
} else {
learner = sequenceLearningAlgorithm;
}
}
}
learner = sequenceLearningAlgorithm;
if (document.isEmpty())
throw new ND4JIllegalStateException("Impossible to apply inference to empty list of words");
Sequence<VocabWord> sequence = new Sequence<>();
sequence.addElements(document);
sequence.setSequenceLabel(new VocabWord(1.0, String.valueOf(new Random().nextInt())));
initLearners();
INDArray inf = learner.inferSequence(sequence, seed, learningRate, minLearningRate, iterations);
return inf;
}
use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class SentenceTransformer method transformToSequence.
@Override
public Sequence<VocabWord> transformToSequence(String object) {
Sequence<VocabWord> sequence = new Sequence<>();
Tokenizer tokenizer = tokenizerFactory.create(object);
List<String> list = tokenizer.getTokens();
for (String token : list) {
if (token == null || token.isEmpty() || token.trim().isEmpty())
continue;
VocabWord word = new VocabWord(1.0, token);
sequence.addElement(word);
}
sequence.setSequenceId(sentenceCounter.getAndIncrement());
return sequence;
}
Aggregations