Search in sources :

Example 11 with ClassificationModel

use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.

the class LanguageDetectorTrainer method train.

public ClassificationModel train() throws TalismaneException, IOException {
    ModelTrainerFactory factory = new ModelTrainerFactory();
    ClassificationModelTrainer trainer = factory.constructTrainer(languageConfig.getConfig("train.machine-learning"));
    ClassificationModel model = trainer.trainModel(eventStream, descriptors);
    model.setExternalResources(session.getExternalResourceFinder().getExternalResources());
    File modelDir = modelFile.getParentFile();
    if (modelDir != null)
        modelDir.mkdirs();
    model.persist(modelFile);
    return model;
}
Also used : ModelTrainerFactory(com.joliciel.talismane.machineLearning.ModelTrainerFactory) ClassificationModelTrainer(com.joliciel.talismane.machineLearning.ClassificationModelTrainer) File(java.io.File) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel)

Example 12 with ClassificationModel

use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.

the class PerceptronClassificationModelTrainer method train.

void train() {
    try {
        double prevAccuracy1 = 0.0;
        double prevAccuracy2 = 0.0;
        double prevAccuracy3 = 0.0;
        int i = 0;
        int averagingCount = 0;
        for (i = 1; i <= iterations; i++) {
            LOG.debug("Iteration " + i);
            int totalErrors = 0;
            int totalEvents = 0;
            try (Scanner scanner = new Scanner(new BufferedReader(new InputStreamReader(new FileInputStream(eventFile), "UTF-8")))) {
                while (scanner.hasNextLine()) {
                    String line = scanner.nextLine();
                    PerceptronEvent event = new PerceptronEvent(line);
                    totalEvents++;
                    // don't normalise unless we calculate the
                    // log-likelihood,
                    // to avoid mathematical cost of normalising
                    double[] results = decisionMaker.predict(event.getFeatureIndexes(), event.getFeatureValues());
                    double maxValue = results[0];
                    int predicted = 0;
                    for (int j = 1; j < results.length; j++) {
                        if (results[j] > maxValue) {
                            maxValue = results[j];
                            predicted = j;
                        }
                    }
                    int actual = event.getOutcomeIndex();
                    if (actual != predicted) {
                        for (int j = 0; j < event.getFeatureIndexes().size(); j++) {
                            double[] classWeights = params.getFeatureWeights()[event.getFeatureIndexes().get(j)];
                            classWeights[actual] += event.getFeatureValues().get(j);
                            classWeights[predicted] -= event.getFeatureValues().get(j);
                        }
                        totalErrors++;
                    }
                // correct outcome?
                }
            // next event
            }
            // Add feature weights for this iteration
            boolean addAverage = true;
            if (this.isAverageAtIntervals()) {
                if (i <= 20 || i == 25 || i == 36 || i == 49 || i == 64 || i == 81 || i == 100 || i == 121 || i == 144 || i == 169 || i == 196) {
                    addAverage = true;
                    LOG.debug("Averaging at iteration: " + i);
                } else
                    addAverage = false;
            }
            if (addAverage) {
                for (int j = 0; j < params.getFeatureWeights().length; j++) {
                    double[] totalClassWeights = totalFeatureWeights[j];
                    double[] classWeights = params.getFeatureWeights()[j];
                    for (int k = 0; k < params.getOutcomeCount(); k++) {
                        totalClassWeights[k] += classWeights[k];
                    }
                }
                averagingCount++;
            }
            if (observer != null && observationPoints.contains(i)) {
                PerceptronModelParameters cloneParams = params.clone();
                // average the weights for this model
                for (int j = 0; j < cloneParams.getFeatureWeights().length; j++) {
                    double[] totalClassWeights = totalFeatureWeights[j];
                    double[] classWeights = cloneParams.getFeatureWeights()[j];
                    for (int k = 0; k < cloneParams.getOutcomeCount(); k++) {
                        classWeights[k] = totalClassWeights[k] / averagingCount;
                    }
                }
                ClassificationModel model = this.getModel(cloneParams, i);
                observer.onNextModel(model, i);
                cloneParams = null;
            }
            double accuracy = (double) (totalEvents - totalErrors) / (double) totalEvents;
            LOG.debug("Accuracy: " + accuracy);
            // exit if accuracy hasn't significantly changed in 3 iterations
            if (Math.abs(accuracy - prevAccuracy1) < tolerance && Math.abs(accuracy - prevAccuracy2) < tolerance && Math.abs(accuracy - prevAccuracy3) < tolerance) {
                LOG.info("Accuracy change < " + tolerance + " for 3 iterations: exiting after " + i + " iterations");
                break;
            }
            prevAccuracy3 = prevAccuracy2;
            prevAccuracy2 = prevAccuracy1;
            prevAccuracy1 = accuracy;
        }
        // average the final weights
        for (int j = 0; j < params.getFeatureWeights().length; j++) {
            double[] totalClassWeights = totalFeatureWeights[j];
            double[] classWeights = params.getFeatureWeights()[j];
            for (int k = 0; k < params.getOutcomeCount(); k++) {
                classWeights[k] = totalClassWeights[k] / averagingCount;
            }
        }
    } catch (IOException e) {
        LogUtils.logError(LOG, e);
        throw new RuntimeException(e);
    }
}
Also used : Scanner(java.util.Scanner) InputStreamReader(java.io.InputStreamReader) IOException(java.io.IOException) FileInputStream(java.io.FileInputStream) BufferedReader(java.io.BufferedReader) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel)

