use of org.deeplearning4j.models.paragraphvectors.ParagraphVectors in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readParagraphVectors.
/**
* This method restores ParagraphVectors model previously saved with writeParagraphVectors()
*
* @return
*/
public static ParagraphVectors readParagraphVectors(File file) throws IOException {
File tmpFileL = File.createTempFile("paravec", "l");
tmpFileL.deleteOnExit();
Word2Vec w2v = readWord2Vec(file);
// and "convert" it to ParaVec model + optionally trying to restore labels information
ParagraphVectors vectors = new ParagraphVectors.Builder(w2v.getConfiguration()).vocabCache(w2v.getVocab()).lookupTable(w2v.getLookupTable()).resetModel(false).build();
ZipFile zipFile = new ZipFile(file);
// now we try to restore labels information
ZipEntry labels = zipFile.getEntry("labels.txt");
if (labels != null) {
InputStream stream = zipFile.getInputStream(labels);
Files.copy(stream, Paths.get(tmpFileL.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
try (BufferedReader reader = new BufferedReader(new FileReader(tmpFileL))) {
String line;
while ((line = reader.readLine()) != null) {
VocabWord word = vectors.getVocab().tokenFor(decodeB64(line.trim()));
if (word != null) {
word.markAsLabel(true);
}
}
}
}
vectors.extractLabels();
return vectors;
}
use of org.deeplearning4j.models.paragraphvectors.ParagraphVectors in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testMalformedLabels1.
@Test
public void testMalformedLabels1() throws Exception {
List<String> words = new ArrayList<>();
words.add("test A");
words.add("test B");
words.add("test\nC");
words.add("test`D");
words.add("test_E");
words.add("test 5");
AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
int cnt = 0;
for (String word : words) {
vocabCache.addToken(new VocabWord(1.0, word));
vocabCache.addWordToIndex(cnt, word);
cnt++;
}
vocabCache.elementAtIndex(1).markAsLabel(true);
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>(vocabCache, 10, false, 0.01, Nd4j.getRandom(), 0.0);
lookupTable.resetWeights(true);
assertNotEquals(null, lookupTable.getSyn0());
assertNotEquals(null, lookupTable.getSyn1());
assertNotEquals(null, lookupTable.getExpTable());
assertEquals(null, lookupTable.getSyn1Neg());
ParagraphVectors vec = new ParagraphVectors.Builder().lookupTable(lookupTable).vocabCache(vocabCache).build();
File tempFile = File.createTempFile("temp", "w2v");
tempFile.deleteOnExit();
WordVectorSerializer.writeParagraphVectors(vec, tempFile);
ParagraphVectors restoredVec = WordVectorSerializer.readParagraphVectors(tempFile);
for (String word : words) {
assertEquals(true, restoredVec.hasWord(word));
}
assertTrue(restoredVec.getVocab().elementAtIndex(1).isLabel());
}
use of org.deeplearning4j.models.paragraphvectors.ParagraphVectors in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testParaVecSerialization1.
@Test
public void testParaVecSerialization1() throws Exception {
VectorsConfiguration configuration = new VectorsConfiguration();
configuration.setIterations(14123);
configuration.setLayersSize(156);
INDArray syn0 = Nd4j.rand(100, configuration.getLayersSize());
INDArray syn1 = Nd4j.rand(100, configuration.getLayersSize());
AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
for (int i = 0; i < 100; i++) {
VocabWord word = new VocabWord((float) i, "word_" + i);
List<Integer> points = new ArrayList<>();
List<Byte> codes = new ArrayList<>();
int num = org.apache.commons.lang3.RandomUtils.nextInt(1, 20);
for (int x = 0; x < num; x++) {
points.add(org.apache.commons.lang3.RandomUtils.nextInt(1, 100000));
codes.add(org.apache.commons.lang3.RandomUtils.nextBytes(10)[0]);
}
if (RandomUtils.nextInt(10) < 3) {
word.markAsLabel(true);
}
word.setIndex(i);
word.setPoints(points);
word.setCodes(codes);
cache.addToken(word);
cache.addWordToIndex(i, word.getLabel());
}
InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(configuration.getLayersSize()).cache(cache).build();
lookupTable.setSyn0(syn0);
lookupTable.setSyn1(syn1);
ParagraphVectors originalVectors = new ParagraphVectors.Builder(configuration).vocabCache(cache).lookupTable(lookupTable).build();
File tempFile = File.createTempFile("paravec", "tests");
tempFile.deleteOnExit();
WordVectorSerializer.writeParagraphVectors(originalVectors, tempFile);
ParagraphVectors restoredVectors = WordVectorSerializer.readParagraphVectors(tempFile);
InMemoryLookupTable<VocabWord> restoredLookupTable = (InMemoryLookupTable<VocabWord>) restoredVectors.getLookupTable();
AbstractCache<VocabWord> restoredVocab = (AbstractCache<VocabWord>) restoredVectors.getVocab();
assertEquals(restoredLookupTable.getSyn0(), lookupTable.getSyn0());
assertEquals(restoredLookupTable.getSyn1(), lookupTable.getSyn1());
for (int i = 0; i < cache.numWords(); i++) {
assertEquals(cache.elementAtIndex(i).isLabel(), restoredVocab.elementAtIndex(i).isLabel());
assertEquals(cache.wordAtIndex(i), restoredVocab.wordAtIndex(i));
assertEquals(cache.elementAtIndex(i).getElementFrequency(), restoredVocab.elementAtIndex(i).getElementFrequency(), 0.1f);
List<Integer> originalPoints = cache.elementAtIndex(i).getPoints();
List<Integer> restoredPoints = restoredVocab.elementAtIndex(i).getPoints();
assertEquals(originalPoints.size(), restoredPoints.size());
for (int x = 0; x < originalPoints.size(); x++) {
assertEquals(originalPoints.get(x), restoredPoints.get(x));
}
List<Byte> originalCodes = cache.elementAtIndex(i).getCodes();
List<Byte> restoredCodes = restoredVocab.elementAtIndex(i).getCodes();
assertEquals(originalCodes.size(), restoredCodes.size());
for (int x = 0; x < originalCodes.size(); x++) {
assertEquals(originalCodes.get(x), restoredCodes.get(x));
}
}
}
use of org.deeplearning4j.models.paragraphvectors.ParagraphVectors in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readParagraphVectorsFromText.
/**
* Restores previously serialized ParagraphVectors model
*
* Deprecation note: Please, consider using readParagraphVectors() method instead
*
* @param stream InputStream that contains previously serialized model
* @return
*/
@Deprecated
public static ParagraphVectors readParagraphVectorsFromText(@NonNull InputStream stream) {
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"));
ArrayList<String> labels = new ArrayList<>();
ArrayList<INDArray> arrays = new ArrayList<>();
VocabCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
String line = "";
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
split[1] = split[1].replaceAll(whitespaceReplacement, " ");
VocabWord word = new VocabWord(1.0, split[1]);
if (split[0].equals("L")) {
// we have label element here
word.setSpecial(true);
word.markAsLabel(true);
labels.add(word.getLabel());
} else if (split[0].equals("E")) {
// we have usual element, aka word here
word.setSpecial(false);
word.markAsLabel(false);
} else
throw new IllegalStateException("Source stream doesn't looks like ParagraphVectors serialized model");
// this particular line is just for backward compatibility with InMemoryLookupCache
word.setIndex(vocabCache.numWords());
vocabCache.addToken(word);
vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
// backward compatibility code
vocabCache.putVocabWord(word.getLabel());
float[] vector = new float[split.length - 2];
for (int i = 2; i < split.length; i++) {
vector[i - 2] = Float.parseFloat(split[i]);
}
INDArray row = Nd4j.create(vector);
arrays.add(row);
}
// now we create syn0 matrix, using previously fetched rows
/*INDArray syn = Nd4j.create(new int[]{arrays.size(), arrays.get(0).columns()});
for (int i = 0; i < syn.rows(); i++) {
syn.putRow(i, arrays.get(i));
}*/
INDArray syn = Nd4j.vstack(arrays);
InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(arrays.get(0).columns()).useAdaGrad(false).cache(vocabCache).build();
Nd4j.clearNans(syn);
lookupTable.setSyn0(syn);
LabelsSource source = new LabelsSource(labels);
ParagraphVectors vectors = new ParagraphVectors.Builder().labelsSource(source).vocabCache(vocabCache).lookupTable(lookupTable).modelUtils(new BasicModelUtils<VocabWord>()).build();
try {
reader.close();
} catch (Exception e) {
}
vectors.extractLabels();
return vectors;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Aggregations