Search in sources :

Example 1 with Environment

use of org.nd4j.linalg.heartbeat.reports.Environment in project deeplearning4j by deeplearning4j.

the class ComputationGraph method update.

private void update(Task task) {
    if (!initDone) {
        initDone = true;
        Heartbeat heartbeat = Heartbeat.getInstance();
        task = ModelSerializer.taskByModel(this);
        Environment env = EnvironmentUtils.buildEnvironment();
        heartbeat.reportEvent(Event.STANDALONE, env, task);
    }
}
Also used : Heartbeat(org.nd4j.linalg.heartbeat.Heartbeat) Environment(org.nd4j.linalg.heartbeat.reports.Environment)

Example 2 with Environment

use of org.nd4j.linalg.heartbeat.reports.Environment in project deeplearning4j by deeplearning4j.

the class Word2Vec method train.

/**
     *  Training word2vec model on a given text corpus
     *
     * @param corpusRDD training corpus
     * @throws Exception
     */
public void train(JavaRDD<String> corpusRDD) throws Exception {
    log.info("Start training ...");
    if (workers > 0)
        corpusRDD.repartition(workers);
    // SparkContext
    final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
    // Pre-defined variables
    Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
    Map<String, Object> word2vecVarMap = getWord2vecVarMap();
    // Variables to fill in train
    final JavaRDD<AtomicLong> sentenceWordsCountRDD;
    final JavaRDD<List<VocabWord>> vocabWordListRDD;
    final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
    final VocabCache<VocabWord> vocabCache;
    final JavaRDD<Long> sentenceCumSumCountRDD;
    int maxRep = 1;
    // Start Training //
    //////////////////////////////////////
    log.info("Tokenization and building VocabCache ...");
    // Processing every sentence and make a VocabCache which gets fed into a LookupCache
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    // Get total word count and put into word2vec variable map
    word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
    // 2 RDDs: (vocab words list) and (sentence Count).Already cached
    sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
    vocabWordListRDD = pipeline.getVocabWordListRDD();
    // Get vocabCache and broad-casted vocabCache
    Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
    vocabCache = vocabCacheBroadcast.getValue();
    log.info("Vocab size: {}", vocabCache.numWords());
    //////////////////////////////////////
    log.info("Building Huffman Tree ...");
    // Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
    /*
        We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
        
        Huffman huffman = new Huffman(vocabCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(vocabCache);
        */
    //////////////////////////////////////
    log.info("Calculating cumulative sum of sentence counts ...");
    sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
    //////////////////////////////////////
    log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
    vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
    /////////////////////////////////////
    log.info("Broadcasting word2vec variables to workers ...");
    Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
    Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
    /////////////////////////////////////
    log.info("Training word2vec sentences ...");
    FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
    @SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
    // Get all the syn0 updates into a list in driver
    List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
    // Instantiate syn0
    INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
    // Updating syn0 first pass: just add vectors obtained from different nodes
    log.info("Averaging results...");
    Map<VocabWord, AtomicInteger> updates = new HashMap<>();
    Map<Long, Long> updaters = new HashMap<>();
    for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
        syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
        // for proper averaging we need to divide resulting sums later, by the number of additions
        if (updates.containsKey(syn0UpdateEntry.getFirst())) {
            updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
        } else
            updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
        if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
            updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
        }
    }
    // Updating syn0 second pass: average obtained vectors
    for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
        if (entry.getValue().get() > 1) {
            if (entry.getValue().get() > maxRep)
                maxRep = entry.getValue().get();
            syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
        }
    }
    long totals = 0;
    log.info("Finished calculations...");
    vocab = vocabCache;
    InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
    Environment env = EnvironmentUtils.buildEnvironment();
    env.setNumCores(maxRep);
    env.setAvailableMemory(totals);
    update(env, Event.SPARK);
    inMemoryLookupTable.setVocab(vocabCache);
    inMemoryLookupTable.setVectorLength(layerSize);
    inMemoryLookupTable.setSyn0(syn0);
    lookupTable = inMemoryLookupTable;
    modelUtils.init(lookupTable);
}
Also used : HashMap(java.util.HashMap) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) ArrayList(java.util.ArrayList) List(java.util.List) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) Environment(org.nd4j.linalg.heartbeat.reports.Environment) HashMap(java.util.HashMap) Map(java.util.Map)

Example 3 with Environment

use of org.nd4j.linalg.heartbeat.reports.Environment in project nd4j by deeplearning4j.

the class EnvironmentUtils method buildEnvironment.

/**
 * This method build
 * @return
 */
public static Environment buildEnvironment() {
    Environment environment = new Environment();
    environment.setJavaVersion(System.getProperty("java.specification.version"));
    environment.setNumCores(Runtime.getRuntime().availableProcessors());
    environment.setAvailableMemory(Runtime.getRuntime().maxMemory());
    environment.setOsArch(System.getProperty("os.arch"));
    environment.setOsName(System.getProperty("os.opName"));
    environment.setBackendUsed(Nd4j.getExecutioner().getClass().getSimpleName());
    return environment;
}
Also used : Environment(org.nd4j.linalg.heartbeat.reports.Environment)

Example 4 with Environment

use of org.nd4j.linalg.heartbeat.reports.Environment in project deeplearning4j by deeplearning4j.

the class MultiLayerTest method testCid.

@Test
@Ignore
public void testCid() throws Exception {
    System.out.println(EnvironmentUtils.buildCId());
    Environment environment = EnvironmentUtils.buildEnvironment();
    environment.setSerialVersionID(EnvironmentUtils.buildCId());
    Task task = TaskUtils.buildTask(Nd4j.create(new double[] { 1, 2, 3, 4, 5, 6 }));
    Heartbeat.getInstance().reportEvent(Event.STANDALONE, environment, task);
    Thread.sleep(25000);
}
Also used : Task(org.nd4j.linalg.heartbeat.reports.Task) Environment(org.nd4j.linalg.heartbeat.reports.Environment) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 5 with Environment

use of org.nd4j.linalg.heartbeat.reports.Environment in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method update.

private void update(Task task) {
    if (!initDone) {
        initDone = true;
        Heartbeat heartbeat = Heartbeat.getInstance();
        task = ModelSerializer.taskByModel(this);
        Environment env = EnvironmentUtils.buildEnvironment();
        heartbeat.reportEvent(Event.STANDALONE, env, task);
    }
}
Also used : Heartbeat(org.nd4j.linalg.heartbeat.Heartbeat) Environment(org.nd4j.linalg.heartbeat.reports.Environment)

Aggregations

Environment (org.nd4j.linalg.heartbeat.reports.Environment)7 Task (org.nd4j.linalg.heartbeat.reports.Task)3 Heartbeat (org.nd4j.linalg.heartbeat.Heartbeat)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Map (java.util.Map)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 AtomicLong (java.util.concurrent.atomic.AtomicLong)1 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)1 FlatMapFunction (org.apache.spark.api.java.function.FlatMapFunction)1 Pair (org.deeplearning4j.berkeley.Pair)1 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)1 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)1 VocabCache (org.deeplearning4j.models.word2vec.wordstore.VocabCache)1 CountCumSum (org.deeplearning4j.spark.text.functions.CountCumSum)1 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)1 Ignore (org.junit.Ignore)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1