use of zemberek.core.embeddings.FastTextTrainer in project zemberek-nlp by ahmetaa.
the class AutomaticLabelingExperiment method getOrTrainFastText.
private FastText getOrTrainFastText(Path train, Path modelPath) throws Exception {
FastText fastText;
if (modelPath.toFile().exists()) {
fastText = FastText.load(modelPath);
} else {
Args argz = Args.forSupervised();
argz.thread = 16;
argz.loss = Args.loss_name.hierarchicalSoftmax;
argz.epoch = 100;
argz.wordNgrams = 2;
argz.minCount = 10;
argz.lr = 0.2;
argz.dim = 250;
argz.bucket = 7_000_000;
fastText = new FastTextTrainer(argz).train(train);
fastText.saveModel(modelPath);
}
return fastText;
}
use of zemberek.core.embeddings.FastTextTrainer in project zemberek-nlp by ahmetaa.
the class CategoryPredictionExperiment method runExperiment.
private void runExperiment() throws Exception {
Path corpusPath = experimentRoot.resolve("category.corpus");
Path train = experimentRoot.resolve("category.train");
Path test = experimentRoot.resolve("category.test");
Path titleRaw = experimentRoot.resolve("category.title");
Path modelPath = experimentRoot.resolve("category.model");
Path predictionPath = experimentRoot.resolve("category.predictions");
extractCategoryDocuments(rawCorpusRoot, corpusPath);
boolean useOnlyTitles = true;
boolean useLemmas = true;
generateSets(corpusPath, train, test, useOnlyTitles, useLemmas);
generateRawSet(corpusPath, titleRaw);
FastText fastText;
if (modelPath.toFile().exists()) {
Log.info("Reusing existing model %s", modelPath);
fastText = FastText.load(modelPath);
} else {
Args argz = Args.forSupervised();
argz.thread = 4;
argz.model = Args.model_name.supervised;
argz.loss = Args.loss_name.softmax;
argz.epoch = 50;
argz.wordNgrams = 2;
argz.minCount = 0;
argz.lr = 0.5;
argz.dim = 100;
argz.bucket = 5_000_000;
fastText = new FastTextTrainer(argz).train(train);
fastText.saveModel(modelPath);
}
EvaluationResult result = fastText.test(test, 1);
Log.info(result.toString());
WebCorpus corpus = new WebCorpus("corpus", "labeled");
corpus.addDocuments(WebCorpus.loadDocuments(corpusPath));
Log.info("Testing started.");
List<String> testLines = Files.readAllLines(test, StandardCharsets.UTF_8);
try (PrintWriter pw = new PrintWriter(predictionPath.toFile(), "utf-8")) {
for (String testLine : testLines) {
String id = testLine.substring(0, testLine.indexOf(' ')).substring(1);
WebDocument doc = corpus.getDocument(id);
List<ScoredItem<String>> res = fastText.predict(testLine, 3);
List<String> predictedCategories = new ArrayList<>();
for (ScoredItem<String> re : res) {
if (re.score < -10) {
continue;
}
predictedCategories.add(String.format(Locale.ENGLISH, "%s (%.2f)", re.item.replaceAll("__label__", "").replaceAll("_", " "), re.score));
}
pw.println("id = " + id);
pw.println();
pw.println(doc.getTitle());
pw.println();
pw.println("Actual Category = " + doc.getCategory());
pw.println("Predictions = " + String.join(", ", predictedCategories));
pw.println();
pw.println("------------------------------------------------------");
pw.println();
}
}
Log.info("Done.");
}
use of zemberek.core.embeddings.FastTextTrainer in project zemberek-nlp by ahmetaa.
the class FastTextClassifierTrainer method train.
public FastTextClassifier train(Path corpus) {
Args args = Args.forSupervised();
args.loss = builder.type == LossType.SOFTMAX ? loss_name.softmax : loss_name.hierarchicalSoftmax;
args.dim = builder.dimension;
args.wordNgrams = builder.wordNgramOrder;
args.thread = builder.threadCount;
args.epoch = builder.epochCount;
args.lr = builder.learningRate;
args.ws = builder.contextWindowSize;
SubWordHashProvider p = builder.subWordHashProvider;
args.subWordHashProvider = p;
args.minn = p.getMinN();
args.maxn = p.getMaxN();
args.minCount = builder.minWordCount;
args.cutoff = builder.quantizationCutOff;
FastTextTrainer trainer = new FastTextTrainer(args);
// for catching and forwarding progress events.
trainer.getEventBus().register(this);
try {
return new FastTextClassifier(trainer.train(corpus));
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
use of zemberek.core.embeddings.FastTextTrainer in project zemberek-nlp by ahmetaa.
the class DocumentSimilarityExperiment method generateVectorModel.
public void generateVectorModel(Path input, Path modelFile) throws Exception {
Args argz = Args.forWordVectors(Args.model_name.skipGram);
argz.thread = 16;
argz.epoch = 10;
argz.dim = 250;
argz.bucket = 10;
argz.minCount = 10;
argz.minn = 0;
argz.maxn = 0;
// argz.wordNgrams = 2;
argz.subWordHashProvider = new EmbeddingHashProviders.EmptySubwordHashProvider();
// argz.subWordHashProvider = new Dictionary.CharacterNgramHashProvider(argz.minn, argz.maxn);
FastText fastText = new FastTextTrainer(argz).train(input);
Log.info("Saving vmodel to %s", modelFile);
fastText.saveModel(modelFile);
}
use of zemberek.core.embeddings.FastTextTrainer in project zemberek-nlp by ahmetaa.
the class FastTextTest method classificationTest.
/**
* Runs the cooking label guessing task. run with -Xms8G or more.
*/
@Test
@Ignore("Not an actual Test.")
public void classificationTest() throws Exception {
Path inputRoot = Paths.get("/home/ahmetaa/data/fasttext");
Path trainFile = inputRoot.resolve("cooking.train");
// Path trainFile = inputRoot.resolve("train.10k");
Path modelPath = inputRoot.resolve("cooking.model.bin");
Path quantizedModelPath = inputRoot.resolve("cooking.model.qbin");
Path testFile = inputRoot.resolve("cooking.valid");
Args argz = Args.forSupervised();
argz.thread = 4;
argz.epoch = 25;
argz.wordNgrams = 2;
argz.minCount = 1;
argz.lr = 1.0;
argz.dim = 100;
argz.bucket = 1000_000;
FastText fastText;
/* if(modelPath.toFile().exists())
fastText = FastText.load(modelPath);
else*/
fastText = new FastTextTrainer(argz).train(trainFile);
fastText.saveModel(modelPath);
Log.info("Testing started.");
test(fastText, testFile, 1);
fastText = FastText.load(modelPath);
test(fastText, testFile, 1);
argz.qnorm = false;
argz.cutoff = 3000;
fastText = fastText.quantize(modelPath, argz);
fastText.saveModel(quantizedModelPath);
Log.info("Testing quantization result.");
test(fastText, testFile, 1);
fastText = FastText.load(quantizedModelPath);
Log.info("Testing after loading quantized model.");
test(fastText, testFile, 1);
}
Aggregations