Search in sources :

Example 1 with ClassificationModel

use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.

the class LanguageDetector method getInstance.

public static LanguageDetector getInstance(String sessionId) throws IOException, TalismaneException, ClassNotFoundException {
    LanguageDetector languageDetector = languageDetectorMap.get(sessionId);
    if (languageDetector == null) {
        Config config = ConfigFactory.load();
        String configPath = "talismane.core." + sessionId + ".language-detector.model";
        String modelFilePath = config.getString(configPath);
        ClassificationModel model = modelMap.get(modelFilePath);
        if (model == null) {
            InputStream modelFile = ConfigUtils.getFileFromConfig(config, configPath);
            MachineLearningModelFactory factory = new MachineLearningModelFactory();
            model = factory.getClassificationModel(new ZipInputStream(modelFile));
            modelMap.put(modelFilePath, model);
        }
        languageDetector = new LanguageDetector(model);
        languageDetectorMap.put(sessionId, languageDetector);
    }
    return languageDetector;
}
Also used : ZipInputStream(java.util.zip.ZipInputStream) Config(com.typesafe.config.Config) ZipInputStream(java.util.zip.ZipInputStream) InputStream(java.io.InputStream) MachineLearningModelFactory(com.joliciel.talismane.machineLearning.MachineLearningModelFactory) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel)

Example 2 with ClassificationModel

use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.

the class SentenceDetectorTrainer method train.

public ClassificationModel train() throws TalismaneException, IOException {
    ModelTrainerFactory factory = new ModelTrainerFactory();
    ClassificationModelTrainer trainer = factory.constructTrainer(sentenceConfig.getConfig("train.machine-learning"));
    ClassificationModel model = trainer.trainModel(eventStream, descriptors);
    model.setExternalResources(TalismaneSession.get(sessionId).getExternalResourceFinder().getExternalResources());
    File modelDir = modelFile.getParentFile();
    if (modelDir != null)
        modelDir.mkdirs();
    model.persist(modelFile);
    return model;
}
Also used : ModelTrainerFactory(com.joliciel.talismane.machineLearning.ModelTrainerFactory) ClassificationModelTrainer(com.joliciel.talismane.machineLearning.ClassificationModelTrainer) File(java.io.File) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel)

Example 3 with ClassificationModel

use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.

the class PosTaggerTrainer method train.

public ClassificationModel train() throws TalismaneException, IOException {
    ModelTrainerFactory factory = new ModelTrainerFactory();
    ClassificationModelTrainer trainer = factory.constructTrainer(posTaggerConfig.getConfig("train.machine-learning"));
    ClassificationModel model = trainer.trainModel(eventStream, descriptors);
    model.setExternalResources(TalismaneSession.get(sessionId).getExternalResourceFinder().getExternalResources());
    File modelDir = modelFile.getParentFile();
    if (modelDir != null)
        modelDir.mkdirs();
    model.persist(modelFile);
    return model;
}
Also used : ModelTrainerFactory(com.joliciel.talismane.machineLearning.ModelTrainerFactory) ClassificationModelTrainer(com.joliciel.talismane.machineLearning.ClassificationModelTrainer) File(java.io.File) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel)

Example 4 with ClassificationModel

use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.

the class PatternTokeniserTrainer method train.

public ClassificationModel train() throws TalismaneException, IOException {
    ModelTrainerFactory factory = new ModelTrainerFactory();
    ClassificationModelTrainer trainer = factory.constructTrainer(tokeniserConfig.getConfig("train.machine-learning"));
    ClassificationModel model = trainer.trainModel(eventStream, descriptors);
    model.setExternalResources(TalismaneSession.get(sessionId).getExternalResourceFinder().getExternalResources());
    File modelDir = modelFile.getParentFile();
    if (modelDir != null)
        modelDir.mkdirs();
    model.persist(modelFile);
    return model;
}
Also used : ModelTrainerFactory(com.joliciel.talismane.machineLearning.ModelTrainerFactory) ClassificationModelTrainer(com.joliciel.talismane.machineLearning.ClassificationModelTrainer) File(java.io.File) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel)

Example 5 with ClassificationModel

use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.

the class LinearSVMModelTrainer method trainModel.

