Search in sources :

Example 6 with PointInSetQuery

use of org.apache.lucene.search.PointInSetQuery in project lucene-solr by apache.

the class JoinUtil method createJoinQuery.

/**
   * Method for query time joining for numeric fields. It supports multi- and single- values longs, ints, floats and longs.
   * All considerations from {@link JoinUtil#createJoinQuery(String, boolean, String, Query, IndexSearcher, ScoreMode)} are applicable here too,
   * though memory consumption might be higher.
   * <p>
   *
   * @param fromField                 The from field to join from
   * @param multipleValuesPerDocument Whether the from field has multiple terms per document
   *                                  when true fromField might be {@link DocValuesType#SORTED_NUMERIC},
   *                                  otherwise fromField should be {@link DocValuesType#NUMERIC}
   * @param toField                   The to field to join to, should be {@link IntPoint}, {@link LongPoint}, {@link FloatPoint}
   *                                  or {@link DoublePoint}.
   * @param numericType               either {@link java.lang.Integer}, {@link java.lang.Long}, {@link java.lang.Float}
   *                                  or {@link java.lang.Double} it should correspond to toField types
   * @param fromQuery                 The query to match documents on the from side
   * @param fromSearcher              The searcher that executed the specified fromQuery
   * @param scoreMode                 Instructs how scores from the fromQuery are mapped to the returned query
   * @return a {@link Query} instance that can be used to join documents based on the
   *         terms in the from and to field
   * @throws IOException If I/O related errors occur
   */
