Search in sources :

Example 1 with Window

use of org.deeplearning4j.text.movingwindow.Window in project deeplearning4j by deeplearning4j.

the class Word2VecDataFetcher method next.

@Override
public DataSet next() {
    //pop from cache when possible, or when there's nothing left
    if (cache.size() >= batch || !files.hasNext())
        return fromCache();
    File f = files.next();
    try {
        LineIterator lines = FileUtils.lineIterator(f);
        INDArray outcomes = null;
        INDArray input = null;
        while (lines.hasNext()) {
            List<Window> windows = Windows.windows(lines.nextLine());
            if (windows.isEmpty() && lines.hasNext())
                continue;
            if (windows.size() < batch) {
                input = Nd4j.create(windows.size(), vec.lookupTable().layerSize() * vec.getWindow());
                outcomes = Nd4j.create(batch, labels.size());
                for (int i = 0; i < windows.size(); i++) {
                    input.putRow(i, WindowConverter.asExampleMatrix(cache.get(i), vec));
                    int idx = labels.indexOf(windows.get(i).getLabel());
                    if (idx < 0)
                        idx = 0;
                    INDArray outcomeRow = FeatureUtil.toOutcomeVector(idx, labels.size());
                    outcomes.putRow(i, outcomeRow);
                }
                return new DataSet(input, outcomes);
            } else {
                input = Nd4j.create(batch, vec.lookupTable().layerSize() * vec.getWindow());
                outcomes = Nd4j.create(batch, labels.size());
                for (int i = 0; i < batch; i++) {
                    input.putRow(i, WindowConverter.asExampleMatrix(cache.get(i), vec));
                    int idx = labels.indexOf(windows.get(i).getLabel());
                    if (idx < 0)
                        idx = 0;
                    INDArray outcomeRow = FeatureUtil.toOutcomeVector(idx, labels.size());
                    outcomes.putRow(i, outcomeRow);
                }
                /*
                     * Note that I'm aware of possible concerns for sentence sequencing.
                     * This is a hack right now in place of something
                     * that will be way more elegant in the future.
                     */
                if (windows.size() > batch) {
                    List<Window> leftOvers = windows.subList(batch, windows.size());
                    cache.addAll(leftOvers);
                }
                return new DataSet(input, outcomes);
            }
        }
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    return null;
}
Also used : Window(org.deeplearning4j.text.movingwindow.Window) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) IOException(java.io.IOException) File(java.io.File) LineIterator(org.apache.commons.io.LineIterator)

Example 2 with Window

use of org.deeplearning4j.text.movingwindow.Window in project deeplearning4j by deeplearning4j.

the class Word2VecDataSetIterator method fromCached.

private DataSet fromCached(int num) {
    if (cachedWindow.isEmpty()) {
        while (cachedWindow.size() < num && iter.hasNext()) {
            String sentence = iter.nextSentence();
            if (sentence.isEmpty())
                continue;
            List<Window> windows = Windows.windows(sentence, vec.getTokenizerFactory(), vec.getWindow(), vec);
            for (Window w : windows) w.setLabel(iter.currentLabel());
            cachedWindow.addAll(windows);
        }
    }
    List<Window> windows = new ArrayList<>(num);
    for (int i = 0; i < num; i++) {
        if (cachedWindow.isEmpty())
            break;
        windows.add(cachedWindow.remove(0));
    }
    if (windows.isEmpty())
        return null;
    INDArray inputs = Nd4j.create(num, inputColumns());
    for (int i = 0; i < inputs.rows(); i++) {
        inputs.putRow(i, WindowConverter.asExampleMatrix(windows.get(i), vec));
    }
    INDArray labelOutput = Nd4j.create(num, labels.size());
    for (int i = 0; i < labelOutput.rows(); i++) {
        String label = windows.get(i).getLabel();
        labelOutput.putRow(i, FeatureUtil.toOutcomeVector(labels.indexOf(label), labels.size()));
    }
    DataSet ret = new DataSet(inputs, labelOutput);
    if (preProcessor != null)
        preProcessor.preProcess(ret);
    return ret;
}
Also used : Window(org.deeplearning4j.text.movingwindow.Window) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList)

Aggregations

Window (org.deeplearning4j.text.movingwindow.Window)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 DataSet (org.nd4j.linalg.dataset.DataSet)2 File (java.io.File)1 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)1 LineIterator (org.apache.commons.io.LineIterator)1