Search in sources :

Example 1 with TwoPassRealValueDataIndexer

use of com.joliciel.talismane.machineLearning.maxent.custom.TwoPassRealValueDataIndexer in project talismane by joliciel-informatique.

the class MaxentModelTrainer method trainModel.

@Override
public ClassificationModel trainModel(ClassificationEventStream corpusEventStream, Map<String, List<String>> descriptors) throws IOException {
    MaxentModel maxentModel = null;
    EventStream eventStream = new OpenNLPEventStream(corpusEventStream);
    DataIndexer dataIndexer = new TwoPassRealValueDataIndexer(eventStream, cutoff);
    GISTrainer trainer = new GISTrainer(true);
    if (this.getSmoothing() > 0) {
        trainer.setSmoothing(true);
        trainer.setSmoothingObservation(this.getSmoothing());
    } else if (this.getSigma() > 0) {
        trainer.setGaussianSigma(this.getSigma());
    }
    maxentModel = trainer.trainModel(iterations, dataIndexer, cutoff);
    MaximumEntropyModel model = new MaximumEntropyModel(maxentModel, config, descriptors);
    model.addModelAttribute("cutoff", this.getCutoff());
    model.addModelAttribute("iterations", this.getIterations());
    model.addModelAttribute("sigma", this.getSigma());
    model.addModelAttribute("smoothing", this.getSmoothing());
    model.getModelAttributes().putAll(corpusEventStream.getAttributes());
    return model;
}
Also used : TwoPassRealValueDataIndexer(com.joliciel.talismane.machineLearning.maxent.custom.TwoPassRealValueDataIndexer) DataIndexer(opennlp.model.DataIndexer) ClassificationEventStream(com.joliciel.talismane.machineLearning.ClassificationEventStream) EventStream(opennlp.model.EventStream) MaxentModel(opennlp.model.MaxentModel) GISTrainer(com.joliciel.talismane.machineLearning.maxent.custom.GISTrainer) TwoPassRealValueDataIndexer(com.joliciel.talismane.machineLearning.maxent.custom.TwoPassRealValueDataIndexer)

Example 2 with TwoPassRealValueDataIndexer

use of com.joliciel.talismane.machineLearning.maxent.custom.TwoPassRealValueDataIndexer in project talismane by joliciel-informatique.

the class TwoPassRealValueDataIndexerTest method testDataIndexers.

/**
 * This test sets out to prove that the scale you use on real valued
 * predicates doesn't matter when it comes the probability assigned to each
 * outcome. Strangely, if we use (1,2) and (10,20) there's no difference. If
 * we use (0.1,0.2) and (10,20) there is a difference.
 *
 * @throws Exception
 */
public void testDataIndexers() throws Exception {
    String smallValues = "predA=0.1 predB=0.2 A\n" + "predB=0.3 predA=0.1 B\n";
    String smallTest = "predA=0.2 predB=0.2";
    StringReader smallReader = new StringReader(smallValues);
    EventStream smallEventStream = new RealBasicEventStream(new PlainTextByLineDataStream(smallReader));
    MaxentModel smallModel = GIS.trainModel(100, new OnePassRealValueDataIndexer(smallEventStream, 0), false);
    String[] contexts = smallTest.split(" ");
    float[] values = RealValueFileEventStream.parseContexts(contexts);
    double[] smallResults = smallModel.eval(contexts, values);
    String smallResultString = smallModel.getAllOutcomes(smallResults);
    System.out.println("smallResults: " + smallResultString);
    StringReader smallReaderTwoPass = new StringReader(smallValues);
    EventStream smallEventStreamTwoPass = new RealBasicEventStream(new PlainTextByLineDataStream(smallReaderTwoPass));
    MaxentModel smallModelTwoPass = GIS.trainModel(100, new TwoPassRealValueDataIndexer(smallEventStreamTwoPass, 0), false);
    contexts = smallTest.split(" ");
    values = RealValueFileEventStream.parseContexts(contexts);
    double[] smallResultsTwoPass = smallModelTwoPass.eval(contexts, values);
    String smallResultTwoPassString = smallModel.getAllOutcomes(smallResults);
    System.out.println("smallResults2: " + smallResultTwoPassString);
    assertEquals(smallResults.length, smallResultsTwoPass.length);
    for (int i = 0; i < smallResults.length; i++) {
        System.out.println(String.format("classify with smallModel: %1$s = %2$f", smallModel.getOutcome(i), smallResults[i]));
        System.out.println(String.format("classify with smallModelTwoPass: %1$s = %2$f", smallModelTwoPass.getOutcome(i), smallResultsTwoPass[i]));
        assertEquals(smallResults[i], smallResultsTwoPass[i], 0.01f);
    }
}
Also used : RealBasicEventStream(opennlp.maxent.RealBasicEventStream) RealValueFileEventStream(opennlp.model.RealValueFileEventStream) EventStream(opennlp.model.EventStream) RealBasicEventStream(opennlp.maxent.RealBasicEventStream) StringReader(java.io.StringReader) MaxentModel(opennlp.model.MaxentModel) PlainTextByLineDataStream(opennlp.maxent.PlainTextByLineDataStream) TwoPassRealValueDataIndexer(com.joliciel.talismane.machineLearning.maxent.custom.TwoPassRealValueDataIndexer) OnePassRealValueDataIndexer(opennlp.model.OnePassRealValueDataIndexer)

Aggregations

TwoPassRealValueDataIndexer (com.joliciel.talismane.machineLearning.maxent.custom.TwoPassRealValueDataIndexer)2 EventStream (opennlp.model.EventStream)2 MaxentModel (opennlp.model.MaxentModel)2 ClassificationEventStream (com.joliciel.talismane.machineLearning.ClassificationEventStream)1 GISTrainer (com.joliciel.talismane.machineLearning.maxent.custom.GISTrainer)1 StringReader (java.io.StringReader)1 PlainTextByLineDataStream (opennlp.maxent.PlainTextByLineDataStream)1 RealBasicEventStream (opennlp.maxent.RealBasicEventStream)1 DataIndexer (opennlp.model.DataIndexer)1 OnePassRealValueDataIndexer (opennlp.model.OnePassRealValueDataIndexer)1 RealValueFileEventStream (opennlp.model.RealValueFileEventStream)1