Search in sources :

Example 1 with Embedding

use of edu.stanford.nlp.neural.Embedding in project CoreNLP by stanfordnlp.

the class SentimentModel method readWordVectors.

void readWordVectors() {
    Embedding embedding = new Embedding(op.wordVectors, op.numHid);
    this.wordVectors = Generics.newTreeMap();
    //    for (String word : rawWordVectors.keySet()) {
    for (String word : embedding.keySet()) {
        // TODO: factor out unknown word vector code from DVParser
        wordVectors.put(word, embedding.get(word));
    }
    String unkWord = op.unkWord;
    SimpleMatrix unknownWordVector = wordVectors.get(unkWord);
    wordVectors.put(UNKNOWN_WORD, unknownWordVector);
    if (unknownWordVector == null) {
        throw new RuntimeException("Unknown word vector not specified in the word vector file");
    }
}
Also used : SimpleMatrix(org.ejml.simple.SimpleMatrix) Embedding(edu.stanford.nlp.neural.Embedding)

Example 2 with Embedding

use of edu.stanford.nlp.neural.Embedding in project CoreNLP by stanfordnlp.

the class DVModel method readWordVectors.

public void readWordVectors() {
    SimpleMatrix unknownNumberVector = null;
    SimpleMatrix unknownCapsVector = null;
    SimpleMatrix unknownChineseYearVector = null;
    SimpleMatrix unknownChineseNumberVector = null;
    SimpleMatrix unknownChinesePercentVector = null;
    wordVectors = Generics.newTreeMap();
    int numberCount = 0;
    int capsCount = 0;
    int chineseYearCount = 0;
    int chineseNumberCount = 0;
    int chinesePercentCount = 0;
    //Map<String, SimpleMatrix> rawWordVectors = NeuralUtils.readRawWordVectors(op.lexOptions.wordVectorFile, op.lexOptions.numHid);
    Embedding rawWordVectors = new Embedding(op.lexOptions.wordVectorFile, op.lexOptions.numHid);
    for (String word : rawWordVectors.keySet()) {
        SimpleMatrix vector = rawWordVectors.get(word);
        if (op.wordFunction != null) {
            word = op.wordFunction.apply(word);
        }
        wordVectors.put(word, vector);
        if (op.lexOptions.numHid <= 0) {
            op.lexOptions.numHid = vector.getNumElements();
        }
        // TODO: factor out all of these identical blobs
        if (op.trainOptions.unknownNumberVector && (NUMBER_PATTERN.matcher(word).matches() || DG_PATTERN.matcher(word).matches())) {
            ++numberCount;
            if (unknownNumberVector == null) {
                unknownNumberVector = new SimpleMatrix(vector);
            } else {
                unknownNumberVector = unknownNumberVector.plus(vector);
            }
        }
        if (op.trainOptions.unknownCapsVector && CAPS_PATTERN.matcher(word).matches()) {
            ++capsCount;
            if (unknownCapsVector == null) {
                unknownCapsVector = new SimpleMatrix(vector);
            } else {
                unknownCapsVector = unknownCapsVector.plus(vector);
            }
        }
        if (op.trainOptions.unknownChineseYearVector && CHINESE_YEAR_PATTERN.matcher(word).matches()) {
            ++chineseYearCount;
            if (unknownChineseYearVector == null) {
                unknownChineseYearVector = new SimpleMatrix(vector);
            } else {
                unknownChineseYearVector = unknownChineseYearVector.plus(vector);
            }
        }
        if (op.trainOptions.unknownChineseNumberVector && (CHINESE_NUMBER_PATTERN.matcher(word).matches() || DG_PATTERN.matcher(word).matches())) {
            ++chineseNumberCount;
            if (unknownChineseNumberVector == null) {
                unknownChineseNumberVector = new SimpleMatrix(vector);
            } else {
                unknownChineseNumberVector = unknownChineseNumberVector.plus(vector);
            }
        }
        if (op.trainOptions.unknownChinesePercentVector && CHINESE_PERCENT_PATTERN.matcher(word).matches()) {
            ++chinesePercentCount;
            if (unknownChinesePercentVector == null) {
                unknownChinesePercentVector = new SimpleMatrix(vector);
            } else {
                unknownChinesePercentVector = unknownChinesePercentVector.plus(vector);
            }
        }
    }
    String unkWord = op.trainOptions.unkWord;
    if (op.wordFunction != null) {
        unkWord = op.wordFunction.apply(unkWord);
    }
    SimpleMatrix unknownWordVector = wordVectors.get(unkWord);
    wordVectors.put(UNKNOWN_WORD, unknownWordVector);
    if (unknownWordVector == null) {
        throw new RuntimeException("Unknown word vector not specified in the word vector file");
    }
    if (op.trainOptions.unknownNumberVector) {
        if (numberCount > 0) {
            unknownNumberVector = unknownNumberVector.divide(numberCount);
        } else {
            unknownNumberVector = new SimpleMatrix(unknownWordVector);
        }
        wordVectors.put(UNKNOWN_NUMBER, unknownNumberVector);
    }
    if (op.trainOptions.unknownCapsVector) {
        if (capsCount > 0) {
            unknownCapsVector = unknownCapsVector.divide(capsCount);
        } else {
            unknownCapsVector = new SimpleMatrix(unknownWordVector);
        }
        wordVectors.put(UNKNOWN_CAPS, unknownCapsVector);
    }
    if (op.trainOptions.unknownChineseYearVector) {
        log.info("Matched " + chineseYearCount + " chinese year vectors");
        if (chineseYearCount > 0) {
            unknownChineseYearVector = unknownChineseYearVector.divide(chineseYearCount);
        } else {
            unknownChineseYearVector = new SimpleMatrix(unknownWordVector);
        }
        wordVectors.put(UNKNOWN_CHINESE_YEAR, unknownChineseYearVector);
    }
    if (op.trainOptions.unknownChineseNumberVector) {
        log.info("Matched " + chineseNumberCount + " chinese number vectors");
        if (chineseNumberCount > 0) {
            unknownChineseNumberVector = unknownChineseNumberVector.divide(chineseNumberCount);
        } else {
            unknownChineseNumberVector = new SimpleMatrix(unknownWordVector);
        }
        wordVectors.put(UNKNOWN_CHINESE_NUMBER, unknownChineseNumberVector);
    }
    if (op.trainOptions.unknownChinesePercentVector) {
        log.info("Matched " + chinesePercentCount + " chinese percent vectors");
        if (chinesePercentCount > 0) {
            unknownChinesePercentVector = unknownChinesePercentVector.divide(chinesePercentCount);
        } else {
            unknownChinesePercentVector = new SimpleMatrix(unknownWordVector);
        }
        wordVectors.put(UNKNOWN_CHINESE_PERCENT, unknownChinesePercentVector);
    }
    if (op.trainOptions.useContextWords) {
        SimpleMatrix start = SimpleMatrix.random(op.lexOptions.numHid, 1, -0.5, 0.5, rand);
        SimpleMatrix end = SimpleMatrix.random(op.lexOptions.numHid, 1, -0.5, 0.5, rand);
        wordVectors.put(START_WORD, start);
        wordVectors.put(END_WORD, end);
    }
}
Also used : SimpleMatrix(org.ejml.simple.SimpleMatrix) Embedding(edu.stanford.nlp.neural.Embedding)

Aggregations

Embedding (edu.stanford.nlp.neural.Embedding)2 SimpleMatrix (org.ejml.simple.SimpleMatrix)2