public static Query createJoinQuery(String fromField, boolean multipleValuesPerDocument, String toField, Class<? extends Number> numericType, Query fromQuery, IndexSearcher fromSearcher, ScoreMode scoreMode) throws IOException {
    TreeSet<Long> joinValues = new TreeSet<>();
    Map<Long, Float> aggregatedScores = new HashMap<>();
    Map<Long, Integer> occurrences = new HashMap<>();
    boolean needsScore = scoreMode != ScoreMode.None;
    BiConsumer<Long, Float> scoreAggregator;
    if (scoreMode == ScoreMode.Max) {
        scoreAggregator = (key, score) -> {
            Float currentValue = aggregatedScores.putIfAbsent(key, score);
            if (currentValue != null) {
                aggregatedScores.put(key, Math.max(currentValue, score));
            }
        };
    } else if (scoreMode == ScoreMode.Min) {
        scoreAggregator = (key, score) -> {
            Float currentValue = aggregatedScores.putIfAbsent(key, score);
            if (currentValue != null) {
                aggregatedScores.put(key, Math.min(currentValue, score));
            }
        };
    } else if (scoreMode == ScoreMode.Total) {
        scoreAggregator = (key, score) -> {
            Float currentValue = aggregatedScores.putIfAbsent(key, score);
            if (currentValue != null) {
                aggregatedScores.put(key, currentValue + score);
            }
        };
    } else if (scoreMode == ScoreMode.Avg) {
        scoreAggregator = (key, score) -> {
            Float currentSore = aggregatedScores.putIfAbsent(key, score);
            if (currentSore != null) {
                aggregatedScores.put(key, currentSore + score);
            }
            Integer currentOccurrence = occurrences.putIfAbsent(key, 1);
            if (currentOccurrence != null) {
                occurrences.put(key, ++currentOccurrence);
            }
        };
    } else {
        scoreAggregator = (key, score) -> {
            throw new UnsupportedOperationException();
        };
    }
    LongFunction<Float> joinScorer;
    if (scoreMode == ScoreMode.Avg) {
        joinScorer = (joinValue) -> {
            Float aggregatedScore = aggregatedScores.get(joinValue);
            Integer occurrence = occurrences.get(joinValue);
            return aggregatedScore / occurrence;
        };
    } else {
        joinScorer = aggregatedScores::get;
    }
    Collector collector;
    if (multipleValuesPerDocument) {
        collector = new SimpleCollector() {

            SortedNumericDocValues sortedNumericDocValues;

            Scorer scorer;

            @Override
            public void collect(int doc) throws IOException {
                if (doc > sortedNumericDocValues.docID()) {
                    sortedNumericDocValues.advance(doc);
                }
                if (doc == sortedNumericDocValues.docID()) {
                    for (int i = 0; i < sortedNumericDocValues.docValueCount(); i++) {
                        long value = sortedNumericDocValues.nextValue();
                        joinValues.add(value);
                        if (needsScore) {
                            scoreAggregator.accept(value, scorer.score());
                        }
                    }
                }
            }

            @Override
            protected void doSetNextReader(LeafReaderContext context) throws IOException {
                sortedNumericDocValues = DocValues.getSortedNumeric(context.reader(), fromField);
            }

            @Override
            public void setScorer(Scorer scorer) throws IOException {
                this.scorer = scorer;
            }

            @Override
            public boolean needsScores() {
                return needsScore;
            }
        };
    } else {
        collector = new SimpleCollector() {

            NumericDocValues numericDocValues;

            Scorer scorer;

            private int lastDocID = -1;

            private boolean docsInOrder(int docID) {
                if (docID < lastDocID) {
                    throw new AssertionError("docs out of order: lastDocID=" + lastDocID + " vs docID=" + docID);
                }
                lastDocID = docID;
                return true;
            }

            @Override
            public void collect(int doc) throws IOException {
                assert docsInOrder(doc);
                int dvDocID = numericDocValues.docID();
                if (dvDocID < doc) {
                    dvDocID = numericDocValues.advance(doc);
                }
                long value;
                if (dvDocID == doc) {
                    value = numericDocValues.longValue();
                } else {
                    value = 0;
                }
                joinValues.add(value);
                if (needsScore) {
                    scoreAggregator.accept(value, scorer.score());
                }
            }

            @Override
            protected void doSetNextReader(LeafReaderContext context) throws IOException {
                numericDocValues = DocValues.getNumeric(context.reader(), fromField);
                lastDocID = -1;
            }

            @Override
            public void setScorer(Scorer scorer) throws IOException {
                this.scorer = scorer;
            }

            @Override
            public boolean needsScores() {
                return needsScore;
            }
        };
    }
    fromSearcher.search(fromQuery, collector);
    Iterator<Long> iterator = joinValues.iterator();
    final int bytesPerDim;
    final BytesRef encoded = new BytesRef();
    final PointInSetIncludingScoreQuery.Stream stream;
    if (Integer.class.equals(numericType)) {
        bytesPerDim = Integer.BYTES;
        stream = new PointInSetIncludingScoreQuery.Stream() {

            @Override
            public BytesRef next() {
                if (iterator.hasNext()) {
                    long value = iterator.next();
                    IntPoint.encodeDimension((int) value, encoded.bytes, 0);
                    if (needsScore) {
                        score = joinScorer.apply(value);
                    }
                    return encoded;
                } else {
                    return null;
                }
            }
        };
    } else if (Long.class.equals(numericType)) {
        bytesPerDim = Long.BYTES;
        stream = new PointInSetIncludingScoreQuery.Stream() {

            @Override
            public BytesRef next() {
                if (iterator.hasNext()) {
                    long value = iterator.next();
                    LongPoint.encodeDimension(value, encoded.bytes, 0);
                    if (needsScore) {
                        score = joinScorer.apply(value);
                    }
                    return encoded;
                } else {
                    return null;
                }
            }
        };
    } else if (Float.class.equals(numericType)) {
        bytesPerDim = Float.BYTES;
        stream = new PointInSetIncludingScoreQuery.Stream() {

            @Override
            public BytesRef next() {
                if (iterator.hasNext()) {
                    long value = iterator.next();
                    FloatPoint.encodeDimension(Float.intBitsToFloat((int) value), encoded.bytes, 0);
                    if (needsScore) {
                        score = joinScorer.apply(value);
                    }
                    return encoded;
                } else {
                    return null;
                }
            }
        };
    } else if (Double.class.equals(numericType)) {
        bytesPerDim = Double.BYTES;
        stream = new PointInSetIncludingScoreQuery.Stream() {

            @Override
            public BytesRef next() {
                if (iterator.hasNext()) {
                    long value = iterator.next();
                    DoublePoint.encodeDimension(Double.longBitsToDouble(value), encoded.bytes, 0);
                    if (needsScore) {
                        score = joinScorer.apply(value);
                    }
                    return encoded;
                } else {
                    return null;
                }
            }
        };
    } else {
        throw new IllegalArgumentException("unsupported numeric type, only Integer, Long, Float and Double are supported");
    }
    encoded.bytes = new byte[bytesPerDim];
    encoded.length = bytesPerDim;
    if (needsScore) {
        return new PointInSetIncludingScoreQuery(scoreMode, fromQuery, multipleValuesPerDocument, toField, bytesPerDim, stream) {

            @Override
            protected String toString(byte[] value) {
                return toString.apply(value, numericType);
            }
        };
    } else {
        return new PointInSetQuery(toField, 1, bytesPerDim, stream) {

            @Override
            protected String toString(byte[] value) {
                return PointInSetIncludingScoreQuery.toString.apply(value, numericType);
            }
        };
    }
}
Also used : Query(org.apache.lucene.search.Query) LongPoint(org.apache.lucene.document.LongPoint) MatchNoDocsQuery(org.apache.lucene.search.MatchNoDocsQuery) NumericDocValues(org.apache.lucene.index.NumericDocValues) HashMap(java.util.HashMap) TreeSet(java.util.TreeSet) DoublePoint(org.apache.lucene.document.DoublePoint) PointInSetQuery(org.apache.lucene.search.PointInSetQuery) Locale(java.util.Locale) Map(java.util.Map) BiConsumer(java.util.function.BiConsumer) SortedSetDocValues(org.apache.lucene.index.SortedSetDocValues) IntPoint(org.apache.lucene.document.IntPoint) LeafReaderContext(org.apache.lucene.index.LeafReaderContext) SortedDocValues(org.apache.lucene.index.SortedDocValues) SimpleCollector(org.apache.lucene.search.SimpleCollector) Scorer(org.apache.lucene.search.Scorer) Iterator(java.util.Iterator) LongFunction(java.util.function.LongFunction) FloatPoint(org.apache.lucene.document.FloatPoint) MultiDocValues(org.apache.lucene.index.MultiDocValues) BytesRef(org.apache.lucene.util.BytesRef) IOException(java.io.IOException) Collector(org.apache.lucene.search.Collector) SortedNumericDocValues(org.apache.lucene.index.SortedNumericDocValues) Function(org.apache.lucene.search.join.DocValuesTermsCollector.Function) DocValues(org.apache.lucene.index.DocValues) DocValuesType(org.apache.lucene.index.DocValuesType) LeafReader(org.apache.lucene.index.LeafReader) BinaryDocValues(org.apache.lucene.index.BinaryDocValues) IndexSearcher(org.apache.lucene.search.IndexSearcher) NumericDocValues(org.apache.lucene.index.NumericDocValues) SortedNumericDocValues(org.apache.lucene.index.SortedNumericDocValues) SortedNumericDocValues(org.apache.lucene.index.SortedNumericDocValues) HashMap(java.util.HashMap) Scorer(org.apache.lucene.search.Scorer) SimpleCollector(org.apache.lucene.search.SimpleCollector) TreeSet(java.util.TreeSet) SimpleCollector(org.apache.lucene.search.SimpleCollector) Collector(org.apache.lucene.search.Collector) LeafReaderContext(org.apache.lucene.index.LeafReaderContext) BytesRef(org.apache.lucene.util.BytesRef) PointInSetQuery(org.apache.lucene.search.PointInSetQuery) IOException(java.io.IOException) LongPoint(org.apache.lucene.document.LongPoint) DoublePoint(org.apache.lucene.document.DoublePoint) IntPoint(org.apache.lucene.document.IntPoint) FloatPoint(org.apache.lucene.document.FloatPoint)

