use of zemberek.core.embeddings.FastText in project zemberek-nlp by ahmetaa.
the class FastTextClassifier method load.
public static FastTextClassifier load(Path modelPath) throws IOException {
Preconditions.checkArgument(modelPath.toFile().exists(), "%s does not exist.", modelPath);
FastText fastText = FastText.load(modelPath);
return new FastTextClassifier(fastText);
}
use of zemberek.core.embeddings.FastText in project zemberek-nlp by ahmetaa.
the class TrainClassifier method run.
@Override
public void run() throws IOException {
Log.info("Generating classification model from %s", input);
FastTextClassifierTrainer trainer = FastTextClassifierTrainer.builder().epochCount(epochCount).learningRate(learningRate).lossType(lossType).quantizationCutOff(cutOff).minWordCount(minWordCount).threadCount(threadCount).wordNgramOrder(wordNGrams).dimension(dimension).contextWindowSize(contextWindowSize).build();
Log.info("Training Started.");
trainer.getEventBus().register(this);
FastTextClassifier classifier = trainer.train(input);
Log.info("Saving classification model to %s", output);
FastText fastText = classifier.getFastText();
fastText.saveModel(output);
if (applyQuantization) {
Log.info("Applying quantization.");
if (cutOff > 0) {
Log.info("Quantization dictionary cut-off value = %d", cutOff);
}
Path parent = output.getParent();
String name = output.toFile().getName() + ".q";
Path quantizedModel = parent == null ? Paths.get(name) : parent.resolve(name);
Log.info("Saving quantized classification model to %s", quantizedModel);
FastText quantized = fastText.quantize(output, fastText.getArgs());
quantized.saveModel(quantizedModel);
}
}
use of zemberek.core.embeddings.FastText in project zemberek-nlp by ahmetaa.
the class DocumentSimilarityExperiment method generateVectorModel.
public void generateVectorModel(Path input, Path modelFile) throws Exception {
Args argz = Args.forWordVectors(Args.model_name.skipGram);
argz.thread = 16;
argz.epoch = 10;
argz.dim = 250;
argz.bucket = 10;
argz.minCount = 10;
argz.minn = 0;
argz.maxn = 0;
// argz.wordNgrams = 2;
argz.subWordHashProvider = new EmbeddingHashProviders.EmptySubwordHashProvider();
// argz.subWordHashProvider = new Dictionary.CharacterNgramHashProvider(argz.minn, argz.maxn);
FastText fastText = new FastTextTrainer(argz).train(input);
Log.info("Saving vmodel to %s", modelFile);
fastText.saveModel(modelFile);
}
use of zemberek.core.embeddings.FastText in project zemberek-nlp by ahmetaa.
the class FastTextTest method quantizationTest.
/**
* Runs the dbpedia classification task. run with -Xms8G or more.
*/
@Test
@Ignore("Not an actual Test.")
public void quantizationTest() throws Exception {
Path inputRoot = Paths.get("/home/ahmetaa/projects/fastText/data");
Path trainFile = inputRoot.resolve("dbpedia.train");
// Path trainFile = inputRoot.resolve("train.10k");
Path modelPath = inputRoot.resolve("10k.model.bin");
Path quantizedModelPath = inputRoot.resolve("10k.model.qbin");
Path testFile = inputRoot.resolve("dbpedia.test");
Args argz = Args.forSupervised();
argz.thread = 4;
argz.epoch = 15;
argz.wordNgrams = 2;
argz.minCount = 5;
argz.lr = 0.1;
argz.dim = 30;
argz.bucket = 1000_000;
FastText fastText = FastText.load(modelPath);
// FastText fastText = FastText.train(trainFile, argz);
fastText.saveModel(modelPath);
Log.info("Testing started.");
test(fastText, testFile, 1);
fastText = FastText.load(modelPath);
test(fastText, testFile, 1);
argz.qnorm = false;
argz.cutoff = 15000;
fastText = fastText.quantize(modelPath, argz);
fastText.saveModel(quantizedModelPath);
Log.info("Testing quantization result.");
test(fastText, testFile, 1);
fastText = FastText.load(quantizedModelPath);
Log.info("Testing after loading quantized model.");
test(fastText, testFile, 1);
}
use of zemberek.core.embeddings.FastText in project zemberek-nlp by ahmetaa.
the class FastTextTest method classificationTest.
/**
* Runs the cooking label guessing task. run with -Xms8G or more.
*/
@Test
@Ignore("Not an actual Test.")
public void classificationTest() throws Exception {
Path inputRoot = Paths.get("/home/ahmetaa/data/fasttext");
Path trainFile = inputRoot.resolve("cooking.train");
// Path trainFile = inputRoot.resolve("train.10k");
Path modelPath = inputRoot.resolve("cooking.model.bin");
Path quantizedModelPath = inputRoot.resolve("cooking.model.qbin");
Path testFile = inputRoot.resolve("cooking.valid");
Args argz = Args.forSupervised();
argz.thread = 4;
argz.epoch = 25;
argz.wordNgrams = 2;
argz.minCount = 1;
argz.lr = 1.0;
argz.dim = 100;
argz.bucket = 1000_000;
FastText fastText;
/* if(modelPath.toFile().exists())
fastText = FastText.load(modelPath);
else*/
fastText = new FastTextTrainer(argz).train(trainFile);
fastText.saveModel(modelPath);
Log.info("Testing started.");
test(fastText, testFile, 1);
fastText = FastText.load(modelPath);
test(fastText, testFile, 1);
argz.qnorm = false;
argz.cutoff = 3000;
fastText = fastText.quantize(modelPath, argz);
fastText.saveModel(quantizedModelPath);
Log.info("Testing quantization result.");
test(fastText, testFile, 1);
fastText = FastText.load(quantizedModelPath);
Log.info("Testing after loading quantized model.");
test(fastText, testFile, 1);
}
Aggregations