use of opennlp.model.MaxentModel in project talismane by joliciel-informatique.
the class MaxentModelTrainer method trainModel.
@Override
public ClassificationModel trainModel(ClassificationEventStream corpusEventStream, Map<String, List<String>> descriptors) throws IOException {
MaxentModel maxentModel = null;
EventStream eventStream = new OpenNLPEventStream(corpusEventStream);
DataIndexer dataIndexer = new TwoPassRealValueDataIndexer(eventStream, cutoff);
GISTrainer trainer = new GISTrainer(true);
if (this.getSmoothing() > 0) {
trainer.setSmoothing(true);
trainer.setSmoothingObservation(this.getSmoothing());
} else if (this.getSigma() > 0) {
trainer.setGaussianSigma(this.getSigma());
}
maxentModel = trainer.trainModel(iterations, dataIndexer, cutoff);
MaximumEntropyModel model = new MaximumEntropyModel(maxentModel, config, descriptors);
model.addModelAttribute("cutoff", this.getCutoff());
model.addModelAttribute("iterations", this.getIterations());
model.addModelAttribute("sigma", this.getSigma());
model.addModelAttribute("smoothing", this.getSmoothing());
model.getModelAttributes().putAll(corpusEventStream.getAttributes());
return model;
}
use of opennlp.model.MaxentModel in project talismane by joliciel-informatique.
the class TwoPassRealValueDataIndexerTest method testDataIndexers.
/**
* This test sets out to prove that the scale you use on real valued
* predicates doesn't matter when it comes the probability assigned to each
* outcome. Strangely, if we use (1,2) and (10,20) there's no difference. If
* we use (0.1,0.2) and (10,20) there is a difference.
*
* @throws Exception
*/
public void testDataIndexers() throws Exception {
String smallValues = "predA=0.1 predB=0.2 A\n" + "predB=0.3 predA=0.1 B\n";
String smallTest = "predA=0.2 predB=0.2";
StringReader smallReader = new StringReader(smallValues);
EventStream smallEventStream = new RealBasicEventStream(new PlainTextByLineDataStream(smallReader));
MaxentModel smallModel = GIS.trainModel(100, new OnePassRealValueDataIndexer(smallEventStream, 0), false);
String[] contexts = smallTest.split(" ");
float[] values = RealValueFileEventStream.parseContexts(contexts);
double[] smallResults = smallModel.eval(contexts, values);
String smallResultString = smallModel.getAllOutcomes(smallResults);
System.out.println("smallResults: " + smallResultString);
StringReader smallReaderTwoPass = new StringReader(smallValues);
EventStream smallEventStreamTwoPass = new RealBasicEventStream(new PlainTextByLineDataStream(smallReaderTwoPass));
MaxentModel smallModelTwoPass = GIS.trainModel(100, new TwoPassRealValueDataIndexer(smallEventStreamTwoPass, 0), false);
contexts = smallTest.split(" ");
values = RealValueFileEventStream.parseContexts(contexts);
double[] smallResultsTwoPass = smallModelTwoPass.eval(contexts, values);
String smallResultTwoPassString = smallModel.getAllOutcomes(smallResults);
System.out.println("smallResults2: " + smallResultTwoPassString);
assertEquals(smallResults.length, smallResultsTwoPass.length);
for (int i = 0; i < smallResults.length; i++) {
System.out.println(String.format("classify with smallModel: %1$s = %2$f", smallModel.getOutcome(i), smallResults[i]));
System.out.println(String.format("classify with smallModelTwoPass: %1$s = %2$f", smallModelTwoPass.getOutcome(i), smallResultsTwoPass[i]));
assertEquals(smallResults[i], smallResultsTwoPass[i], 0.01f);
}
}
Aggregations