Example 13 with ClassificationModel

use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.

the class SentenceDetector method getInstance.

public static SentenceDetector getInstance(String sessionId) throws IOException, ClassNotFoundException {
    SentenceDetector sentenceDetector = sentenceDetectorMap.get(sessionId);
    if (sentenceDetector == null) {
        Config config = ConfigFactory.load();
        String configPath = "talismane.core." + sessionId + ".sentence-detector.model";
        String modelFilePath = config.getString(configPath);
        ClassificationModel sentenceModel = modelMap.get(modelFilePath);
        if (sentenceModel == null) {
            InputStream modelFile = ConfigUtils.getFileFromConfig(config, configPath);
            MachineLearningModelFactory factory = new MachineLearningModelFactory();
            sentenceModel = factory.getClassificationModel(new ZipInputStream(modelFile));
            modelMap.put(modelFilePath, sentenceModel);
        }
        sentenceDetector = new SentenceDetector(sentenceModel, sessionId);
        sentenceDetectorMap.put(sessionId, sentenceDetector);
    }
    return sentenceDetector.cloneSentenceDetector();
}
Also used : ZipInputStream(java.util.zip.ZipInputStream) Config(com.typesafe.config.Config) ZipInputStream(java.util.zip.ZipInputStream) InputStream(java.io.InputStream) MachineLearningModelFactory(com.joliciel.talismane.machineLearning.MachineLearningModelFactory) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel)

Example 14 with ClassificationModel

use of com.joliciel.talismane.machineLearning.ClassificationModel in project jochre by urieli.

the class Jochre method doCommandTrain.

/**
 * Train a letter guessing model.
 *
 * @param featureDescriptors
 *          the feature descriptors for training
 * @param criteria
 *          criteria for selecting images to include when training
 * @param reconstructLetters
 *          whether or not complete letters should be reconstructed for
 *          training, from merged/split letters
 */