Example 7 with PointInSetQuery

use of org.apache.lucene.search.PointInSetQuery in project lucene-solr by apache.

the class FloatPoint method newSetQuery.

/**
   * Create a query matching any of the specified 1D values.  This is the points equivalent of {@code TermsQuery}.
   * 
   * @param field field name. must not be {@code null}.
   * @param values all values to match
   */
public static Query newSetQuery(String field, float... values) {
    // Don't unexpectedly change the user's incoming values array:
    float[] sortedValues = values.clone();
    Arrays.sort(sortedValues);
    final BytesRef encoded = new BytesRef(new byte[Float.BYTES]);
    return new PointInSetQuery(field, 1, Float.BYTES, new PointInSetQuery.Stream() {

        int upto;

        @Override
        public BytesRef next() {
            if (upto == sortedValues.length) {
                return null;
            } else {
                encodeDimension(sortedValues[upto], encoded.bytes, 0);
                upto++;
                return encoded;
            }
        }
    }) {

        @Override
        protected String toString(byte[] value) {
            assert value.length == Float.BYTES;
            return Float.toString(decodeDimension(value, 0));
        }
    };
}
Also used : PointInSetQuery(org.apache.lucene.search.PointInSetQuery) BytesRef(org.apache.lucene.util.BytesRef)

Example 8 with PointInSetQuery

use of org.apache.lucene.search.PointInSetQuery in project lucene-solr by apache.

the class IntPoint method newSetQuery.

/**
   * Create a query matching any of the specified 1D values.  This is the points equivalent of {@code TermsQuery}.
   * 
   * @param field field name. must not be {@code null}.
   * @param values all values to match
   */
