use of zemberek.ner.PerceptronNer.ClassModel in project zemberek-nlp by ahmetaa.
the class PerceptronNerTrainer method train.
public PerceptronNer train(NerDataSet trainingSet, NerDataSet devSet, int iterationCount, float learningRate) {
Map<String, ClassModel> averages = new HashMap<>();
Map<String, ClassModel> model = new HashMap<>();
IntValueMap<String> counts = new IntValueMap<>();
// initialize model weights for all classes.
for (String typeId : trainingSet.typeIds) {
model.put(typeId, new ClassModel(typeId));
averages.put(typeId, new ClassModel(typeId));
}
for (int it = 0; it < iterationCount; it++) {
int errorCount = 0;
int tokenCount = 0;
trainingSet.shuffle();
for (NerSentence sentence : trainingSet.sentences) {
for (int i = 0; i < sentence.tokens.size(); i++) {
tokenCount++;
NerToken currentToken = sentence.tokens.get(i);
String currentId = currentToken.tokenId;
FeatureData data = new FeatureData(morphology, sentence, i);
List<String> sparseFeatures = data.getTextualFeatures();
if (i > 0) {
sparseFeatures.add("PreType=" + sentence.tokens.get(i - 1).tokenId);
}
if (i > 1) {
sparseFeatures.add("2PreType=" + sentence.tokens.get(i - 2).tokenId);
}
if (i > 2) {
sparseFeatures.add("3PreType=" + sentence.tokens.get(i - 3).tokenId);
}
ScoredItem<String> predicted = PerceptronNer.predictTypeAndPosition(model, sparseFeatures);
String predictedId = predicted.item;
if (predictedId.equals(currentId)) {
// do nothing
counts.addOrIncrement(predictedId);
continue;
}
counts.addOrIncrement(currentId);
counts.addOrIncrement(predictedId);
errorCount++;
model.get(currentId).updateSparse(sparseFeatures, +learningRate);
model.get(predictedId).updateSparse(sparseFeatures, -learningRate);
averages.get(currentId).updateSparse(sparseFeatures, counts.get(currentId) * learningRate);
averages.get(predictedId).updateSparse(sparseFeatures, -counts.get(predictedId) * learningRate);
}
}
Log.info("Iteration %d, Token error = %.6f", it + 1, (errorCount * 1d) / tokenCount);
Map<String, ClassModel> copyModel = copyModel(model);
averageWeights(averages, copyModel, counts);
PerceptronNer ner = new PerceptronNer(copyModel, morphology);
if (devSet != null) {
NerDataSet result = ner.evaluate(devSet);
Log.info(collectEvaluationData(devSet, result).dump());
}
}
averageWeights(averages, model, counts);
Log.info("Training finished.");
return new PerceptronNer(model, morphology);
}
Aggregations