public void doCommandTrain(List<String> featureDescriptors, CorpusSelectionCriteria criteria, boolean reconstructLetters) {
    if (jochreSession.getLetterModelPath() == null)
        throw new RuntimeException("Missing argument: letterModel");
    if (featureDescriptors == null)
        throw new JochreException("features is required");
    LetterFeatureParser letterFeatureParser = new LetterFeatureParser();
    Set<LetterFeature<?>> features = letterFeatureParser.getLetterFeatureSet(featureDescriptors);
    BoundaryDetector boundaryDetector = null;
    if (reconstructLetters) {
        ShapeSplitter splitter = new TrainingCorpusShapeSplitter(jochreSession);
        ShapeMerger merger = new TrainingCorpusShapeMerger();
        boundaryDetector = new LetterByLetterBoundaryDetector(splitter, merger, jochreSession);
    } else {
        boundaryDetector = new OriginalBoundaryDetector();
    }
    LetterValidator letterValidator = new ComponentCharacterValidator(jochreSession);
    ClassificationEventStream corpusEventStream = new JochreLetterEventStream(features, boundaryDetector, letterValidator, criteria, jochreSession);
    File letterModelFile = new File(jochreSession.getLetterModelPath());
    letterModelFile.getParentFile().mkdirs();
    ModelTrainerFactory modelTrainerFactory = new ModelTrainerFactory();
    ClassificationModelTrainer trainer = modelTrainerFactory.constructTrainer(jochreSession.getConfig());
    ClassificationModel letterModel = trainer.trainModel(corpusEventStream, featureDescriptors);
    letterModel.persist(letterModelFile);
}
Also used : LetterByLetterBoundaryDetector(com.joliciel.jochre.boundaries.LetterByLetterBoundaryDetector) OriginalBoundaryDetector(com.joliciel.jochre.boundaries.OriginalBoundaryDetector) BoundaryDetector(com.joliciel.jochre.boundaries.BoundaryDetector) LetterByLetterBoundaryDetector(com.joliciel.jochre.boundaries.LetterByLetterBoundaryDetector) DeterministicBoundaryDetector(com.joliciel.jochre.boundaries.DeterministicBoundaryDetector) TrainingCorpusShapeMerger(com.joliciel.jochre.boundaries.TrainingCorpusShapeMerger) LetterValidator(com.joliciel.jochre.letterGuesser.LetterValidator) ClassificationEventStream(com.joliciel.talismane.machineLearning.ClassificationEventStream) OriginalBoundaryDetector(com.joliciel.jochre.boundaries.OriginalBoundaryDetector) JochreLetterEventStream(com.joliciel.jochre.letterGuesser.JochreLetterEventStream) ModelTrainerFactory(com.joliciel.talismane.machineLearning.ModelTrainerFactory) JochreException(com.joliciel.jochre.utils.JochreException) ClassificationModelTrainer(com.joliciel.talismane.machineLearning.ClassificationModelTrainer) LetterFeature(com.joliciel.jochre.letterGuesser.features.LetterFeature) TrainingCorpusShapeMerger(com.joliciel.jochre.boundaries.TrainingCorpusShapeMerger) ShapeMerger(com.joliciel.jochre.boundaries.ShapeMerger) LetterFeatureParser(com.joliciel.jochre.letterGuesser.features.LetterFeatureParser) TrainingCorpusShapeSplitter(com.joliciel.jochre.boundaries.TrainingCorpusShapeSplitter) RecursiveShapeSplitter(com.joliciel.jochre.boundaries.RecursiveShapeSplitter) TrainingCorpusShapeSplitter(com.joliciel.jochre.boundaries.TrainingCorpusShapeSplitter) ShapeSplitter(com.joliciel.jochre.boundaries.ShapeSplitter) ComponentCharacterValidator(com.joliciel.jochre.letterGuesser.ComponentCharacterValidator) File(java.io.File) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel)

Example 15 with ClassificationModel

use of com.joliciel.talismane.machineLearning.ClassificationModel in project jochre by urieli.

the class Jochre method doCommandTrainMerge.

/**
 * Train the letter merging model.
 *
 * @param featureDescriptors
 *          feature descriptors for training
 * @param multiplier
 *          if &gt; 0, will be used to equalize the outcomes
 * @param criteria
 *          the criteria used to select the training corpus
 */
