use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class PopularityWalker method next.
/**
* This method returns next walk sequence from this graph
*
* @return
*/
@Override
public Sequence<T> next() {
Sequence<T> sequence = new Sequence<>();
int[] visitedHops = new int[walkLength];
Arrays.fill(visitedHops, -1);
int startPosition = position.getAndIncrement();
int lastId = -1;
int startPoint = order[startPosition];
startPosition = startPoint;
for (int i = 0; i < walkLength; i++) {
Vertex<T> vertex = sourceGraph.getVertex(startPosition);
int currentPosition = startPosition;
sequence.addElement(vertex.getValue());
visitedHops[i] = vertex.vertexID();
int cSpread = 0;
if (alpha > 0 && lastId != startPoint && lastId != -1 && alpha > rng.nextDouble()) {
startPosition = startPoint;
continue;
}
switch(walkDirection) {
case RANDOM:
case FORWARD_ONLY:
case FORWARD_UNIQUE:
case FORWARD_PREFERRED:
{
// we get popularity of each node connected to the current node.
PriorityQueue<Node<T>> queue = new PriorityQueue<>();
// ArrayUtils.removeElements(sourceGraph.getConnectedVertexIndices(order[currentPosition]), visitedHops);
int[] connections = ArrayUtils.removeElements(sourceGraph.getConnectedVertexIndices(vertex.vertexID()), visitedHops);
int start = 0;
int stop = 0;
int cnt = 0;
if (connections.length > 0) {
for (int connected : connections) {
queue.add(new Node<T>(connected, sourceGraph.getConnectedVertices(connected).size()), sourceGraph.getConnectedVertices(connected).size());
}
cSpread = spread > connections.length ? connections.length : spread;
switch(popularityMode) {
case MAXIMUM:
start = 0;
stop = start + cSpread - 1;
break;
case MINIMUM:
start = connections.length - cSpread;
stop = connections.length - 1;
break;
case AVERAGE:
int mid = connections.length / 2;
start = mid - (cSpread / 2);
stop = mid + (cSpread / 2);
break;
}
// logger.info("Spread: ["+ cSpread+ "], Connections: ["+ connections.length+"], Start: ["+start+"], Stop: ["+stop+"]");
cnt = 0;
//logger.info("Queue: " + queue);
//logger.info("Queue size: " + queue.size());
List<Node<T>> list = new ArrayList<>();
double[] weights = new double[cSpread];
int fcnt = 0;
while (queue.hasNext()) {
Node<T> node = queue.next();
if (cnt >= start && cnt <= stop) {
list.add(node);
weights[fcnt] = node.getWeight();
fcnt++;
}
connections[cnt] = node.getVertexId();
cnt++;
}
int con = -1;
switch(spectrum) {
case PLAIN:
{
con = RandomUtils.nextInt(start, stop + 1);
// logger.info("Picked selection: " + con);
Vertex<T> nV = sourceGraph.getVertex(connections[con]);
startPosition = nV.vertexID();
lastId = vertex.vertexID();
}
break;
case PROPORTIONAL:
{
double[] norm = MathArrays.normalizeArray(weights, 1);
double prob = rng.nextDouble();
double floor = 0.0;
for (int b = 0; b < weights.length; b++) {
if (prob >= floor && prob < floor + norm[b]) {
startPosition = list.get(b).getVertexId();
lastId = startPosition;
break;
} else {
floor += norm[b];
}
}
}
break;
}
} else {
switch(noEdgeHandling) {
case EXCEPTION_ON_DISCONNECTED:
throw new NoEdgesException("No more edges at vertex [" + currentPosition + "]");
case CUTOFF_ON_DISCONNECTED:
i += walkLength;
break;
case SELF_LOOP_ON_DISCONNECTED:
startPosition = currentPosition;
break;
case RESTART_ON_DISCONNECTED:
startPosition = startPoint;
break;
default:
throw new UnsupportedOperationException("Unsupported noEdgeHandling: [" + noEdgeHandling + "]");
}
}
}
break;
default:
throw new UnsupportedOperationException("Unknown WalkDirection: [" + walkDirection + "]");
}
}
return sequence;
}
use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class ParallelTransformerIteratorTest method testSpeedComparison1.
@Test
public void testSpeedComparison1() throws Exception {
SentenceIterator iterator = new MutipleEpochsSentenceIterator(new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile()), 25);
SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(false).tokenizerFactory(factory).build();
Iterator<Sequence<VocabWord>> iter = transformer.iterator();
int cnt = 0;
long time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
long time2 = System.currentTimeMillis();
log.info("Single-threaded time: {} ms", time2 - time1);
iterator.reset();
transformer = new SentenceTransformer.Builder().iterator(iterator).allowMultithreading(true).tokenizerFactory(factory).build();
iter = transformer.iterator();
time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
time2 = System.currentTimeMillis();
log.info("Multi-threaded time: {} ms", time2 - time1);
SentenceIterator baseIterator = iterator;
baseIterator.reset();
LabelAwareIterator lai = new BasicLabelAwareIterator.Builder(new MutipleEpochsSentenceIterator(new BasicLineIterator(new ClassPathResource("/big/raw_sentences.txt").getFile()), 25)).build();
transformer = new SentenceTransformer.Builder().iterator(lai).allowMultithreading(false).tokenizerFactory(factory).build();
iter = transformer.iterator();
time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
time2 = System.currentTimeMillis();
log.info("Prefetched Single-threaded time: {} ms", time2 - time1);
lai.reset();
transformer = new SentenceTransformer.Builder().iterator(lai).allowMultithreading(true).tokenizerFactory(factory).build();
iter = transformer.iterator();
time1 = System.currentTimeMillis();
while (iter.hasNext()) {
Sequence<VocabWord> sequence = iter.next();
assertNotEquals("Failed on [" + cnt + "] iteration", null, sequence);
assertNotEquals("Failed on [" + cnt + "] iteration", 0, sequence.size());
cnt++;
}
time2 = System.currentTimeMillis();
log.info("Prefetched Multi-threaded time: {} ms", time2 - time1);
}
use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class TokenizerFunction method call.
@Override
public Sequence<VocabWord> call(String s) throws Exception {
if (tokenizerFactory == null)
instantiateTokenizerFactory();
List<String> tokens = tokenizerFactory.create(s).getTokens();
Sequence<VocabWord> seq = new Sequence<>();
for (String token : tokens) {
if (token == null || token.isEmpty())
continue;
seq.addElement(new VocabWord(1.0, token));
}
return seq;
}
use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class TrainingFunction method call.
@Override
@SuppressWarnings("unchecked")
public void call(Sequence<T> sequence) throws Exception {
/**
* Depending on actual training mode, we'll either go for SkipGram/CBOW/PV-DM/PV-DBOW or whatever
*/
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
if (paramServer == null) {
paramServer = VoidParameterServer.getInstance();
if (elementsLearningAlgorithm == null) {
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
driver = elementsLearningAlgorithm.getTrainingDriver();
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
}
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
if (shallowVocabCache == null)
shallowVocabCache = vocabCacheBroadcast.getValue();
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
// TODO: do ELA initialization
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
// TODO: do SLA initialization
try {
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
}
/*
at this moment we should have everything ready for actual initialization
the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
*/
Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
for (T element : sequence.getElements()) {
// it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
if (reduced != null)
mergedSequence.addElement(reduced);
}
// do the same with labels, transfer them, if any
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
for (T label : sequence.getSequenceLabels()) {
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
if (reduced != null)
mergedSequence.addSequenceLabel(reduced);
}
}
// FIXME: temporary hook
if (sequence.size() > 0)
paramServer.execDistributed(elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
else
log.warn("Skipping empty sequence...");
}
use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class SparkParagraphVectors method fitLabelledDocuments.
/**
* This method builds ParagraphVectors model, expecting JavaRDD<LabelledDocument>.
* It can be either non-tokenized documents, or tokenized.
*
* @param documentsRdd
*/
public void fitLabelledDocuments(JavaRDD<LabelledDocument> documentsRdd) {
validateConfiguration();
broadcastEnvironment(new JavaSparkContext(documentsRdd.context()));
JavaRDD<Sequence<VocabWord>> sequenceRDD = documentsRdd.map(new DocumentSequenceConvertFunction(configurationBroadcast));
super.fitSequences(sequenceRDD);
}
Aggregations