@Override
public ClassificationModel trainModel(ClassificationEventStream corpusEventStream, Map<String, List<String>> descriptors) {
    // Note: since we want a probabilistic classifier, our options here
    // are limited to logistic regression:
    // L2R_LR: L2-regularized logistic regression (primal)
    // L1R_LR: L1-regularized logistic regression
    // L2R_LR_DUAL: L2-regularized logistic regression (dual)
    SolverType solver = SolverType.valueOf(this.solverType.name());
    if (!solver.isLogisticRegressionSolver())
        throw new JolicielException("To get a probability distribution of outcomes, only logistic regression solvers are supported.");
    TObjectIntMap<String> featureIndexMap = new TObjectIntHashMap<String>(1000, 0.75f, -1);
    TObjectIntMap<String> outcomeIndexMap = new TObjectIntHashMap<String>(100, 0.75f, -1);
    TIntList outcomeList = new TIntArrayList();
    TIntIntMap featureCountMap = new TIntIntHashMap();
    CountingInfo countingInfo = new CountingInfo();
    Feature[][] featureMatrix = this.getFeatureMatrix(corpusEventStream, featureIndexMap, outcomeIndexMap, outcomeList, featureCountMap, countingInfo);
    // apply the cutoff
    if (cutoff > 1) {
        LOG.debug("Feature count (after cutoff): " + countingInfo.featureCountOverCutoff);
        for (int i = 0; i < featureMatrix.length; i++) {
            Feature[] featureArray = featureMatrix[i];
            List<Feature> featureList = new ArrayList<Feature>(featureArray.length);
            for (int j = 0; j < featureArray.length; j++) {
                Feature feature = featureArray[j];
                int featureCount = featureCountMap.get(feature.getIndex());
                if (featureCount >= cutoff)
                    featureList.add(feature);
            }
            Feature[] newFeatureArray = new Feature[featureList.size()];
            int j = 0;
            for (Feature feature : featureList) newFeatureArray[j++] = feature;
            // try to force a garbage collect without being too explicit
            // about it
            featureMatrix[i] = null;
            featureArray = null;
            featureMatrix[i] = newFeatureArray;
        }
    }
    final String[] outcomeArray2 = new String[outcomeIndexMap.size()];
    outcomeIndexMap.forEachEntry(new TObjectIntProcedure<String>() {

        @Override
        public boolean execute(String key, int value) {
            outcomeArray2[value] = key;
            return true;
        }
    });
    List<String> outcomes = new ArrayList<String>(outcomeIndexMap.size());
    for (String outcome : outcomeArray2) outcomes.add(outcome);
    if (oneVsRest) {
        // find outcomes representing multiple classes
        TIntSet multiClassOutcomes = new TIntHashSet();
        TIntObjectMap<TIntSet> outcomeComponentMap = new TIntObjectHashMap<TIntSet>();
        List<String> atomicOutcomes = new ArrayList<String>();
        TObjectIntMap<String> atomicOutcomeIndexes = new TObjectIntHashMap<String>();
        TIntIntMap oldIndexNewIndexMap = new TIntIntHashMap();
        // store all atomic outcomes in one data structures
        for (int j = 0; j < outcomes.size(); j++) {
            String outcome = outcomes.get(j);
            if (outcome.indexOf('\t') < 0) {
                int newIndex = atomicOutcomes.size();
                atomicOutcomeIndexes.put(outcome, newIndex);
                oldIndexNewIndexMap.put(j, newIndex);
                atomicOutcomes.add(outcome);
            }
        }
        // data structures
        for (int j = 0; j < outcomes.size(); j++) {
            String outcome = outcomes.get(j);
            if (outcome.indexOf('\t') >= 0) {
                multiClassOutcomes.add(j);
                TIntSet myComponentOutcomes = new TIntHashSet();
                outcomeComponentMap.put(j, myComponentOutcomes);
                String[] parts = outcome.split("\t", -1);
                for (String part : parts) {
                    int outcomeIndex = outcomeIndexMap.get(part);
                    int newIndex = 0;
                    if (outcomeIndex < 0) {
                        outcomeIndex = countingInfo.currentOutcomeIndex++;
                        outcomeIndexMap.put(part, outcomeIndex);
                        newIndex = atomicOutcomes.size();
                        atomicOutcomeIndexes.put(part, newIndex);
                        oldIndexNewIndexMap.put(outcomeIndex, newIndex);
                        atomicOutcomes.add(part);
                    } else {
                        newIndex = oldIndexNewIndexMap.get(outcomeIndex);
                    }
                    myComponentOutcomes.add(newIndex);
                }
            }
        }
        LinearSVMOneVsRestModel linearSVMModel = new LinearSVMOneVsRestModel(config, descriptors);
        linearSVMModel.setFeatureIndexMap(featureIndexMap);
        linearSVMModel.setOutcomes(atomicOutcomes);
        linearSVMModel.addModelAttribute("solver", this.getSolverType().name());
        linearSVMModel.addModelAttribute("cutoff", "" + this.getCutoff());
        linearSVMModel.addModelAttribute("c", "" + this.getConstraintViolationCost());
        linearSVMModel.addModelAttribute("eps", "" + this.getEpsilon());
        linearSVMModel.addModelAttribute("oneVsRest", "" + this.isOneVsRest());
        linearSVMModel.getModelAttributes().putAll(corpusEventStream.getAttributes());
        // build one 1-vs-All model per outcome
        for (int j = 0; j < atomicOutcomes.size(); j++) {
            String outcome = atomicOutcomes.get(j);
            LOG.info("Building model for outcome: " + outcome);
            // create an outcome array with 1 for the current outcome
            // and 0 for all others
            double[] outcomeArray = new double[countingInfo.numEvents];
            int i = 0;
            TIntIterator outcomeIterator = outcomeList.iterator();
            int myOutcomeCount = 0;
            while (outcomeIterator.hasNext()) {
                boolean isMyOutcome = false;
                int originalOutcomeIndex = outcomeIterator.next();
                if (multiClassOutcomes.contains(originalOutcomeIndex)) {
                    if (outcomeComponentMap.get(originalOutcomeIndex).contains(j))
                        isMyOutcome = true;
                } else {
                    if (oldIndexNewIndexMap.get(originalOutcomeIndex) == j)
                        isMyOutcome = true;
                }
                int myOutcome = (isMyOutcome ? 1 : 0);
                if (myOutcome == 1)
                    myOutcomeCount++;
                outcomeArray[i++] = myOutcome;
            }
            LOG.debug("Found " + myOutcomeCount + " out of " + countingInfo.numEvents + " outcomes of type: " + outcome);
            double[] myOutcomeArray = outcomeArray;
            Feature[][] myFeatureMatrix = featureMatrix;
            if (balanceEventCounts) {
                // we start with the truncated proportion of false
                // events to true events
                // we want these approximately balanced
                // we only balance up, never balance down
                int otherCount = countingInfo.numEvents - myOutcomeCount;
                int proportion = otherCount / myOutcomeCount;
                if (proportion > 1) {
                    LOG.debug("Balancing events for " + outcome + " by " + proportion);
                    int newSize = otherCount + myOutcomeCount * proportion;
                    myOutcomeArray = new double[newSize];
                    myFeatureMatrix = new Feature[newSize][];
                    int l = 0;
                    for (int k = 0; k < outcomeArray.length; k++) {
                        double myOutcome = outcomeArray[k];
                        Feature[] myFeatures = featureMatrix[k];
                        if (myOutcome == 0) {
                            myOutcomeArray[l] = myOutcome;
                            myFeatureMatrix[l] = myFeatures;
                            l++;
                        } else {
                            for (int m = 0; m < proportion; m++) {
                                myOutcomeArray[l] = myOutcome;
                                myFeatureMatrix[l] = myFeatures;
                                l++;
                            }
                        }
                    // is it the right outcome or not?
                    }
                // next outcome in original array
                }
            // requires balancing?
            }
            // balance event counts?
            Problem problem = new Problem();
            // problem.l = ... // number of training examples
            // problem.n = ... // number of features
            // problem.x = ... // feature nodes - note: must be ordered
            // by index
            // problem.y = ... // target values
            // number of training
            problem.l = countingInfo.numEvents;
            // examples
            // number of
            problem.n = countingInfo.currentFeatureIndex;
            // features
            // feature nodes - note: must
            problem.x = myFeatureMatrix;
            // be ordered by index
            // target values
            problem.y = myOutcomeArray;
            Parameter parameter = new Parameter(solver, this.constraintViolationCost, this.epsilon);
            Model model = Linear.train(problem, parameter);
            linearSVMModel.addModel(model);
        }
        return linearSVMModel;
    } else {
        double[] outcomeArray = new double[countingInfo.numEvents];
        int i = 0;
        TIntIterator outcomeIterator = outcomeList.iterator();
        while (outcomeIterator.hasNext()) outcomeArray[i++] = outcomeIterator.next();
        Problem problem = new Problem();
        // problem.l = ... // number of training examples
        // problem.n = ... // number of features
        // problem.x = ... // feature nodes - note: must be ordered by
        // index
        // problem.y = ... // target values
        // number of training
        problem.l = countingInfo.numEvents;
        // examples
        // number of
        problem.n = countingInfo.currentFeatureIndex;
        // features
        // feature nodes - note: must be
        problem.x = featureMatrix;
        // ordered by index
        // target values
        problem.y = outcomeArray;
        Parameter parameter = new Parameter(solver, this.constraintViolationCost, this.epsilon);
        Model model = Linear.train(problem, parameter);
        LinearSVMModel linearSVMModel = new LinearSVMModel(model, config, descriptors);
        linearSVMModel.setFeatureIndexMap(featureIndexMap);
        linearSVMModel.setOutcomes(outcomes);
        linearSVMModel.addModelAttribute("solver", this.getSolverType());
        linearSVMModel.addModelAttribute("cutoff", this.getCutoff());
        linearSVMModel.addModelAttribute("cost", this.getConstraintViolationCost());
        linearSVMModel.addModelAttribute("epsilon", this.getEpsilon());
        linearSVMModel.addModelAttribute("oneVsRest", this.isOneVsRest());
        linearSVMModel.getModelAttributes().putAll(corpusEventStream.getAttributes());
        return linearSVMModel;
    }
}
Also used : JolicielException(com.joliciel.talismane.utils.JolicielException) TIntSet(gnu.trove.set.TIntSet) TIntArrayList(gnu.trove.list.array.TIntArrayList) ArrayList(java.util.ArrayList) Feature(de.bwaldvogel.liblinear.Feature) TIntIntMap(gnu.trove.map.TIntIntMap) TIntHashSet(gnu.trove.set.hash.TIntHashSet) TObjectIntHashMap(gnu.trove.map.hash.TObjectIntHashMap) TIntIntHashMap(gnu.trove.map.hash.TIntIntHashMap) TIntIterator(gnu.trove.iterator.TIntIterator) SolverType(de.bwaldvogel.liblinear.SolverType) TIntArrayList(gnu.trove.list.array.TIntArrayList) TIntObjectHashMap(gnu.trove.map.hash.TIntObjectHashMap) ClassificationModel(com.joliciel.talismane.machineLearning.ClassificationModel) Model(de.bwaldvogel.liblinear.Model) MachineLearningModel(com.joliciel.talismane.machineLearning.MachineLearningModel) Parameter(de.bwaldvogel.liblinear.Parameter) Problem(de.bwaldvogel.liblinear.Problem) TIntList(gnu.trove.list.TIntList)