public void doCommandTrainMerge(List<String> featureDescriptors, int multiplier, CorpusSelectionCriteria criteria) {
    if (jochreSession.getMergeModelPath() == null)
        throw new RuntimeException("Missing argument: mergeModel");
    if (featureDescriptors == null)
        throw new JochreException("features is required");
    File mergeModelFile = new File(jochreSession.getMergeModelPath());
    mergeModelFile.getParentFile().mkdirs();
    MergeFeatureParser mergeFeatureParser = new MergeFeatureParser();
    Set<MergeFeature<?>> mergeFeatures = mergeFeatureParser.getMergeFeatureSet(featureDescriptors);
    ClassificationEventStream corpusEventStream = new JochreMergeEventStream(criteria, mergeFeatures, jochreSession);
    if (multiplier > 0) {
        corpusEventStream = new OutcomeEqualiserEventStream(corpusEventStream, multiplier);
    }
    ModelTrainerFactory modelTrainerFactory = new ModelTrainerFactory();
    ClassificationModelTrainer trainer = modelTrainerFactory.constructTrainer(jochreSession.getConfig());
    ClassificationModel mergeModel = trainer.trainModel(corpusEventStream, featureDescriptors);
    mergeModel.persist(mergeModelFile);
}
Also used : MergeFeatureParser(com.joliciel.jochre.boundaries.features.MergeFeatureParser) ClassificationEventStream(com.joliciel.talismane.machineLearning.ClassificationEventStream) ModelTrainerFactory(com.joliciel.talismane.machineLearning.ModelTrainerFactory) JochreException(com.joliciel.jochre.utils.JochreException) ClassificationModelTrainer(com.joliciel.talismane.machineLearning.ClassificationModelTrainer) MergeFeature(com.joliciel.jochre.boundaries.features.MergeFeature) File(java.io.File) JochreMergeEventStream(com.joliciel.jochre.boundaries.JochreMergeEventStream) OutcomeEqualiserEventStream(com.joliciel.talismane.machineLearning.OutcomeEqualiserEventStream) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel)

Aggregations

ClassificationModel (com.joliciel.talismane.machineLearning.ClassificationModel)20 File (java.io.File)10 ClassificationModelTrainer (com.joliciel.talismane.machineLearning.ClassificationModelTrainer)8 ModelTrainerFactory (com.joliciel.talismane.machineLearning.ModelTrainerFactory)8 LetterFeature (com.joliciel.jochre.letterGuesser.features.LetterFeature)6 LetterFeatureParser (com.joliciel.jochre.letterGuesser.features.LetterFeatureParser)6 BeamSearchImageAnalyser (com.joliciel.jochre.analyser.BeamSearchImageAnalyser)5 ImageAnalyser (com.joliciel.jochre.analyser.ImageAnalyser)5 OriginalShapeLetterAssigner (com.joliciel.jochre.analyser.OriginalShapeLetterAssigner)5 BoundaryDetector (com.joliciel.jochre.boundaries.BoundaryDetector)5 DeterministicBoundaryDetector (com.joliciel.jochre.boundaries.DeterministicBoundaryDetector)5 OriginalBoundaryDetector (com.joliciel.jochre.boundaries.OriginalBoundaryDetector)5 LetterGuesser (com.joliciel.jochre.letterGuesser.LetterGuesser)5 JochreException (com.joliciel.jochre.utils.JochreException)5 LetterAssigner (com.joliciel.jochre.analyser.LetterAssigner)4 LetterByLetterBoundaryDetector (com.joliciel.jochre.boundaries.LetterByLetterBoundaryDetector)4 RecursiveShapeSplitter (com.joliciel.jochre.boundaries.RecursiveShapeSplitter)4 ShapeMerger (com.joliciel.jochre.boundaries.ShapeMerger)4 ShapeSplitter (com.joliciel.jochre.boundaries.ShapeSplitter)4 TrainingCorpusShapeMerger (com.joliciel.jochre.boundaries.TrainingCorpusShapeMerger)4