use of org.datavec.api.util.ClassPathResource in project deeplearning4j by deeplearning4j.
the class ManualTests method testWord2VecPlot.
@Test
public void testWord2VecPlot() throws Exception {
File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(2).batchSize(1000).learningRate(0.025).layerSize(100).seed(42).sampling(0).negativeSample(0).windowSize(5).modelUtils(new BasicModelUtils<VocabWord>()).useAdaGrad(false).iterate(iter).workers(10).tokenizerFactory(t).build();
vec.fit();
// UiConnectionInfo connectionInfo = UiServer.getInstance().getConnectionInfo();
// vec.getLookupTable().plotVocab(100, connectionInfo);
Thread.sleep(10000000000L);
fail("Not implemented");
}
use of org.datavec.api.util.ClassPathResource in project deeplearning4j by deeplearning4j.
the class GloveTest method testGlove.
@Test
public void testGlove() throws Exception {
Glove glove = new Glove(true, 5, 100);
JavaRDD<String> corpus = sc.textFile(new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath()).map(new Function<String, String>() {
@Override
public String call(String s) throws Exception {
return s.toLowerCase();
}
});
Pair<VocabCache<VocabWord>, GloveWeightLookupTable> table = glove.train(corpus);
WordVectors vectors = WordVectorSerializer.fromPair(new Pair<>((InMemoryLookupTable) table.getSecond(), (VocabCache) table.getFirst()));
Collection<String> words = vectors.wordsNearest("day", 20);
assertTrue(words.contains("week"));
}
use of org.datavec.api.util.ClassPathResource in project deeplearning4j by deeplearning4j.
the class Word2VecTest method testSparkW2VonBiggerCorpus.
@Ignore
@Test
public void testSparkW2VonBiggerCorpus() throws Exception {
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest").set("spark.driver.maxResultSize", "4g").set("spark.driver.memory", "8g").set("spark.executor.memory", "8g");
// Set SparkContext
JavaSparkContext sc = new JavaSparkContext(sparkConf);
// Path of data part-00000
//String dataPath = new ClassPathResource("/big/raw_sentences.txt").getFile().getAbsolutePath();
// String dataPath = "/ext/Temp/SampleRussianCorpus.txt";
String dataPath = new ClassPathResource("spark_word2vec_test.txt").getFile().getAbsolutePath();
// Read in data
JavaRDD<String> corpus = sc.textFile(dataPath);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new LowCasePreProcessor());
Word2Vec word2Vec = new Word2Vec.Builder().setNGrams(1).tokenizerFactory(t).seed(42L).negative(3).useAdaGrad(false).layerSize(100).windowSize(5).learningRate(0.025).minLearningRate(0.0001).iterations(1).batchSize(100).minWordFrequency(5).useUnknown(true).build();
word2Vec.train(corpus);
sc.stop();
WordVectorSerializer.writeWordVectors(word2Vec.getLookupTable(), "/ext/Temp/sparkRuModel.txt");
}
use of org.datavec.api.util.ClassPathResource in project deeplearning4j by deeplearning4j.
the class Word2VecTest method testConcepts.
@Test
public void testConcepts() throws Exception {
// These are all default values for word2vec
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest");
// Set SparkContext
JavaSparkContext sc = new JavaSparkContext(sparkConf);
// Path of data part-00000
String dataPath = new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath();
// dataPath = "/ext/Temp/part-00000";
// String dataPath = new ClassPathResource("spark_word2vec_test.txt").getFile().getAbsolutePath();
// Read in data
JavaRDD<String> corpus = sc.textFile(dataPath);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec word2Vec = new Word2Vec.Builder().setNGrams(1).tokenizerFactory(t).seed(42L).negative(10).useAdaGrad(false).layerSize(150).windowSize(5).learningRate(0.025).minLearningRate(0.0001).iterations(1).batchSize(100).minWordFrequency(5).stopWords(Arrays.asList("three")).useUnknown(true).build();
word2Vec.train(corpus);
//word2Vec.setModelUtils(new FlatModelUtils());
System.out.println("UNK: " + word2Vec.getWordVectorMatrix("UNK"));
InMemoryLookupTable<VocabWord> table = (InMemoryLookupTable<VocabWord>) word2Vec.lookupTable();
double sim = word2Vec.similarity("day", "night");
System.out.println("day/night similarity: " + sim);
/*
System.out.println("Hornjo: " + word2Vec.getWordVectorMatrix("hornjoserbsce"));
System.out.println("carro: " + word2Vec.getWordVectorMatrix("carro"));
Collection<String> portu = word2Vec.wordsNearest("carro", 10);
printWords("carro", portu, word2Vec);
portu = word2Vec.wordsNearest("davi", 10);
printWords("davi", portu, word2Vec);
System.out.println("---------------------------------------");
*/
Collection<String> words = word2Vec.wordsNearest("day", 10);
printWords("day", words, word2Vec);
assertTrue(words.contains("night"));
assertTrue(words.contains("week"));
assertTrue(words.contains("year"));
sim = word2Vec.similarity("two", "four");
System.out.println("two/four similarity: " + sim);
words = word2Vec.wordsNearest("two", 10);
printWords("two", words, word2Vec);
// three should be absent due to stopWords
assertFalse(words.contains("three"));
assertTrue(words.contains("five"));
assertTrue(words.contains("four"));
sc.stop();
// test serialization
File tempFile = File.createTempFile("temp", "tmp");
tempFile.deleteOnExit();
int idx1 = word2Vec.vocab().wordFor("day").getIndex();
INDArray array1 = word2Vec.getWordVectorMatrix("day").dup();
VocabWord word1 = word2Vec.vocab().elementAtIndex(0);
WordVectorSerializer.writeWordVectors(word2Vec.getLookupTable(), tempFile);
WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
VocabWord word2 = ((VocabCache<VocabWord>) vectors.vocab()).elementAtIndex(0);
VocabWord wordIT = ((VocabCache<VocabWord>) vectors.vocab()).wordFor("it");
int idx2 = vectors.vocab().wordFor("day").getIndex();
INDArray array2 = vectors.getWordVectorMatrix("day").dup();
System.out.println("word 'i': " + word2);
System.out.println("word 'it': " + wordIT);
assertEquals(idx1, idx2);
assertEquals(word1, word2);
assertEquals(array1, array2);
}
use of org.datavec.api.util.ClassPathResource in project tutorials by eugenp.
the class IrisClassifier method main.
public static void main(String[] args) throws IOException, InterruptedException {
DataSet allData;
try (RecordReader recordReader = new CSVRecordReader(0, ',')) {
recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, FEATURES_COUNT, CLASSES_COUNT);
allData = iterator.next();
}
allData.shuffle(42);
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(allData);
normalizer.transform(allData);
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder().iterations(1000).activation(Activation.TANH).weightInit(WeightInit.XAVIER).learningRate(0.1).regularization(true).l2(0.0001).list().layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(3).nOut(CLASSES_COUNT).build()).backprop(true).pretrain(false).build();
MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init();
model.fit(trainingData);
INDArray output = model.output(testData.getFeatureMatrix());
Evaluation eval = new Evaluation(CLASSES_COUNT);
eval.eval(testData.getLabels(), output);
System.out.println(eval.stats());
}
Aggregations