use of zemberek.classification.FastTextClassifier in project zemberek-nlp by ahmetaa.
the class SimpleClassification method main.
public static void main(String[] args) throws IOException {
// assumes models are generated with NewsTitleCategoryFinder
Path path = Paths.get("/media/ahmetaa/depo/zemberek/data/classification/news-title-category-set.tokenized.model");
FastTextClassifier classifier = FastTextClassifier.load(path);
String s = "Beşiktaş berabere kaldı.";
// process the input exactly the way trainin set is processed
String processed = String.join(" ", TurkishTokenizer.DEFAULT.tokenizeToStrings(s));
processed = processed.toLowerCase(Turkish.LOCALE);
// results, only top three.
List<ScoredItem<String>> res = classifier.predict(processed, 3);
for (ScoredItem<String> re : res) {
System.out.println(re);
}
}
use of zemberek.classification.FastTextClassifier in project zemberek-nlp by ahmetaa.
the class EvaluateClassifier method run.
@Override
public void run() throws Exception {
System.out.println("Loading classification model...");
FastTextClassifier classifier = FastTextClassifier.load(model);
EvaluationResult result = classifier.evaluate(input, maxPrediction, threshold);
System.out.println("Result = " + result.toString());
if (predictions == null) {
String name = input.toFile().getName();
predictions = Paths.get("").resolve(name + ".predictions");
}
List<String> testLines = Files.readAllLines(input, StandardCharsets.UTF_8);
try (PrintWriter pw = new PrintWriter(predictions.toFile(), "utf-8")) {
for (String testLine : testLines) {
List<ScoredItem<String>> res = classifier.predict(testLine, maxPrediction);
res = res.stream().filter(s -> s.score >= threshold).collect(Collectors.toList());
List<String> predictedCategories = new ArrayList<>();
for (ScoredItem<String> re : res) {
predictedCategories.add(String.format(Locale.ENGLISH, "%s (%.6f)", re.item.replaceAll("__label__", ""), Math.exp(re.score)));
}
pw.println(testLine);
pw.println("Predictions = " + String.join(", ", predictedCategories));
pw.println();
}
}
System.out.println("Predictions are written to " + predictions);
}
use of zemberek.classification.FastTextClassifier in project zemberek-nlp by ahmetaa.
the class ClassificationConsole method run.
@Override
public void run() throws Exception {
Log.info("Loading classification model...");
FastTextClassifier classifier = FastTextClassifier.load(model);
if (preprocessor == Preprocessor.LEMMA) {
morphology = TurkishMorphology.createWithDefaults();
}
String input;
System.out.println("Preprocessing type = " + preprocessor.name());
System.out.println("Enter sentence:");
Scanner sc = new Scanner(System.in);
input = sc.nextLine();
while (!input.equals("exit") && !input.equals("quit")) {
if (input.trim().length() == 0) {
System.out.println("Empty line cannot be processed.");
input = sc.nextLine();
continue;
}
String processed;
if (preprocessor == Preprocessor.TOKENIZED) {
processed = String.join(" ", TurkishTokenizer.DEFAULT.tokenizeToStrings(input));
} else {
processed = replaceWordsWithLemma(input);
}
processed = removeNonWords(processed).toLowerCase(Turkish.LOCALE);
System.out.println("Processed Input = " + processed);
if (processed.trim().length() == 0) {
System.out.println("Processing result is empty. Enter new sentence.");
input = sc.nextLine();
continue;
}
List<ScoredItem<String>> res = classifier.predict(processed, predictionCount);
List<String> predictedCategories = new ArrayList<>();
for (ScoredItem<String> re : res) {
predictedCategories.add(String.format(Locale.ENGLISH, "%s (%.6f)", re.item.replaceAll("__label__", ""), re.score));
}
System.out.println("Predictions = " + String.join(", ", predictedCategories));
System.out.println();
input = sc.nextLine();
}
}
use of zemberek.classification.FastTextClassifier in project zemberek-nlp by ahmetaa.
the class TrainClassifier method run.
@Override
public void run() throws IOException {
Log.info("Generating classification model from %s", input);
FastTextClassifierTrainer trainer = FastTextClassifierTrainer.builder().epochCount(epochCount).learningRate(learningRate).lossType(lossType).quantizationCutOff(cutOff).minWordCount(minWordCount).threadCount(threadCount).wordNgramOrder(wordNGrams).dimension(dimension).contextWindowSize(contextWindowSize).build();
Log.info("Training Started.");
trainer.getEventBus().register(this);
FastTextClassifier classifier = trainer.train(input);
Log.info("Saving classification model to %s", output);
FastText fastText = classifier.getFastText();
fastText.saveModel(output);
if (applyQuantization) {
Log.info("Applying quantization.");
if (cutOff > 0) {
Log.info("Quantization dictionary cut-off value = %d", cutOff);
}
Path parent = output.getParent();
String name = output.toFile().getName() + ".q";
Path quantizedModel = parent == null ? Paths.get(name) : parent.resolve(name);
Log.info("Saving quantized classification model to %s", quantizedModel);
FastText quantized = fastText.quantize(output, fastText.getArgs());
quantized.saveModel(quantizedModel);
}
}
Aggregations