use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.
the class LanguageDetectorTrainer method train.
public ClassificationModel train() throws TalismaneException, IOException {
ModelTrainerFactory factory = new ModelTrainerFactory();
ClassificationModelTrainer trainer = factory.constructTrainer(languageConfig.getConfig("train.machine-learning"));
ClassificationModel model = trainer.trainModel(eventStream, descriptors);
model.setExternalResources(session.getExternalResourceFinder().getExternalResources());
File modelDir = modelFile.getParentFile();
if (modelDir != null)
modelDir.mkdirs();
model.persist(modelFile);
return model;
}
use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.
the class PerceptronClassificationModelTrainer method train.
void train() {
try {
double prevAccuracy1 = 0.0;
double prevAccuracy2 = 0.0;
double prevAccuracy3 = 0.0;
int i = 0;
int averagingCount = 0;
for (i = 1; i <= iterations; i++) {
LOG.debug("Iteration " + i);
int totalErrors = 0;
int totalEvents = 0;
try (Scanner scanner = new Scanner(new BufferedReader(new InputStreamReader(new FileInputStream(eventFile), "UTF-8")))) {
while (scanner.hasNextLine()) {
String line = scanner.nextLine();
PerceptronEvent event = new PerceptronEvent(line);
totalEvents++;
// don't normalise unless we calculate the
// log-likelihood,
// to avoid mathematical cost of normalising
double[] results = decisionMaker.predict(event.getFeatureIndexes(), event.getFeatureValues());
double maxValue = results[0];
int predicted = 0;
for (int j = 1; j < results.length; j++) {
if (results[j] > maxValue) {
maxValue = results[j];
predicted = j;
}
}
int actual = event.getOutcomeIndex();
if (actual != predicted) {
for (int j = 0; j < event.getFeatureIndexes().size(); j++) {
double[] classWeights = params.getFeatureWeights()[event.getFeatureIndexes().get(j)];
classWeights[actual] += event.getFeatureValues().get(j);
classWeights[predicted] -= event.getFeatureValues().get(j);
}
totalErrors++;
}
// correct outcome?
}
// next event
}
// Add feature weights for this iteration
boolean addAverage = true;
if (this.isAverageAtIntervals()) {
if (i <= 20 || i == 25 || i == 36 || i == 49 || i == 64 || i == 81 || i == 100 || i == 121 || i == 144 || i == 169 || i == 196) {
addAverage = true;
LOG.debug("Averaging at iteration: " + i);
} else
addAverage = false;
}
if (addAverage) {
for (int j = 0; j < params.getFeatureWeights().length; j++) {
double[] totalClassWeights = totalFeatureWeights[j];
double[] classWeights = params.getFeatureWeights()[j];
for (int k = 0; k < params.getOutcomeCount(); k++) {
totalClassWeights[k] += classWeights[k];
}
}
averagingCount++;
}
if (observer != null && observationPoints.contains(i)) {
PerceptronModelParameters cloneParams = params.clone();
// average the weights for this model
for (int j = 0; j < cloneParams.getFeatureWeights().length; j++) {
double[] totalClassWeights = totalFeatureWeights[j];
double[] classWeights = cloneParams.getFeatureWeights()[j];
for (int k = 0; k < cloneParams.getOutcomeCount(); k++) {
classWeights[k] = totalClassWeights[k] / averagingCount;
}
}
ClassificationModel model = this.getModel(cloneParams, i);
observer.onNextModel(model, i);
cloneParams = null;
}
double accuracy = (double) (totalEvents - totalErrors) / (double) totalEvents;
LOG.debug("Accuracy: " + accuracy);
// exit if accuracy hasn't significantly changed in 3 iterations
if (Math.abs(accuracy - prevAccuracy1) < tolerance && Math.abs(accuracy - prevAccuracy2) < tolerance && Math.abs(accuracy - prevAccuracy3) < tolerance) {
LOG.info("Accuracy change < " + tolerance + " for 3 iterations: exiting after " + i + " iterations");
break;
}
prevAccuracy3 = prevAccuracy2;
prevAccuracy2 = prevAccuracy1;
prevAccuracy1 = accuracy;
}
// average the final weights
for (int j = 0; j < params.getFeatureWeights().length; j++) {
double[] totalClassWeights = totalFeatureWeights[j];
double[] classWeights = params.getFeatureWeights()[j];
for (int k = 0; k < params.getOutcomeCount(); k++) {
classWeights[k] = totalClassWeights[k] / averagingCount;
}
}
} catch (IOException e) {
LogUtils.logError(LOG, e);
throw new RuntimeException(e);
}
}
use of com.joliciel.talismane.machineLearning.ClassificationModel in project talismane by joliciel-informatique.
the class SentenceDetector method getInstance.
public static SentenceDetector getInstance(String sessionId) throws IOException, ClassNotFoundException {
SentenceDetector sentenceDetector = sentenceDetectorMap.get(sessionId);
if (sentenceDetector == null) {
Config config = ConfigFactory.load();
String configPath = "talismane.core." + sessionId + ".sentence-detector.model";
String modelFilePath = config.getString(configPath);
ClassificationModel sentenceModel = modelMap.get(modelFilePath);
if (sentenceModel == null) {
InputStream modelFile = ConfigUtils.getFileFromConfig(config, configPath);
MachineLearningModelFactory factory = new MachineLearningModelFactory();
sentenceModel = factory.getClassificationModel(new ZipInputStream(modelFile));
modelMap.put(modelFilePath, sentenceModel);
}
sentenceDetector = new SentenceDetector(sentenceModel, sessionId);
sentenceDetectorMap.put(sessionId, sentenceDetector);
}
return sentenceDetector.cloneSentenceDetector();
}
use of com.joliciel.talismane.machineLearning.ClassificationModel in project jochre by urieli.
the class Jochre method doCommandTrain.
/**
* Train a letter guessing model.
*
* @param featureDescriptors
* the feature descriptors for training
* @param criteria
* criteria for selecting images to include when training
* @param reconstructLetters
* whether or not complete letters should be reconstructed for
* training, from merged/split letters
*/
public void doCommandTrain(List<String> featureDescriptors, CorpusSelectionCriteria criteria, boolean reconstructLetters) {
if (jochreSession.getLetterModelPath() == null)
throw new RuntimeException("Missing argument: letterModel");
if (featureDescriptors == null)
throw new JochreException("features is required");
LetterFeatureParser letterFeatureParser = new LetterFeatureParser();
Set<LetterFeature<?>> features = letterFeatureParser.getLetterFeatureSet(featureDescriptors);
BoundaryDetector boundaryDetector = null;
if (reconstructLetters) {
ShapeSplitter splitter = new TrainingCorpusShapeSplitter(jochreSession);
ShapeMerger merger = new TrainingCorpusShapeMerger();
boundaryDetector = new LetterByLetterBoundaryDetector(splitter, merger, jochreSession);
} else {
boundaryDetector = new OriginalBoundaryDetector();
}
LetterValidator letterValidator = new ComponentCharacterValidator(jochreSession);
ClassificationEventStream corpusEventStream = new JochreLetterEventStream(features, boundaryDetector, letterValidator, criteria, jochreSession);
File letterModelFile = new File(jochreSession.getLetterModelPath());
letterModelFile.getParentFile().mkdirs();
ModelTrainerFactory modelTrainerFactory = new ModelTrainerFactory();
ClassificationModelTrainer trainer = modelTrainerFactory.constructTrainer(jochreSession.getConfig());
ClassificationModel letterModel = trainer.trainModel(corpusEventStream, featureDescriptors);
letterModel.persist(letterModelFile);
}
use of com.joliciel.talismane.machineLearning.ClassificationModel in project jochre by urieli.
the class Jochre method doCommandTrainMerge.
/**
* Train the letter merging model.
*
* @param featureDescriptors
* feature descriptors for training
* @param multiplier
* if > 0, will be used to equalize the outcomes
* @param criteria
* the criteria used to select the training corpus
*/
public void doCommandTrainMerge(List<String> featureDescriptors, int multiplier, CorpusSelectionCriteria criteria) {
if (jochreSession.getMergeModelPath() == null)
throw new RuntimeException("Missing argument: mergeModel");
if (featureDescriptors == null)
throw new JochreException("features is required");
File mergeModelFile = new File(jochreSession.getMergeModelPath());
mergeModelFile.getParentFile().mkdirs();
MergeFeatureParser mergeFeatureParser = new MergeFeatureParser();
Set<MergeFeature<?>> mergeFeatures = mergeFeatureParser.getMergeFeatureSet(featureDescriptors);
ClassificationEventStream corpusEventStream = new JochreMergeEventStream(criteria, mergeFeatures, jochreSession);
if (multiplier > 0) {
corpusEventStream = new OutcomeEqualiserEventStream(corpusEventStream, multiplier);
}
ModelTrainerFactory modelTrainerFactory = new ModelTrainerFactory();
ClassificationModelTrainer trainer = modelTrainerFactory.constructTrainer(jochreSession.getConfig());
ClassificationModel mergeModel = trainer.trainModel(corpusEventStream, featureDescriptors);
mergeModel.persist(mergeModelFile);
}
Aggregations