Search in sources :

Example 1 with TfIdfObject

use of com.graphaware.nlp.domain.TfIdfObject in project neo4j-nlp by graphaware.

the class TextRank method checkNextKeyword.

private Map<String, Keyword> checkNextKeyword(KeywordExtractedItem keywordOccurrence, Map<Long, Map<Long, CoOccurrenceItem>> coOccurrences, Map<Long, KeywordExtractedItem> keywords) {
    long tagId = keywordOccurrence.getTagId();
    Map<String, Keyword> results = new HashMap<>();
    if (!coOccurrences.containsKey(tagId))
        return results;
    // mapping: sourceStartPosition -> Set(destination tagIDs)
    Map<Integer, Set<Long>> mapStartId = createThisMapping(coOccurrences.get(tagId));
    Set<Long> coOccurrence = mapStartId.get(keywordOccurrence.getStartPosition());
    if (coOccurrence == null) {
        return results;
    }
    Iterator<Long> iterator = coOccurrence.stream().filter((ccEntry) -> ccEntry != tagId).filter((ccEntry) -> keywords.containsKey(ccEntry)).iterator();
    while (iterator.hasNext()) {
        Long ccEntry = iterator.next();
        String relValue = keywords.get(ccEntry).getValue();
        // System.out.println("checkNextKeyword >> " + relValue);
        List<Long> merged = new ArrayList<>(keywords.get(ccEntry).getRelatedTags());
        // new
        merged.retainAll(keywordOccurrence.getRelatedTags());
        // TO DO: even when using dependencies, we should be able to merge words that are next to each other but that have no dependency (?)
        if (!useDependencies || keywordOccurrence.getRelatedTags().contains(keywords.get(ccEntry).getTagId()) || merged.size() > 0) {
            // System.out.println("checkNextKeyword >>> " + relValue);
            addToResults(relValue, keywords.get(ccEntry).getRelevance(), new TfIdfObject(0., 0.), 0, results, 1);
        }
    }
    return results;
}
Also used : KeywordPersister(com.graphaware.nlp.persistence.persisters.KeywordPersister) NLPManager(com.graphaware.nlp.NLPManager) java.util(java.util) DynamicConfiguration(com.graphaware.nlp.configuration.DynamicConfiguration) Log(org.neo4j.logging.Log) TfIdfObject(com.graphaware.nlp.domain.TfIdfObject) Keyword(com.graphaware.nlp.domain.Keyword) Collectors(java.util.stream.Collectors) AtomicReference(java.util.concurrent.atomic.AtomicReference) DESCRIBES(com.graphaware.nlp.persistence.constants.Relationships.DESCRIBES) Labels(com.graphaware.nlp.persistence.constants.Labels) org.neo4j.graphdb(org.neo4j.graphdb) LoggerFactory(com.graphaware.common.log.LoggerFactory) Pair(com.graphaware.common.util.Pair) PipelineSpecification(com.graphaware.nlp.dsl.request.PipelineSpecification) Keyword(com.graphaware.nlp.domain.Keyword) TfIdfObject(com.graphaware.nlp.domain.TfIdfObject)

Example 2 with TfIdfObject

use of com.graphaware.nlp.domain.TfIdfObject in project neo4j-nlp by graphaware.

the class TextRank method initializeNodeWeights_TfIdf.

