use of edu.stanford.nlp.classify.KNNClassifierFactory in project CoreNLP by stanfordnlp.
the class KNNSceneGraphParser method train.
private void train(String trainFile, String modelPath) throws IOException {
Map<Integer, SceneGraphImage> images = loadImages(trainFile);
KNNClassifierFactory<String, String> classifierFactory = new KNNClassifierFactory<String, String>(1, false, false);
List<RVFDatum<String, String>> dataset = Generics.newLinkedList();
for (Integer imgId : images.keySet()) {
SceneGraphImage img = images.get(imgId);
if (img == null) {
continue;
}
for (int i = 0, sz = img.regions.size(); i < sz; i++) {
SceneGraphImageRegion region = img.regions.get(i);
Counter<String> features = new ClassicCounter<String>();
for (CoreLabel token : region.tokens) {
features.incrementCount(token.word());
}
RVFDatum<String, String> datum = new RVFDatum<String, String>(features, String.format("%d_%d", img.id, i));
dataset.add(datum);
}
}
KNNClassifier<String, String> classifier = classifierFactory.train(dataset);
IOUtils.writeObjectToFile(classifier, modelPath);
}
Aggregations