use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.
the class LanguageDetector method getInstance.
public static LanguageDetector getInstance(String sessionId) throws IOException, TalismaneException, ClassNotFoundException {
LanguageDetector languageDetector = languageDetectorMap.get(sessionId);
if (languageDetector == null) {
Config config = ConfigFactory.load();
String configPath = "talismane.core." + sessionId + ".language-detector.model";
String modelFilePath = config.getString(configPath);
ClassificationModel model = modelMap.get(modelFilePath);
if (model == null) {
InputStream modelFile = ConfigUtils.getFileFromConfig(config, configPath);
MachineLearningModelFactory factory = new MachineLearningModelFactory();
model = factory.getClassificationModel(new ZipInputStream(modelFile));
modelMap.put(modelFilePath, model);
}
languageDetector = new LanguageDetector(model);
languageDetectorMap.put(sessionId, languageDetector);
}
return languageDetector;
}
use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.
the class SentenceDetectorTrainer method train.
public ClassificationModel train() throws TalismaneException, IOException {
ModelTrainerFactory factory = new ModelTrainerFactory();
ClassificationModelTrainer trainer = factory.constructTrainer(sentenceConfig.getConfig("train.machine-learning"));
ClassificationModel model = trainer.trainModel(eventStream, descriptors);
model.setExternalResources(TalismaneSession.get(sessionId).getExternalResourceFinder().getExternalResources());
File modelDir = modelFile.getParentFile();
if (modelDir != null)
modelDir.mkdirs();
model.persist(modelFile);
return model;
}
use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.
the class PosTaggerTrainer method train.
public ClassificationModel train() throws TalismaneException, IOException {
ModelTrainerFactory factory = new ModelTrainerFactory();
ClassificationModelTrainer trainer = factory.constructTrainer(posTaggerConfig.getConfig("train.machine-learning"));
ClassificationModel model = trainer.trainModel(eventStream, descriptors);
model.setExternalResources(TalismaneSession.get(sessionId).getExternalResourceFinder().getExternalResources());
File modelDir = modelFile.getParentFile();
if (modelDir != null)
modelDir.mkdirs();
model.persist(modelFile);
return model;
}
use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.
the class PatternTokeniserTrainer method train.
public ClassificationModel train() throws TalismaneException, IOException {
ModelTrainerFactory factory = new ModelTrainerFactory();
ClassificationModelTrainer trainer = factory.constructTrainer(tokeniserConfig.getConfig("train.machine-learning"));
ClassificationModel model = trainer.trainModel(eventStream, descriptors);
model.setExternalResources(TalismaneSession.get(sessionId).getExternalResourceFinder().getExternalResources());
File modelDir = modelFile.getParentFile();
if (modelDir != null)
modelDir.mkdirs();
model.persist(modelFile);
return model;
}
use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.
the class LinearSVMModelTrainer method trainModel.
@Override
public ClassificationModel trainModel(ClassificationEventStream corpusEventStream, Map<String, List<String>> descriptors) {
// Note: since we want a probabilistic classifier, our options here
// are limited to logistic regression:
// L2R_LR: L2-regularized logistic regression (primal)
// L1R_LR: L1-regularized logistic regression
// L2R_LR_DUAL: L2-regularized logistic regression (dual)
SolverType solver = SolverType.valueOf(this.solverType.name());
if (!solver.isLogisticRegressionSolver())
throw new JolicielException("To get a probability distribution of outcomes, only logistic regression solvers are supported.");
TObjectIntMap<String> featureIndexMap = new TObjectIntHashMap<String>(1000, 0.75f, -1);
TObjectIntMap<String> outcomeIndexMap = new TObjectIntHashMap<String>(100, 0.75f, -1);
TIntList outcomeList = new TIntArrayList();
TIntIntMap featureCountMap = new TIntIntHashMap();
CountingInfo countingInfo = new CountingInfo();
Feature[][] featureMatrix = this.getFeatureMatrix(corpusEventStream, featureIndexMap, outcomeIndexMap, outcomeList, featureCountMap, countingInfo);
// apply the cutoff
if (cutoff > 1) {
LOG.debug("Feature count (after cutoff): " + countingInfo.featureCountOverCutoff);
for (int i = 0; i < featureMatrix.length; i++) {
Feature[] featureArray = featureMatrix[i];
List<Feature> featureList = new ArrayList<Feature>(featureArray.length);
for (int j = 0; j < featureArray.length; j++) {
Feature feature = featureArray[j];
int featureCount = featureCountMap.get(feature.getIndex());
if (featureCount >= cutoff)
featureList.add(feature);
}
Feature[] newFeatureArray = new Feature[featureList.size()];
int j = 0;
for (Feature feature : featureList) newFeatureArray[j++] = feature;
// try to force a garbage collect without being too explicit
// about it
featureMatrix[i] = null;
featureArray = null;
featureMatrix[i] = newFeatureArray;
}
}
final String[] outcomeArray2 = new String[outcomeIndexMap.size()];
outcomeIndexMap.forEachEntry(new TObjectIntProcedure<String>() {
@Override
public boolean execute(String key, int value) {
outcomeArray2[value] = key;
return true;
}
});
List<String> outcomes = new ArrayList<String>(outcomeIndexMap.size());
for (String outcome : outcomeArray2) outcomes.add(outcome);
if (oneVsRest) {
// find outcomes representing multiple classes
TIntSet multiClassOutcomes = new TIntHashSet();
TIntObjectMap<TIntSet> outcomeComponentMap = new TIntObjectHashMap<TIntSet>();
List<String> atomicOutcomes = new ArrayList<String>();
TObjectIntMap<String> atomicOutcomeIndexes = new TObjectIntHashMap<String>();
TIntIntMap oldIndexNewIndexMap = new TIntIntHashMap();
// store all atomic outcomes in one data structures
for (int j = 0; j < outcomes.size(); j++) {
String outcome = outcomes.get(j);
if (outcome.indexOf('\t') < 0) {
int newIndex = atomicOutcomes.size();
atomicOutcomeIndexes.put(outcome, newIndex);
oldIndexNewIndexMap.put(j, newIndex);
atomicOutcomes.add(outcome);
}
}
// data structures
for (int j = 0; j < outcomes.size(); j++) {
String outcome = outcomes.get(j);
if (outcome.indexOf('\t') >= 0) {
multiClassOutcomes.add(j);
TIntSet myComponentOutcomes = new TIntHashSet();
outcomeComponentMap.put(j, myComponentOutcomes);
String[] parts = outcome.split("\t", -1);
for (String part : parts) {
int outcomeIndex = outcomeIndexMap.get(part);
int newIndex = 0;
if (outcomeIndex < 0) {
outcomeIndex = countingInfo.currentOutcomeIndex++;
outcomeIndexMap.put(part, outcomeIndex);
newIndex = atomicOutcomes.size();
atomicOutcomeIndexes.put(part, newIndex);
oldIndexNewIndexMap.put(outcomeIndex, newIndex);
atomicOutcomes.add(part);
} else {
newIndex = oldIndexNewIndexMap.get(outcomeIndex);
}
myComponentOutcomes.add(newIndex);
}
}
}
LinearSVMOneVsRestModel linearSVMModel = new LinearSVMOneVsRestModel(config, descriptors);
linearSVMModel.setFeatureIndexMap(featureIndexMap);
linearSVMModel.setOutcomes(atomicOutcomes);
linearSVMModel.addModelAttribute("solver", this.getSolverType().name());
linearSVMModel.addModelAttribute("cutoff", "" + this.getCutoff());
linearSVMModel.addModelAttribute("c", "" + this.getConstraintViolationCost());
linearSVMModel.addModelAttribute("eps", "" + this.getEpsilon());
linearSVMModel.addModelAttribute("oneVsRest", "" + this.isOneVsRest());
linearSVMModel.getModelAttributes().putAll(corpusEventStream.getAttributes());
// build one 1-vs-All model per outcome
for (int j = 0; j < atomicOutcomes.size(); j++) {
String outcome = atomicOutcomes.get(j);
LOG.info("Building model for outcome: " + outcome);
// create an outcome array with 1 for the current outcome
// and 0 for all others
double[] outcomeArray = new double[countingInfo.numEvents];
int i = 0;
TIntIterator outcomeIterator = outcomeList.iterator();
int myOutcomeCount = 0;
while (outcomeIterator.hasNext()) {
boolean isMyOutcome = false;
int originalOutcomeIndex = outcomeIterator.next();
if (multiClassOutcomes.contains(originalOutcomeIndex)) {
if (outcomeComponentMap.get(originalOutcomeIndex).contains(j))
isMyOutcome = true;
} else {
if (oldIndexNewIndexMap.get(originalOutcomeIndex) == j)
isMyOutcome = true;
}
int myOutcome = (isMyOutcome ? 1 : 0);
if (myOutcome == 1)
myOutcomeCount++;
outcomeArray[i++] = myOutcome;
}
LOG.debug("Found " + myOutcomeCount + " out of " + countingInfo.numEvents + " outcomes of type: " + outcome);
double[] myOutcomeArray = outcomeArray;
Feature[][] myFeatureMatrix = featureMatrix;
if (balanceEventCounts) {
// we start with the truncated proportion of false
// events to true events
// we want these approximately balanced
// we only balance up, never balance down
int otherCount = countingInfo.numEvents - myOutcomeCount;
int proportion = otherCount / myOutcomeCount;
if (proportion > 1) {
LOG.debug("Balancing events for " + outcome + " by " + proportion);
int newSize = otherCount + myOutcomeCount * proportion;
myOutcomeArray = new double[newSize];
myFeatureMatrix = new Feature[newSize][];
int l = 0;
for (int k = 0; k < outcomeArray.length; k++) {
double myOutcome = outcomeArray[k];
Feature[] myFeatures = featureMatrix[k];
if (myOutcome == 0) {
myOutcomeArray[l] = myOutcome;
myFeatureMatrix[l] = myFeatures;
l++;
} else {
for (int m = 0; m < proportion; m++) {
myOutcomeArray[l] = myOutcome;
myFeatureMatrix[l] = myFeatures;
l++;
}
}
// is it the right outcome or not?
}
// next outcome in original array
}
// requires balancing?
}
// balance event counts?
Problem problem = new Problem();
// problem.l = ... // number of training examples
// problem.n = ... // number of features
// problem.x = ... // feature nodes - note: must be ordered
// by index
// problem.y = ... // target values
// number of training
problem.l = countingInfo.numEvents;
// examples
// number of
problem.n = countingInfo.currentFeatureIndex;
// features
// feature nodes - note: must
problem.x = myFeatureMatrix;
// be ordered by index
// target values
problem.y = myOutcomeArray;
Parameter parameter = new Parameter(solver, this.constraintViolationCost, this.epsilon);
Model model = Linear.train(problem, parameter);
linearSVMModel.addModel(model);
}
return linearSVMModel;
} else {
double[] outcomeArray = new double[countingInfo.numEvents];
int i = 0;
TIntIterator outcomeIterator = outcomeList.iterator();
while (outcomeIterator.hasNext()) outcomeArray[i++] = outcomeIterator.next();
Problem problem = new Problem();
// problem.l = ... // number of training examples
// problem.n = ... // number of features
// problem.x = ... // feature nodes - note: must be ordered by
// index
// problem.y = ... // target values
// number of training
problem.l = countingInfo.numEvents;
// examples
// number of
problem.n = countingInfo.currentFeatureIndex;
// features
// feature nodes - note: must be
problem.x = featureMatrix;
// ordered by index
// target values
problem.y = outcomeArray;
Parameter parameter = new Parameter(solver, this.constraintViolationCost, this.epsilon);
Model model = Linear.train(problem, parameter);
LinearSVMModel linearSVMModel = new LinearSVMModel(model, config, descriptors);
linearSVMModel.setFeatureIndexMap(featureIndexMap);
linearSVMModel.setOutcomes(outcomes);
linearSVMModel.addModelAttribute("solver", this.getSolverType());
linearSVMModel.addModelAttribute("cutoff", this.getCutoff());
linearSVMModel.addModelAttribute("cost", this.getConstraintViolationCost());
linearSVMModel.addModelAttribute("epsilon", this.getEpsilon());
linearSVMModel.addModelAttribute("oneVsRest", this.isOneVsRest());
linearSVMModel.getModelAttributes().putAll(corpusEventStream.getAttributes());
return linearSVMModel;
}
}
Aggregations