use of edu.stanford.nlp.neural.Embedding in project CoreNLP by stanfordnlp.
the class SentimentModel method readWordVectors.
void readWordVectors() {
Embedding embedding = new Embedding(op.wordVectors, op.numHid);
this.wordVectors = Generics.newTreeMap();
// for (String word : rawWordVectors.keySet()) {
for (String word : embedding.keySet()) {
// TODO: factor out unknown word vector code from DVParser
wordVectors.put(word, embedding.get(word));
}
String unkWord = op.unkWord;
SimpleMatrix unknownWordVector = wordVectors.get(unkWord);
wordVectors.put(UNKNOWN_WORD, unknownWordVector);
if (unknownWordVector == null) {
throw new RuntimeException("Unknown word vector not specified in the word vector file");
}
}
use of edu.stanford.nlp.neural.Embedding in project CoreNLP by stanfordnlp.
the class DVModel method readWordVectors.
public void readWordVectors() {
SimpleMatrix unknownNumberVector = null;
SimpleMatrix unknownCapsVector = null;
SimpleMatrix unknownChineseYearVector = null;
SimpleMatrix unknownChineseNumberVector = null;
SimpleMatrix unknownChinesePercentVector = null;
wordVectors = Generics.newTreeMap();
int numberCount = 0;
int capsCount = 0;
int chineseYearCount = 0;
int chineseNumberCount = 0;
int chinesePercentCount = 0;
//Map<String, SimpleMatrix> rawWordVectors = NeuralUtils.readRawWordVectors(op.lexOptions.wordVectorFile, op.lexOptions.numHid);
Embedding rawWordVectors = new Embedding(op.lexOptions.wordVectorFile, op.lexOptions.numHid);
for (String word : rawWordVectors.keySet()) {
SimpleMatrix vector = rawWordVectors.get(word);
if (op.wordFunction != null) {
word = op.wordFunction.apply(word);
}
wordVectors.put(word, vector);
if (op.lexOptions.numHid <= 0) {
op.lexOptions.numHid = vector.getNumElements();
}
// TODO: factor out all of these identical blobs
if (op.trainOptions.unknownNumberVector && (NUMBER_PATTERN.matcher(word).matches() || DG_PATTERN.matcher(word).matches())) {
++numberCount;
if (unknownNumberVector == null) {
unknownNumberVector = new SimpleMatrix(vector);
} else {
unknownNumberVector = unknownNumberVector.plus(vector);
}
}
if (op.trainOptions.unknownCapsVector && CAPS_PATTERN.matcher(word).matches()) {
++capsCount;
if (unknownCapsVector == null) {
unknownCapsVector = new SimpleMatrix(vector);
} else {
unknownCapsVector = unknownCapsVector.plus(vector);
}
}
if (op.trainOptions.unknownChineseYearVector && CHINESE_YEAR_PATTERN.matcher(word).matches()) {
++chineseYearCount;
if (unknownChineseYearVector == null) {
unknownChineseYearVector = new SimpleMatrix(vector);
} else {
unknownChineseYearVector = unknownChineseYearVector.plus(vector);
}
}
if (op.trainOptions.unknownChineseNumberVector && (CHINESE_NUMBER_PATTERN.matcher(word).matches() || DG_PATTERN.matcher(word).matches())) {
++chineseNumberCount;
if (unknownChineseNumberVector == null) {
unknownChineseNumberVector = new SimpleMatrix(vector);
} else {
unknownChineseNumberVector = unknownChineseNumberVector.plus(vector);
}
}
if (op.trainOptions.unknownChinesePercentVector && CHINESE_PERCENT_PATTERN.matcher(word).matches()) {
++chinesePercentCount;
if (unknownChinesePercentVector == null) {
unknownChinesePercentVector = new SimpleMatrix(vector);
} else {
unknownChinesePercentVector = unknownChinesePercentVector.plus(vector);
}
}
}
String unkWord = op.trainOptions.unkWord;
if (op.wordFunction != null) {
unkWord = op.wordFunction.apply(unkWord);
}
SimpleMatrix unknownWordVector = wordVectors.get(unkWord);
wordVectors.put(UNKNOWN_WORD, unknownWordVector);
if (unknownWordVector == null) {
throw new RuntimeException("Unknown word vector not specified in the word vector file");
}
if (op.trainOptions.unknownNumberVector) {
if (numberCount > 0) {
unknownNumberVector = unknownNumberVector.divide(numberCount);
} else {
unknownNumberVector = new SimpleMatrix(unknownWordVector);
}
wordVectors.put(UNKNOWN_NUMBER, unknownNumberVector);
}
if (op.trainOptions.unknownCapsVector) {
if (capsCount > 0) {
unknownCapsVector = unknownCapsVector.divide(capsCount);
} else {
unknownCapsVector = new SimpleMatrix(unknownWordVector);
}
wordVectors.put(UNKNOWN_CAPS, unknownCapsVector);
}
if (op.trainOptions.unknownChineseYearVector) {
log.info("Matched " + chineseYearCount + " chinese year vectors");
if (chineseYearCount > 0) {
unknownChineseYearVector = unknownChineseYearVector.divide(chineseYearCount);
} else {
unknownChineseYearVector = new SimpleMatrix(unknownWordVector);
}
wordVectors.put(UNKNOWN_CHINESE_YEAR, unknownChineseYearVector);
}
if (op.trainOptions.unknownChineseNumberVector) {
log.info("Matched " + chineseNumberCount + " chinese number vectors");
if (chineseNumberCount > 0) {
unknownChineseNumberVector = unknownChineseNumberVector.divide(chineseNumberCount);
} else {
unknownChineseNumberVector = new SimpleMatrix(unknownWordVector);
}
wordVectors.put(UNKNOWN_CHINESE_NUMBER, unknownChineseNumberVector);
}
if (op.trainOptions.unknownChinesePercentVector) {
log.info("Matched " + chinesePercentCount + " chinese percent vectors");
if (chinesePercentCount > 0) {
unknownChinesePercentVector = unknownChinesePercentVector.divide(chinesePercentCount);
} else {
unknownChinesePercentVector = new SimpleMatrix(unknownWordVector);
}
wordVectors.put(UNKNOWN_CHINESE_PERCENT, unknownChinesePercentVector);
}
if (op.trainOptions.useContextWords) {
SimpleMatrix start = SimpleMatrix.random(op.lexOptions.numHid, 1, -0.5, 0.5, rand);
SimpleMatrix end = SimpleMatrix.random(op.lexOptions.numHid, 1, -0.5, 0.5, rand);
wordVectors.put(START_WORD, start);
wordVectors.put(END_WORD, end);
}
}
Aggregations