private void initializeNodeWeights_TfIdf(Map<Long, TfIdfObject> nodeWeights, Node annotatedText, Map<Long, Map<Long, CoOccurrenceItem>> coOccurrences) {
    // Map<Long, TfIdfObject> nodeWeights = new HashMap<>();
    if (coOccurrences != null) {
        coOccurrences.entrySet().stream().forEach((coOccurrence) -> {
            // source
            nodeWeights.put(coOccurrence.getKey(), new TfIdfObject(1.0d, 1.0d));
            coOccurrence.getValue().entrySet().stream().filter((entry) -> !nodeWeights.containsKey(entry.getKey())).forEach((entry) -> {
                // destination
                nodeWeights.put(entry.getValue().getDestination(), new TfIdfObject(1.0d, 1.0d));
            });
        });
    }
    String query = "MATCH (doc:AnnotatedText)\n" + "WITH count(doc) as documentsCount\n" + "MATCH (a:AnnotatedText)-[:CONTAINS_SENTENCE]->(:Sentence)-[ht:HAS_TAG]->(t:Tag)\n" + "WHERE id(a) = {id} \n" + "WITH t, sum(ht.tf) as tf, documentsCount\n" + "MATCH (a:AnnotatedText)-[:CONTAINS_SENTENCE]->(:Sentence)-[:HAS_TAG]->(t)\n" + "RETURN id(t) as tag, t.id as tagVal, tf, count(distinct a) as docCountForTag, documentsCount\n";
    try (Transaction tx = database.beginTx()) {
        Result res = database.execute(query, Collections.singletonMap("id", annotatedText.getId()));
        while (res != null && res.hasNext()) {
            Map<String, Object> next = res.next();
            Long tag = (Long) next.get("tag");
            if (// initialize only those that are needed
            coOccurrences != null && !nodeWeights.keySet().contains(tag))
                continue;
            long tf = ((Long) next.get("tf"));
            long docCount = (long) next.get("documentsCount");
            long docCountTag = (long) next.get("docCountForTag");
            double idf = Math.log10(1.0d * docCount / docCountTag);
            if (nodeWeights.containsKey(tag)) {
                nodeWeights.get(tag).setTf(tf);
                nodeWeights.get(tag).setIdf(idf);
            } else {
                nodeWeights.put(tag, new TfIdfObject(tf, idf));
            }
        // LOG.info((String) next.get("tagVal") + ": tf = " + tf + ", idf = " + idf + " (docCountTag = " + docCountTag + "), tf*idf = " + tf*idf);
        }
        tx.success();
    } catch (Exception e) {
        LOG.error("Error while initializing node weights: ", e);
        // nodeWeights;
        return;
    }
    // nodeWeights;
    return;
}
Also used : TfIdfObject(com.graphaware.nlp.domain.TfIdfObject) KeywordPersister(com.graphaware.nlp.persistence.persisters.KeywordPersister) NLPManager(com.graphaware.nlp.NLPManager) java.util(java.util) DynamicConfiguration(com.graphaware.nlp.configuration.DynamicConfiguration) Log(org.neo4j.logging.Log) TfIdfObject(com.graphaware.nlp.domain.TfIdfObject) Keyword(com.graphaware.nlp.domain.Keyword) Collectors(java.util.stream.Collectors) AtomicReference(java.util.concurrent.atomic.AtomicReference) DESCRIBES(com.graphaware.nlp.persistence.constants.Relationships.DESCRIBES) Labels(com.graphaware.nlp.persistence.constants.Labels) org.neo4j.graphdb(org.neo4j.graphdb) LoggerFactory(com.graphaware.common.log.LoggerFactory) Pair(com.graphaware.common.util.Pair) PipelineSpecification(com.graphaware.nlp.dsl.request.PipelineSpecification) TfIdfObject(com.graphaware.nlp.domain.TfIdfObject)

Example 3 with TfIdfObject

use of com.graphaware.nlp.domain.TfIdfObject in project neo4j-nlp by graphaware.

the class TextRank method evaluate.

