Search in sources :

Example 1 with Args

use of zemberek.core.embeddings.Args in project zemberek-nlp by ahmetaa.

the class AutomaticLabelingExperiment method getOrTrainFastText.

private FastText getOrTrainFastText(Path train, Path modelPath) throws Exception {
    FastText fastText;
    if (modelPath.toFile().exists()) {
        fastText = FastText.load(modelPath);
    } else {
        Args argz = Args.forSupervised();
        argz.thread = 16;
        argz.loss = Args.loss_name.hierarchicalSoftmax;
        argz.epoch = 100;
        argz.wordNgrams = 2;
        argz.minCount = 10;
        argz.lr = 0.2;
        argz.dim = 250;
        argz.bucket = 7_000_000;
        fastText = new FastTextTrainer(argz).train(train);
        fastText.saveModel(modelPath);
    }
    return fastText;
}
Also used : Args(zemberek.core.embeddings.Args) FastTextTrainer(zemberek.core.embeddings.FastTextTrainer) FastText(zemberek.core.embeddings.FastText)

Example 2 with Args

use of zemberek.core.embeddings.Args in project zemberek-nlp by ahmetaa.

the class CategoryPredictionExperiment method runExperiment.

private void runExperiment() throws Exception {
    Path corpusPath = experimentRoot.resolve("category.corpus");
    Path train = experimentRoot.resolve("category.train");
    Path test = experimentRoot.resolve("category.test");
    Path titleRaw = experimentRoot.resolve("category.title");
    Path modelPath = experimentRoot.resolve("category.model");
    Path predictionPath = experimentRoot.resolve("category.predictions");
    extractCategoryDocuments(rawCorpusRoot, corpusPath);
    boolean useOnlyTitles = true;
    boolean useLemmas = true;
    generateSets(corpusPath, train, test, useOnlyTitles, useLemmas);
    generateRawSet(corpusPath, titleRaw);
    FastText fastText;
    if (modelPath.toFile().exists()) {
        Log.info("Reusing existing model %s", modelPath);
        fastText = FastText.load(modelPath);
    } else {
        Args argz = Args.forSupervised();
        argz.thread = 4;
        argz.model = Args.model_name.supervised;
        argz.loss = Args.loss_name.softmax;
        argz.epoch = 50;
        argz.wordNgrams = 2;
        argz.minCount = 0;
        argz.lr = 0.5;
        argz.dim = 100;
        argz.bucket = 5_000_000;
        fastText = new FastTextTrainer(argz).train(train);
        fastText.saveModel(modelPath);
    }
    EvaluationResult result = fastText.test(test, 1);
    Log.info(result.toString());
    WebCorpus corpus = new WebCorpus("corpus", "labeled");
    corpus.addDocuments(WebCorpus.loadDocuments(corpusPath));
    Log.info("Testing started.");
    List<String> testLines = Files.readAllLines(test, StandardCharsets.UTF_8);
    try (PrintWriter pw = new PrintWriter(predictionPath.toFile(), "utf-8")) {
        for (String testLine : testLines) {
            String id = testLine.substring(0, testLine.indexOf(' ')).substring(1);
            WebDocument doc = corpus.getDocument(id);
            List<ScoredItem<String>> res = fastText.predict(testLine, 3);
            List<String> predictedCategories = new ArrayList<>();
            for (ScoredItem<String> re : res) {
                if (re.score < -10) {
                    continue;
                }
                predictedCategories.add(String.format(Locale.ENGLISH, "%s (%.2f)", re.item.replaceAll("__label__", "").replaceAll("_", " "), re.score));
            }
            pw.println("id = " + id);
            pw.println();
            pw.println(doc.getTitle());
            pw.println();
            pw.println("Actual Category = " + doc.getCategory());
            pw.println("Predictions   = " + String.join(", ", predictedCategories));
            pw.println();
            pw.println("------------------------------------------------------");
            pw.println();
        }
    }
    Log.info("Done.");
}
Also used : Path(java.nio.file.Path) Args(zemberek.core.embeddings.Args) ScoredItem(zemberek.core.ScoredItem) ArrayList(java.util.ArrayList) FastTextTrainer(zemberek.core.embeddings.FastTextTrainer) EvaluationResult(zemberek.core.embeddings.FastText.EvaluationResult) WebDocument(zemberek.corpus.WebDocument) WebCorpus(zemberek.corpus.WebCorpus) FastText(zemberek.core.embeddings.FastText) PrintWriter(java.io.PrintWriter)

Example 3 with Args

use of zemberek.core.embeddings.Args in project zemberek-nlp by ahmetaa.

the class FastTextClassifierTrainer method train.

