use of zemberek.lm.LmVocabulary in project zemberek-nlp by ahmetaa.
the class SmoothLmTest method testProbabilities.
@Test
public void testProbabilities() throws IOException {
SmoothLm lm = getTinyLm();
System.out.println(lm.info());
LmVocabulary vocabulary = lm.getVocabulary();
int[] is = { vocabulary.indexOf("<s>") };
Assert.assertEquals(-1.716003, lm.getProbabilityValue(is), 0.0001);
Assert.assertEquals(-1.716003, lm.getProbability(is), 0.0001);
// <s> kedi
int[] is2 = { vocabulary.indexOf("<s>"), vocabulary.indexOf("kedi") };
Assert.assertEquals(-0.796249, lm.getProbabilityValue(is2), 0.0001);
Assert.assertEquals(-0.796249, lm.getProbability(is2), 0.0001);
// Ahmet dondurma yedi
int[] is3 = { vocabulary.indexOf("Ahmet"), vocabulary.indexOf("dondurma"), vocabulary.indexOf("yedi") };
Assert.assertEquals(-0.602060, lm.getProbabilityValue(is3), 0.0001);
Assert.assertEquals(-0.602060, lm.getProbability(is3), 0.0001);
}
use of zemberek.lm.LmVocabulary in project zemberek-nlp by ahmetaa.
the class WordDistances method saveDistanceListBin.
public static void saveDistanceListBin(Path vectorFile, Path outFile, Path vocabFile, int distanceAmount, int blockSize, int threadSize) throws Exception {
Log.info("Loading vectors.");
List<WordVector> wordVectors = WordVector.loadFromBinary(vectorFile);
Log.info("Writing vocabulary.");
// write vocabulary.
List<String> words = new ArrayList<>(wordVectors.size());
wordVectors.forEach(s -> words.add(s.word));
LmVocabulary vocabulary = new LmVocabulary(words);
vocabulary.saveBinary(vocabFile.toFile());
Log.info("Calculating distances.");
// create a thread pool executor
ExecutorService es = Executors.newFixedThreadPool(threadSize);
CompletionService<List<BlockUnit>> completionService = new ExecutorCompletionService<>(es);
int blockCounter = 0;
for (int i = 0; i < wordVectors.size(); i += blockSize) {
int endIndex = i + blockSize >= wordVectors.size() ? wordVectors.size() : i + blockSize;
completionService.submit(new BlockDistanceTask(wordVectors, distanceAmount, blockSize, i, endIndex));
blockCounter++;
}
es.shutdown();
List<WordDistances> distancesToWrite = new ArrayList<>(wordVectors.size());
int i = 0;
while (i < blockCounter) {
List<BlockUnit> units = completionService.take().get();
for (BlockUnit unit : units) {
String source = unit.vector.word;
List<Distance> distList = new ArrayList<>(unit.distQueue);
Collections.sort(distList);
Collections.reverse(distList);
distancesToWrite.add(new WordDistances(source, distList.toArray(new Distance[distList.size()])));
}
i++;
if ((i * blockSize % 10) == 0) {
Log.info("%d of %d completed", i * blockSize, wordVectors.size());
}
}
Log.info("Writing.");
try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(outFile.toFile())))) {
dos.writeInt(wordVectors.size());
dos.writeInt(distanceAmount);
for (WordDistances w : distancesToWrite) {
dos.writeInt(vocabulary.indexOf(w.source));
for (Distance token : w.distances) {
dos.writeInt(vocabulary.indexOf(token.word));
dos.writeFloat(token.distance);
}
}
}
}
use of zemberek.lm.LmVocabulary in project zemberek-nlp by ahmetaa.
the class WordVectorLookup method loadFromBinary.
public static WordVectorLookup loadFromBinary(Path vectorFile, Path vocabularyFile) throws IOException {
LmVocabulary vocabulary = LmVocabulary.loadFromBinary(vocabularyFile.toFile());
try (DataInputStream dis = IOUtil.getDataInputStream(vectorFile)) {
int wordCount = dis.readInt();
int vectorDimension = dis.readInt();
Vector[] vectors = new Vector[wordCount];
for (int i = 0; i < wordCount; i++) {
int index = dis.readInt();
if (index > wordCount || index < 0) {
throw new IllegalStateException("Bad word index " + index);
}
float[] vec = FloatArrays.deserializeRaw(dis, vectorDimension);
vectors[i] = new Vector(index, vec);
}
return new WordVectorLookup(vocabulary, vectors);
}
}
use of zemberek.lm.LmVocabulary in project zemberek-nlp by ahmetaa.
the class WordVectorLookup method loadFromBinaryFast.
public static WordVectorLookup loadFromBinaryFast(Path vectorFile, Path vocabularyFile) throws IOException {
LmVocabulary vocabulary = LmVocabulary.loadFromBinary(vocabularyFile.toFile());
int wordCount;
int vectorDimension;
try (DataInputStream dis = IOUtil.getDataInputStream(vectorFile)) {
wordCount = dis.readInt();
vectorDimension = dis.readInt();
}
RandomAccessFile aFile = new RandomAccessFile(vectorFile.toFile(), "r");
FileChannel inChannel = aFile.getChannel();
long start = 8, size;
int blockSize = 4 + vectorDimension * 4;
Vector[] vectors = new Vector[wordCount];
int wordCounter = 0;
int wordBlockSize = 100_000;
while (wordCounter < wordCount) {
if (wordCounter + wordBlockSize > wordCount) {
wordBlockSize = (wordCount - wordCounter);
}
size = blockSize * wordBlockSize;
MappedByteBuffer buffer = inChannel.map(FileChannel.MapMode.READ_ONLY, start, size);
buffer.load();
start += size;
int blockCounter = 0;
while (blockCounter < wordBlockSize) {
int wordIndex = buffer.getInt();
float[] data = new float[vectorDimension];
buffer.asFloatBuffer().get(data);
vectors[wordCounter] = new Vector(wordIndex, data);
wordCounter++;
blockCounter++;
buffer.position(blockCounter * blockSize);
}
}
return new WordVectorLookup(vocabulary, vectors);
}
use of zemberek.lm.LmVocabulary in project zemberek-nlp by ahmetaa.
the class TurkishSpellChecker method suggestForWord.
public List<String> suggestForWord(String word, String leftContext, String rightContext, NgramLanguageModel lm) {
List<String> unRanked = getUnrankedSuggestions(word);
if (lm == null) {
Log.warn("No language model provided. Returning unraked results.");
return unRanked;
}
if (lm.getOrder() < 2) {
Log.warn("Language model order is 1. For context ranking it should be at least 2. " + "Unigram ranking will be applied.");
return suggestForWord(word, lm);
}
LmVocabulary vocabulary = lm.getVocabulary();
List<ScoredItem<String>> results = new ArrayList<>(unRanked.size());
for (String str : unRanked) {
if (leftContext == null) {
leftContext = vocabulary.getSentenceStart();
} else {
leftContext = normalizeForLm(leftContext);
}
if (rightContext == null) {
rightContext = vocabulary.getSentenceEnd();
} else {
rightContext = normalizeForLm(rightContext);
}
String w = normalizeForLm(str);
int wordIndex = vocabulary.indexOf(w);
int leftIndex = vocabulary.indexOf(leftContext);
int rightIndex = vocabulary.indexOf(rightContext);
float score;
if (lm.getOrder() == 2) {
score = lm.getProbability(leftIndex, wordIndex) + lm.getProbability(wordIndex, rightIndex);
} else {
score = lm.getProbability(leftIndex, wordIndex, rightIndex);
}
results.add(new ScoredItem<>(str, score));
}
results.sort(ScoredItem.STRING_COMP_DESCENDING);
return results.stream().map(s -> s.item).collect(Collectors.toList());
}
Aggregations