public static Query newSetQuery(String field, int... values) {
    // Don't unexpectedly change the user's incoming values array:
    int[] sortedValues = values.clone();
    Arrays.sort(sortedValues);
    final BytesRef encoded = new BytesRef(new byte[Integer.BYTES]);
    return new PointInSetQuery(field, 1, Integer.BYTES, new PointInSetQuery.Stream() {

        int upto;

        @Override
        public BytesRef next() {
            if (upto == sortedValues.length) {
                return null;
            } else {
                encodeDimension(sortedValues[upto], encoded.bytes, 0);
                upto++;
                return encoded;
            }
        }
    }) {

        @Override
        protected String toString(byte[] value) {
            assert value.length == Integer.BYTES;
            return Integer.toString(decodeDimension(value, 0));
        }
    };
}
Also used : PointInSetQuery(org.apache.lucene.search.PointInSetQuery) BytesRef(org.apache.lucene.util.BytesRef)

Example 9 with PointInSetQuery

use of org.apache.lucene.search.PointInSetQuery in project lucene-solr by apache.

the class LongPoint method newSetQuery.

/**
   * Create a query matching any of the specified 1D values.  This is the points equivalent of {@code TermsQuery}.
   * 
   * @param field field name. must not be {@code null}.
   * @param values all values to match
   */
public static Query newSetQuery(String field, long... values) {
    // Don't unexpectedly change the user's incoming values array:
    long[] sortedValues = values.clone();
    Arrays.sort(sortedValues);
    final BytesRef encoded = new BytesRef(new byte[Long.BYTES]);
    return new PointInSetQuery(field, 1, Long.BYTES, new PointInSetQuery.Stream() {

        int upto;

        @Override
        public BytesRef next() {
            if (upto == sortedValues.length) {
                return null;
            } else {
                encodeDimension(sortedValues[upto], encoded.bytes, 0);
                upto++;
                return encoded;
            }
        }
    }) {

        @Override
        protected String toString(byte[] value) {
            assert value.length == Long.BYTES;
            return Long.toString(decodeDimension(value, 0));
        }
    };
}
Also used : PointInSetQuery(org.apache.lucene.search.PointInSetQuery) BytesRef(org.apache.lucene.util.BytesRef)

Example 10 with PointInSetQuery

use of org.apache.lucene.search.PointInSetQuery in project lucene-solr by apache.

the class BigIntegerPoint method newSetQuery.

/**
   * Create a query matching any of the specified 1D values.  This is the points equivalent of {@code TermsQuery}.
   * 
   * @param field field name. must not be {@code null}.
   * @param values all values to match
   */
public static Query newSetQuery(String field, BigInteger... values) {
    // Don't unexpectedly change the user's incoming values array:
    BigInteger[] sortedValues = values.clone();
    Arrays.sort(sortedValues);
    final BytesRef encoded = new BytesRef(new byte[BYTES]);
    return new PointInSetQuery(field, 1, BYTES, new PointInSetQuery.Stream() {

        int upto;

        @Override
        public BytesRef next() {
            if (upto == sortedValues.length) {
                return null;
            } else {
                encodeDimension(sortedValues[upto], encoded.bytes, 0);
                upto++;
                return encoded;
            }
        }
    }) {

        @Override
        protected String toString(byte[] value) {
            assert value.length == BYTES;
            return decodeDimension(value, 0).toString();
        }
    };
}
Also used : PointInSetQuery(org.apache.lucene.search.PointInSetQuery) BigInteger(java.math.BigInteger) BytesRef(org.apache.lucene.util.BytesRef)

Aggregations

PointInSetQuery (org.apache.lucene.search.PointInSetQuery)10 BytesRef (org.apache.lucene.util.BytesRef)9 HashMap (java.util.HashMap)2 MatchNoDocsQuery (org.apache.lucene.search.MatchNoDocsQuery)2 Query (org.apache.lucene.search.Query)2 IOException (java.io.IOException)1 BigInteger (java.math.BigInteger)1 Iterator (java.util.Iterator)1 Locale (java.util.Locale)1 Map (java.util.Map)1 TreeSet (java.util.TreeSet)1 BiConsumer (java.util.function.BiConsumer)1 LongFunction (java.util.function.LongFunction)1 DoublePoint (org.apache.lucene.document.DoublePoint)1 FloatPoint (org.apache.lucene.document.FloatPoint)1 IntPoint (org.apache.lucene.document.IntPoint)1 LongPoint (org.apache.lucene.document.LongPoint)1 BinaryDocValues (org.apache.lucene.index.BinaryDocValues)1 DocValues (org.apache.lucene.index.DocValues)1 DocValuesType (org.apache.lucene.index.DocValuesType)1