Aggregations

ClassificationModel (com.joliciel.talismane.machineLearning.ClassificationModel)20 File (java.io.File)10 ClassificationModelTrainer (com.joliciel.talismane.machineLearning.ClassificationModelTrainer)8 ModelTrainerFactory (com.joliciel.talismane.machineLearning.ModelTrainerFactory)8 LetterFeature (com.joliciel.jochre.letterGuesser.features.LetterFeature)6 LetterFeatureParser (com.joliciel.jochre.letterGuesser.features.LetterFeatureParser)6 BeamSearchImageAnalyser (com.joliciel.jochre.analyser.BeamSearchImageAnalyser)5 ImageAnalyser (com.joliciel.jochre.analyser.ImageAnalyser)5 OriginalShapeLetterAssigner (com.joliciel.jochre.analyser.OriginalShapeLetterAssigner)5 BoundaryDetector (com.joliciel.jochre.boundaries.BoundaryDetector)5 DeterministicBoundaryDetector (com.joliciel.jochre.boundaries.DeterministicBoundaryDetector)5 OriginalBoundaryDetector (com.joliciel.jochre.boundaries.OriginalBoundaryDetector)5 LetterGuesser (com.joliciel.jochre.letterGuesser.LetterGuesser)5 JochreException (com.joliciel.jochre.utils.JochreException)5 LetterAssigner (com.joliciel.jochre.analyser.LetterAssigner)4 LetterByLetterBoundaryDetector (com.joliciel.jochre.boundaries.LetterByLetterBoundaryDetector)4 RecursiveShapeSplitter (com.joliciel.jochre.boundaries.RecursiveShapeSplitter)4 ShapeMerger (com.joliciel.jochre.boundaries.ShapeMerger)4 ShapeSplitter (com.joliciel.jochre.boundaries.ShapeSplitter)4 TrainingCorpusShapeMerger (com.joliciel.jochre.boundaries.TrainingCorpusShapeMerger)4