public FastTextClassifier train(Path corpus) {
    Args args = Args.forSupervised();
    args.loss = builder.type == LossType.SOFTMAX ? loss_name.softmax : loss_name.hierarchicalSoftmax;
    args.dim = builder.dimension;
    args.wordNgrams = builder.wordNgramOrder;
    args.thread = builder.threadCount;
    args.epoch = builder.epochCount;
    args.lr = builder.learningRate;
    args.ws = builder.contextWindowSize;
    SubWordHashProvider p = builder.subWordHashProvider;
    args.subWordHashProvider = p;
    args.minn = p.getMinN();
    args.maxn = p.getMaxN();
    args.minCount = builder.minWordCount;
    args.cutoff = builder.quantizationCutOff;
    FastTextTrainer trainer = new FastTextTrainer(args);
    // for catching and forwarding progress events.
    trainer.getEventBus().register(this);
    try {
        return new FastTextClassifier(trainer.train(corpus));
    } catch (Exception e) {
        e.printStackTrace();
        throw new RuntimeException(e);
    }
}
Also used : Args(zemberek.core.embeddings.Args) SubWordHashProvider(zemberek.core.embeddings.SubWordHashProvider) FastTextTrainer(zemberek.core.embeddings.FastTextTrainer)

Example 4 with Args

use of zemberek.core.embeddings.Args in project zemberek-nlp by ahmetaa.

the class DocumentSimilarityExperiment method generateVectorModel.

public void generateVectorModel(Path input, Path modelFile) throws Exception {
    Args argz = Args.forWordVectors(Args.model_name.skipGram);
    argz.thread = 16;
    argz.epoch = 10;
    argz.dim = 250;
    argz.bucket = 10;
    argz.minCount = 10;
    argz.minn = 0;
    argz.maxn = 0;
    // argz.wordNgrams = 2;
    argz.subWordHashProvider = new EmbeddingHashProviders.EmptySubwordHashProvider();
    // argz.subWordHashProvider = new Dictionary.CharacterNgramHashProvider(argz.minn, argz.maxn);
    FastText fastText = new FastTextTrainer(argz).train(input);
    Log.info("Saving vmodel to %s", modelFile);
    fastText.saveModel(modelFile);
}
Also used : Args(zemberek.core.embeddings.Args) EmbeddingHashProviders(zemberek.core.embeddings.EmbeddingHashProviders) FastTextTrainer(zemberek.core.embeddings.FastTextTrainer) FastText(zemberek.core.embeddings.FastText)

Example 5 with Args

use of zemberek.core.embeddings.Args in project zemberek-nlp by ahmetaa.

the class FastTextTest method quantizationTest.

/**
 * Runs the dbpedia classification task. run with -Xms8G or more.
 */
@Test
@Ignore("Not an actual Test.")
public void quantizationTest() throws Exception {
    Path inputRoot = Paths.get("/home/ahmetaa/projects/fastText/data");
    Path trainFile = inputRoot.resolve("dbpedia.train");
    // Path trainFile = inputRoot.resolve("train.10k");
    Path modelPath = inputRoot.resolve("10k.model.bin");
    Path quantizedModelPath = inputRoot.resolve("10k.model.qbin");
    Path testFile = inputRoot.resolve("dbpedia.test");
    Args argz = Args.forSupervised();
    argz.thread = 4;
    argz.epoch = 15;
    argz.wordNgrams = 2;
    argz.minCount = 5;
    argz.lr = 0.1;
    argz.dim = 30;
    argz.bucket = 1000_000;
    FastText fastText = FastText.load(modelPath);
    // FastText fastText = FastText.train(trainFile, argz);
    fastText.saveModel(modelPath);
    Log.info("Testing started.");
    test(fastText, testFile, 1);
    fastText = FastText.load(modelPath);
    test(fastText, testFile, 1);
    argz.qnorm = false;
    argz.cutoff = 15000;
    fastText = fastText.quantize(modelPath, argz);
    fastText.saveModel(quantizedModelPath);
    Log.info("Testing quantization result.");
    test(fastText, testFile, 1);
    fastText = FastText.load(quantizedModelPath);
    Log.info("Testing after loading quantized model.");
    test(fastText, testFile, 1);
}
Also used : Path(java.nio.file.Path) Args(zemberek.core.embeddings.Args) FastText(zemberek.core.embeddings.FastText) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

Args (zemberek.core.embeddings.Args)8 FastText (zemberek.core.embeddings.FastText)7 FastTextTrainer (zemberek.core.embeddings.FastTextTrainer)7 Path (java.nio.file.Path)5 Ignore (org.junit.Ignore)4 Test (org.junit.Test)4 EmbeddingHashProviders (zemberek.core.embeddings.EmbeddingHashProviders)2 EvaluationResult (zemberek.core.embeddings.FastText.EvaluationResult)2 PrintWriter (java.io.PrintWriter)1 ArrayList (java.util.ArrayList)1 ScoredItem (zemberek.core.ScoredItem)1 SubWordHashProvider (zemberek.core.embeddings.SubWordHashProvider)1 WebCorpus (zemberek.corpus.WebCorpus)1 WebDocument (zemberek.corpus.WebDocument)1