use of org.deeplearning4j.text.movingwindow.Window in project deeplearning4j by deeplearning4j.
the class Word2VecDataFetcher method next.
@Override
public DataSet next() {
//pop from cache when possible, or when there's nothing left
if (cache.size() >= batch || !files.hasNext())
return fromCache();
File f = files.next();
try {
LineIterator lines = FileUtils.lineIterator(f);
INDArray outcomes = null;
INDArray input = null;
while (lines.hasNext()) {
List<Window> windows = Windows.windows(lines.nextLine());
if (windows.isEmpty() && lines.hasNext())
continue;
if (windows.size() < batch) {
input = Nd4j.create(windows.size(), vec.lookupTable().layerSize() * vec.getWindow());
outcomes = Nd4j.create(batch, labels.size());
for (int i = 0; i < windows.size(); i++) {
input.putRow(i, WindowConverter.asExampleMatrix(cache.get(i), vec));
int idx = labels.indexOf(windows.get(i).getLabel());
if (idx < 0)
idx = 0;
INDArray outcomeRow = FeatureUtil.toOutcomeVector(idx, labels.size());
outcomes.putRow(i, outcomeRow);
}
return new DataSet(input, outcomes);
} else {
input = Nd4j.create(batch, vec.lookupTable().layerSize() * vec.getWindow());
outcomes = Nd4j.create(batch, labels.size());
for (int i = 0; i < batch; i++) {
input.putRow(i, WindowConverter.asExampleMatrix(cache.get(i), vec));
int idx = labels.indexOf(windows.get(i).getLabel());
if (idx < 0)
idx = 0;
INDArray outcomeRow = FeatureUtil.toOutcomeVector(idx, labels.size());
outcomes.putRow(i, outcomeRow);
}
/*
* Note that I'm aware of possible concerns for sentence sequencing.
* This is a hack right now in place of something
* that will be way more elegant in the future.
*/
if (windows.size() > batch) {
List<Window> leftOvers = windows.subList(batch, windows.size());
cache.addAll(leftOvers);
}
return new DataSet(input, outcomes);
}
}
} catch (IOException e) {
throw new RuntimeException(e);
}
return null;
}
use of org.deeplearning4j.text.movingwindow.Window in project deeplearning4j by deeplearning4j.
the class Word2VecDataSetIterator method fromCached.
private DataSet fromCached(int num) {
if (cachedWindow.isEmpty()) {
while (cachedWindow.size() < num && iter.hasNext()) {
String sentence = iter.nextSentence();
if (sentence.isEmpty())
continue;
List<Window> windows = Windows.windows(sentence, vec.getTokenizerFactory(), vec.getWindow(), vec);
for (Window w : windows) w.setLabel(iter.currentLabel());
cachedWindow.addAll(windows);
}
}
List<Window> windows = new ArrayList<>(num);
for (int i = 0; i < num; i++) {
if (cachedWindow.isEmpty())
break;
windows.add(cachedWindow.remove(0));
}
if (windows.isEmpty())
return null;
INDArray inputs = Nd4j.create(num, inputColumns());
for (int i = 0; i < inputs.rows(); i++) {
inputs.putRow(i, WindowConverter.asExampleMatrix(windows.get(i), vec));
}
INDArray labelOutput = Nd4j.create(num, labels.size());
for (int i = 0; i < labelOutput.rows(); i++) {
String label = windows.get(i).getLabel();
labelOutput.putRow(i, FeatureUtil.toOutcomeVector(labels.indexOf(label), labels.size()));
}
DataSet ret = new DataSet(inputs, labelOutput);
if (preProcessor != null)
preProcessor.preProcess(ret);
return ret;
}
Aggregations