use of edu.stanford.nlp.classify.LinearClassifierFactory in project CoreNLP by stanfordnlp.
the class ChineseMaxentLexicon method finishTraining.
@Override
public void finishTraining() {
IntCounter<String> tagCounter = new IntCounter<>();
WeightedDataset data = new WeightedDataset(datumCounter.size());
for (TaggedWord word : datumCounter.keySet()) {
int count = datumCounter.getIntCount(word);
if (trainOnLowCount && count > trainCountThreshold) {
continue;
}
if (functionWordTags.containsKey(word.word())) {
continue;
}
tagCounter.incrementCount(word.tag());
if (trainByType) {
count = 1;
}
data.add(new BasicDatum(featExtractor.makeFeatures(word.word()), word.tag()), count);
}
datumCounter = null;
tagDist = Distribution.laplaceSmoothedDistribution(tagCounter, tagCounter.size(), 0.5);
tagCounter = null;
applyThresholds(data);
verbose("Making classifier...");
// new ResultStoringMonitor(5, "weights"));
QNMinimizer minim = new QNMinimizer();
// minim.shutUp();
LinearClassifierFactory factory = new LinearClassifierFactory(minim);
factory.setTol(tol);
factory.setSigma(sigma);
if (tuneSigma) {
factory.setTuneSigmaHeldOut();
}
scorer = factory.trainClassifier(data);
verbose("Done training.");
}
use of edu.stanford.nlp.classify.LinearClassifierFactory in project CoreNLP by stanfordnlp.
the class EntityClassifier method train.
private static void train(List<SceneGraphImage> images, String modelPath, Embedding embeddings) throws IOException {
RVFDataset<String, String> dataset = new RVFDataset<String, String>();
SceneGraphSentenceMatcher sentenceMatcher = new SceneGraphSentenceMatcher(embeddings);
for (SceneGraphImage img : images) {
for (SceneGraphImageRegion region : img.regions) {
SemanticGraph sg = region.getEnhancedSemanticGraph();
SemanticGraphEnhancer.enhance(sg);
List<Triple<IndexedWord, IndexedWord, String>> relationTriples = sentenceMatcher.getRelationTriples(region);
for (Triple<IndexedWord, IndexedWord, String> relation : relationTriples) {
IndexedWord w1 = sg.getNodeByIndexSafe(relation.first.index());
if (w1 != null) {
dataset.add(getDatum(w1, relation.first.get(SceneGraphCoreAnnotations.GoldEntityAnnotation.class), embeddings));
}
}
}
}
LinearClassifierFactory<String, String> classifierFactory = new LinearClassifierFactory<String, String>(new QNMinimizer(15), 1e-4, false, REG_STRENGTH);
Classifier<String, String> classifier = classifierFactory.trainClassifier(dataset);
IOUtils.writeObjectToFile(classifier, modelPath);
System.err.println(classifier.evaluateAccuracy(dataset));
}
use of edu.stanford.nlp.classify.LinearClassifierFactory in project CoreNLP by stanfordnlp.
the class BoWSceneGraphParser method train.
/**
* Trains a classifier using the examples in trainingFile and saves
* it to modelPath.
*
* @param trainingFile Path to JSON file with images and scene graphs.
* @param modelPath
* @throws IOException
*/
public void train(String trainingFile, String modelPath) throws IOException {
LinearClassifierFactory<String, String> classifierFactory = new LinearClassifierFactory<String, String>(new QNMinimizer(15), 1e-4, false, REG_STRENGTH);
/* Create dataset. */
Dataset<String, String> dataset = getTrainingExamples(trainingFile, true);
/* Train the classifier. */
Classifier<String, String> classifier = classifierFactory.trainClassifier(dataset);
/* Save classifier to disk. */
IOUtils.writeObjectToFile(classifier, modelPath);
}
Aggregations