use of zemberek.core.ScoredItem in project zemberek-nlp by ahmetaa.
the class DistanceBasedStemmer method findStems.
public void findStems(String str) {
str = "<s> <s> " + str + " </s> </s>";
SentenceAnalysis analysis = morphology.analyzeAndDisambiguate(str);
List<SentenceWordAnalysis> swaList = analysis.getWordAnalyses();
for (int i = 2; i < analysis.size() - 2; i++) {
SentenceWordAnalysis swa = swaList.get(i);
String s = swaList.get(i).getWordAnalysis().getInput();
List<String> bigramContext = Lists.newArrayList(normalize(swaList.get(i - 1).getWordAnalysis().getInput()), normalize(swaList.get(i - 2).getWordAnalysis().getInput()), normalize(swaList.get(i + 1).getWordAnalysis().getInput()), normalize(swaList.get(i + 2).getWordAnalysis().getInput()));
List<String> unigramContext = Lists.newArrayList(normalize(swaList.get(i - 1).getWordAnalysis().getInput()), normalize(swaList.get(i + 1).getWordAnalysis().getInput()));
WordAnalysis wordResults = swa.getWordAnalysis();
Set<String> stems = wordResults.stream().map(a -> normalize(a.getDictionaryItem().lemma)).collect(Collectors.toSet());
List<ScoredItem<String>> scores = new ArrayList<>();
for (String stem : stems) {
if (!distances.containsWord(stem)) {
Log.info("Cannot find %s in vocab.", stem);
continue;
}
List<WordDistances.Distance> distances = this.distances.getDistance(stem);
float score = totalDistance(stem, bigramContext);
int k = 0;
for (WordDistances.Distance distance : distances) {
/* if (s.equals(distance.word)) {
continue;
}*/
score += distance(s, distance.word);
if (k++ == 10) {
break;
}
}
scores.add(new ScoredItem<>(stem, score));
}
Collections.sort(scores);
Log.info("%n%s : ", s);
for (ScoredItem<String> score : scores) {
Log.info("Lemma = %s Score = %.7f", score.item, score.score);
}
}
Log.info("==== Z disambiguation result ===== ");
for (SentenceWordAnalysis a : analysis) {
Log.info("%n%s : ", a.getWordAnalysis().getInput());
LinkedHashSet<String> items = new LinkedHashSet<>();
for (SingleAnalysis wa : a.getWordAnalysis()) {
items.add(wa.getDictionaryItem().toString());
}
for (String item : items) {
Log.info("%s", item);
}
}
}
use of zemberek.core.ScoredItem in project zemberek-nlp by ahmetaa.
the class SimpleClassification method main.
public static void main(String[] args) throws IOException {
// assumes models are generated with NewsTitleCategoryFinder
Path path = Paths.get("/media/ahmetaa/depo/zemberek/data/classification/news-title-category-set.tokenized.model");
FastTextClassifier classifier = FastTextClassifier.load(path);
String s = "Beşiktaş berabere kaldı.";
// process the input exactly the way trainin set is processed
String processed = String.join(" ", TurkishTokenizer.DEFAULT.tokenizeToStrings(s));
processed = processed.toLowerCase(Turkish.LOCALE);
// results, only top three.
List<ScoredItem<String>> res = classifier.predict(processed, 3);
for (ScoredItem<String> re : res) {
System.out.println(re);
}
}
use of zemberek.core.ScoredItem in project zemberek-nlp by ahmetaa.
the class FastText method predict.
List<ScoredItem<String>> predict(String line, int k) {
IntVector words = new IntVector();
IntVector labels = new IntVector();
dict_.getLine(line, words, labels, model_.getRng());
dict_.addWordNgramHashes(words, args_.wordNgrams);
if (words.isempty()) {
return Collections.emptyList();
}
Vector output = new Vector(dict_.nlabels());
Vector hidden = model_.computeHidden(words.copyOf());
List<Model.FloatIntPair> modelPredictions = model_.predict(k, hidden, output);
List<ScoredItem<String>> result = new ArrayList<>(modelPredictions.size());
for (Model.FloatIntPair pair : modelPredictions) {
result.add(new ScoredItem<>(dict_.getLabel(pair.second), pair.first));
}
return result;
}
use of zemberek.core.ScoredItem in project zemberek-nlp by ahmetaa.
the class EvaluateClassifier method run.
@Override
public void run() throws Exception {
System.out.println("Loading classification model...");
FastTextClassifier classifier = FastTextClassifier.load(model);
EvaluationResult result = classifier.evaluate(input, maxPrediction, threshold);
System.out.println("Result = " + result.toString());
if (predictions == null) {
String name = input.toFile().getName();
predictions = Paths.get("").resolve(name + ".predictions");
}
List<String> testLines = Files.readAllLines(input, StandardCharsets.UTF_8);
try (PrintWriter pw = new PrintWriter(predictions.toFile(), "utf-8")) {
for (String testLine : testLines) {
List<ScoredItem<String>> res = classifier.predict(testLine, maxPrediction);
res = res.stream().filter(s -> s.score >= threshold).collect(Collectors.toList());
List<String> predictedCategories = new ArrayList<>();
for (ScoredItem<String> re : res) {
predictedCategories.add(String.format(Locale.ENGLISH, "%s (%.6f)", re.item.replaceAll("__label__", ""), Math.exp(re.score)));
}
pw.println(testLine);
pw.println("Predictions = " + String.join(", ", predictedCategories));
pw.println();
}
}
System.out.println("Predictions are written to " + predictions);
}
use of zemberek.core.ScoredItem in project zemberek-nlp by ahmetaa.
the class ClassificationConsole method run.
@Override
public void run() throws Exception {
Log.info("Loading classification model...");
FastTextClassifier classifier = FastTextClassifier.load(model);
if (preprocessor == Preprocessor.LEMMA) {
morphology = TurkishMorphology.createWithDefaults();
}
String input;
System.out.println("Preprocessing type = " + preprocessor.name());
System.out.println("Enter sentence:");
Scanner sc = new Scanner(System.in);
input = sc.nextLine();
while (!input.equals("exit") && !input.equals("quit")) {
if (input.trim().length() == 0) {
System.out.println("Empty line cannot be processed.");
input = sc.nextLine();
continue;
}
String processed;
if (preprocessor == Preprocessor.TOKENIZED) {
processed = String.join(" ", TurkishTokenizer.DEFAULT.tokenizeToStrings(input));
} else {
processed = replaceWordsWithLemma(input);
}
processed = removeNonWords(processed).toLowerCase(Turkish.LOCALE);
System.out.println("Processed Input = " + processed);
if (processed.trim().length() == 0) {
System.out.println("Processing result is empty. Enter new sentence.");
input = sc.nextLine();
continue;
}
List<ScoredItem<String>> res = classifier.predict(processed, predictionCount);
List<String> predictedCategories = new ArrayList<>();
for (ScoredItem<String> re : res) {
predictedCategories.add(String.format(Locale.ENGLISH, "%s (%.6f)", re.item.replaceAll("__label__", ""), re.score));
}
System.out.println("Predictions = " + String.join(", ", predictedCategories));
System.out.println();
input = sc.nextLine();
}
}
Aggregations