public boolean evaluate(Node annotatedText, int iter, double damp, double threshold) {
    Map<Long, Map<Long, CoOccurrenceItem>> coOccurrence = createCooccurrences(annotatedText, cooccurrencesFromDependencies);
    PageRank pageRank = new PageRank(database);
    // if (useTfIdfWeights) {
    // pageRank.setNodeWeights(initializeNodeWeights_TfIdf(annotatedText, coOccurrence));
    // }
    Map<Long, Double> pageRanks = pageRank.run(coOccurrence, iter, damp, threshold);
    if (cooccurrencesFromDependencies) {
        coOccurrence.clear();
        // co-occurrences from natural word flow; needed for merging keywords into key phrases
        coOccurrence = createCooccurrences(annotatedText, false);
    }
    if (pageRanks == null) {
        LOG.error("Page ranks not retrieved, aborting evaluate() method ...");
        return false;
    }
    // get tf*idf: useful for cleanFinalKeywords()
    final Map<Long, TfIdfObject> tfidfMap = new HashMap<>();
    if (useDependencies)
        initializeNodeWeights_TfIdf(tfidfMap, annotatedText, null);
    // for z-scores: calculate mean and sigma of relevances and tf*idf
    relevanceAvg = pageRanks.entrySet().stream().mapToDouble(e -> e.getValue()).average().orElse(0.);
    relevanceSigma = Math.sqrt(pageRanks.entrySet().stream().mapToDouble(e -> Math.pow((e.getValue() - relevanceAvg), 2)).average().orElse(0.));
    tfidfAvg = tfidfMap.entrySet().stream().mapToDouble(e -> e.getValue().getTfIdf()).average().orElse(0.);
    tfidfSigma = Math.sqrt(tfidfMap.entrySet().stream().mapToDouble(e -> Math.pow(e.getValue().getTfIdf() - tfidfAvg, 2)).average().orElse(0.));
    int n_oneThird = (int) (pageRanks.size() * topxTags);
    List<Long> topThird = getTopX(pageRanks, n_oneThird);
    pageRanks.entrySet().stream().sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).forEach(en -> System.out.println("   " + idToValue.get(en.getKey()) + ": " + en.getValue()));
    Map<String, Object> params = new HashMap<>();
    params.put("id", annotatedText.getId());
    params.put("posList", admittedPOSs);
    params.put("stopwords", removeStopWords ? stopWords : new ArrayList<>());
    List<KeywordExtractedItem> keywordsOccurrences = new ArrayList<>();
    Map<Long, KeywordExtractedItem> keywordMap = new HashMap<>();
    List<Long> wrongNEs = new ArrayList<>();
    try (Transaction tx = database.beginTx()) {
        Result res = database.execute(GET_TAG_QUERY, params);
        while (res != null && res.hasNext()) {
            Map<String, Object> next = res.next();
            long tagId = (long) next.get("tagId");
            // remove stop-NEs
            if (iterableToList((Iterable<String>) next.get("labels")).stream().anyMatch(el -> forbiddenNEs.contains(el))) {
                wrongNEs.add(tagId);
                continue;
            }
            KeywordExtractedItem item = new KeywordExtractedItem(tagId);
            item.setStartPosition(((Number) next.get("sP")).intValue());
            item.setValue(((String) next.get("tag")));
            item.setEndPosition(((Number) next.get("eP")).intValue());
            item.setRelatedTags(iterableToList((Iterable<Long>) next.get("rel_tags")));
            item.setRelTagStartingPoints(iterableToList((Iterable<Number>) next.get("rel_tos")));
            item.setRelTagEndingPoints(iterableToList((Iterable<Number>) next.get("rel_toe")));
            item.setRelevance(pageRanks.containsKey(tagId) ? pageRanks.get(tagId) : 0);
            keywordsOccurrences.add(item);
            if (!keywordMap.containsKey(tagId)) {
                keywordMap.put(tagId, item);
            } else {
                // new
                // new
                keywordMap.get(tagId).update(item);
            }
        // System.out.println(" Adding for " + item.getValue() + ": " + item.getRelatedTags());
        }
        if (res != null) {
            res.close();
        }
        tx.success();
    } catch (Exception e) {
        LOG.error("Error while running TextRank evaluation: ", e);
        return false;
    }
    Map<String, Keyword> results = new HashMap<>();
    while (!keywordsOccurrences.isEmpty()) {
        final AtomicReference<KeywordExtractedItem> keywordOccurrence = new AtomicReference<>(keywordsOccurrences.remove(0));
        final AtomicReference<String> currValue = new AtomicReference<>(keywordOccurrence.get().getValue());
        final AtomicReference<Double> currRelevance = new AtomicReference<>(keywordOccurrence.get().getRelevance());
        final AtomicReference<TfIdfObject> currTfIdf = new AtomicReference<>(!tfidfMap.isEmpty() ? tfidfMap.get(keywordOccurrence.get().getTagId()) : new TfIdfObject(1.0d, 1.0d));
        final AtomicReference<Integer> currNTopRated = new AtomicReference<>(0);
        Set<Long> relTagIDs = getRelTagsIntoDepth(keywordOccurrence.get(), keywordsOccurrences);
        // System.out.println("\n val: " + keywordOccurrence.get().getValue() + ", relTagIDs: " + relTagIDs.stream().map(el -> idToValue.get(el)).collect(Collectors.joining(", ")));
        // keep only those that are among top 1/3
        relTagIDs.retainAll(topThird);
        // System.out.println("   relTagIDs among top 1/3: " + relTagIDs.stream().map(el -> idToValue.get(el)).collect(Collectors.joining(", ")));
        if (// if useDependencies==false, keep only those keywords that are among top 1/3
        !useDependencies && !topThird.contains(keywordOccurrence.get().getTagId()))
            continue;
        if (useDependencies && !topThird.contains(keywordOccurrence.get().getTagId()) && relTagIDs.size() == 0)
            continue;
        // System.out.println("\n> " + currValue.get() + " - " + keywordOccurrence.get().getStartPosition());
        Map<String, Keyword> localResults;
        if (topThird.contains(keywordOccurrence.get().getTagId()))
            currNTopRated.set(currNTopRated.get() + 1);
        do {
            int endPosition = keywordOccurrence.get().getEndPosition();
            // System.out.println("  cur: " + currValue.get() + ". Examining next level");
            localResults = checkNextKeyword(keywordOccurrence.get(), coOccurrence, keywordMap);
            if (localResults.size() > 0) {
                // System.out.println("    related tags: " + localResults.entrySet().stream().map(en -> en.getKey()).collect(Collectors.joining(", ")));
                keywordOccurrence.set(null);
                localResults.entrySet().stream().forEach((item) -> {
                    KeywordExtractedItem nextKeyword = keywordsOccurrences.get(0);
                    // System.out.println("      " + nextKeyword.getValue() + ": " + nextKeyword.getStartPosition());
                    if (nextKeyword != null && nextKeyword.getValue().equalsIgnoreCase(item.getKey()) && (topThird.contains(nextKeyword.getTagId()) || useDependencies) && // crucial condition for graphs from co-occurrences, but very useful also for graphs from dependencies
                    (nextKeyword.getStartPosition() - endPosition) == 1) // && ((nextKeyword.getStartPosition() - endPosition) == 1 || useDependencies))
                    {
                        String newCurrValue = currValue.get().trim().split("_")[0] + " " + item.getKey();
                        // System.out.println(">> " + newCurrValue);
                        double newCurrRelevance = currRelevance.get() + item.getValue().getRelevance();
                        if (topThird.contains(nextKeyword.getTagId()))
                            currNTopRated.set(currNTopRated.get() + 1);
                        currValue.set(newCurrValue);
                        currRelevance.set(newCurrRelevance);
                        if (tfidfMap != null && tfidfMap.containsKey(nextKeyword.getTagId())) {
                            // tf and idf are sums of tf and idf of all words in a phrase
                            double tf = currTfIdf.get().getTf() + tfidfMap.get(nextKeyword.getTagId()).getTf();
                            double idf = currTfIdf.get().getIdf() + tfidfMap.get(nextKeyword.getTagId()).getIdf();
                            // minimal tf and idf
                            // double tf  = currTfIdf.get().getTf() < tfidfMap.get(nextKeyword.getTagId()).getTf() ? currTfIdf.get().getTf() : tfidfMap.get(nextKeyword.getTagId()).getTf();
                            // double idf = currTfIdf.get().getIdf() < tfidfMap.get(nextKeyword.getTagId()).getIdf() ? currTfIdf.get().getIdf() : tfidfMap.get(nextKeyword.getTagId()).getIdf();
                            currTfIdf.set(new TfIdfObject(tf, idf));
                        }
                        keywordOccurrence.set(nextKeyword);
                        keywordsOccurrences.remove(0);
                    }
                // else {
                // LOG.warn("Next keyword not found!");
                // keywordOccurrence.set(null);
                // }
                });
            }
        } while (!localResults.isEmpty() && keywordOccurrence.get() != null);
        if (currNTopRated.get() > 0)
            addToResults(currValue.get(), currRelevance.get(), currTfIdf.get(), currNTopRated.get(), results, 1);
    // System.out.println("< " + currValue.get());
    }
    if (expandNEs) {
        // add named entities that contain at least some of the top 1/3 of words
        for (Long key : neExpanded.keySet()) {
            if (neExpanded.get(key).stream().filter(v -> topThird.contains(v)).count() == 0)
                continue;
            if (wrongNEs.contains(key))
                continue;
            // .toLowerCase();
            String keystr = idToValue.get(key);
            double pr = pageRanks.containsKey(key) ? pageRanks.get(key) : 0.;
            if (// set PageRank value of a NE to max value of PR of it's composite words
            pr == 0.)
                pr = (double) pageRanks.entrySet().stream().filter(en -> neExpanded.get(key).contains(en.getKey())).mapToDouble(en -> en.getValue()).max().orElse(0.);
            addToResults(keystr, pr, tfidfMap != null && tfidfMap.containsKey(key) ? tfidfMap.get(key) : new TfIdfObject(1., 1.), (int) (neExpanded.get(key).stream().filter(v -> topThird.contains(v)).count()), results, 1);
        }
    }
    computeTotalOccurrence(results);
    if (cleanKeywords) {
        results = cleanFinalKeywords(results, n_oneThird);
    }
    peristKeyword(results, annotatedText);
    return true;
}
Also used : KeywordPersister(com.graphaware.nlp.persistence.persisters.KeywordPersister) NLPManager(com.graphaware.nlp.NLPManager) java.util(java.util) DynamicConfiguration(com.graphaware.nlp.configuration.DynamicConfiguration) Log(org.neo4j.logging.Log) TfIdfObject(com.graphaware.nlp.domain.TfIdfObject) Keyword(com.graphaware.nlp.domain.Keyword) Collectors(java.util.stream.Collectors) AtomicReference(java.util.concurrent.atomic.AtomicReference) DESCRIBES(com.graphaware.nlp.persistence.constants.Relationships.DESCRIBES) Labels(com.graphaware.nlp.persistence.constants.Labels) org.neo4j.graphdb(org.neo4j.graphdb) LoggerFactory(com.graphaware.common.log.LoggerFactory) Pair(com.graphaware.common.util.Pair) PipelineSpecification(com.graphaware.nlp.dsl.request.PipelineSpecification) Keyword(com.graphaware.nlp.domain.Keyword) AtomicReference(java.util.concurrent.atomic.AtomicReference) TfIdfObject(com.graphaware.nlp.domain.TfIdfObject) TfIdfObject(com.graphaware.nlp.domain.TfIdfObject)

Aggregations

LoggerFactory (com.graphaware.common.log.LoggerFactory)3 Pair (com.graphaware.common.util.Pair)3 NLPManager (com.graphaware.nlp.NLPManager)3 DynamicConfiguration (com.graphaware.nlp.configuration.DynamicConfiguration)3 Keyword (com.graphaware.nlp.domain.Keyword)3 TfIdfObject (com.graphaware.nlp.domain.TfIdfObject)3 PipelineSpecification (com.graphaware.nlp.dsl.request.PipelineSpecification)3 Labels (com.graphaware.nlp.persistence.constants.Labels)3 DESCRIBES (com.graphaware.nlp.persistence.constants.Relationships.DESCRIBES)3 KeywordPersister (com.graphaware.nlp.persistence.persisters.KeywordPersister)3 java.util (java.util)3 AtomicReference (java.util.concurrent.atomic.AtomicReference)3 Collectors (java.util.stream.Collectors)3 org.neo4j.graphdb (org.neo4j.graphdb)3 Log (org.neo4j.logging.Log)3