Search in sources :

Example 26 with Tuple

use of org.opensearch.common.collect.Tuple in project OpenSearch by opensearch-project.

the class InternalTopHitsTests method assertReduced.

@Override
protected void assertReduced(InternalTopHits reduced, List<InternalTopHits> inputs) {
    boolean sortedByFields = inputs.get(0).getTopDocs().topDocs instanceof TopFieldDocs;
    Comparator<ScoreDoc> dataNodeComparator;
    if (sortedByFields) {
        dataNodeComparator = sortFieldsComparator(((TopFieldDocs) inputs.get(0).getTopDocs().topDocs).fields);
    } else {
        dataNodeComparator = scoreComparator();
    }
    Comparator<ScoreDoc> reducedComparator = dataNodeComparator.thenComparing(s -> s.shardIndex);
    SearchHits actualHits = reduced.getHits();
    List<Tuple<ScoreDoc, SearchHit>> allHits = new ArrayList<>();
    float maxScore = Float.NEGATIVE_INFINITY;
    long totalHits = 0;
    TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
    for (int input = 0; input < inputs.size(); input++) {
        SearchHits internalHits = inputs.get(input).getHits();
        totalHits += internalHits.getTotalHits().value;
        if (internalHits.getTotalHits().relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) {
            relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
        }
        maxScore = max(maxScore, internalHits.getMaxScore());
        for (int i = 0; i < internalHits.getHits().length; i++) {
            ScoreDoc doc = inputs.get(input).getTopDocs().topDocs.scoreDocs[i];
            if (sortedByFields) {
                doc = new FieldDoc(doc.doc, doc.score, ((FieldDoc) doc).fields, input);
            } else {
                doc = new ScoreDoc(doc.doc, doc.score, input);
            }
            allHits.add(new Tuple<>(doc, internalHits.getHits()[i]));
        }
    }
    allHits.sort(comparing(Tuple::v1, reducedComparator));
    SearchHit[] expectedHitsHits = new SearchHit[min(inputs.get(0).getSize(), allHits.size())];
    for (int i = 0; i < expectedHitsHits.length; i++) {
        expectedHitsHits[i] = allHits.get(i).v2();
    }
    // Lucene's TopDocs initializes the maxScore to Float.NaN, if there is no maxScore
    SearchHits expectedHits = new SearchHits(expectedHitsHits, new TotalHits(totalHits, relation), maxScore == Float.NEGATIVE_INFINITY ? Float.NaN : maxScore);
    assertEqualsWithErrorMessageFromXContent(expectedHits, actualHits);
}
Also used : TotalHits(org.apache.lucene.search.TotalHits) FieldDoc(org.apache.lucene.search.FieldDoc) SearchHit(org.opensearch.search.SearchHit) ArrayList(java.util.ArrayList) TopFieldDocs(org.apache.lucene.search.TopFieldDocs) ScoreDoc(org.apache.lucene.search.ScoreDoc) SearchHits(org.opensearch.search.SearchHits) Tuple(org.opensearch.common.collect.Tuple)

Example 27 with Tuple

use of org.opensearch.common.collect.Tuple in project OpenSearch by opensearch-project.

the class S3HttpHandler method handle.

@Override
public void handle(final HttpExchange exchange) throws IOException {
    final String request = exchange.getRequestMethod() + " " + exchange.getRequestURI().toString();
    if (request.startsWith("GET") || request.startsWith("HEAD") || request.startsWith("DELETE")) {
        int read = exchange.getRequestBody().read();
        assert read == -1 : "Request body should have been empty but saw [" + read + "]";
    }
    try {
        if (Regex.simpleMatch("HEAD /" + path + "/*", request)) {
            final BytesReference blob = blobs.get(exchange.getRequestURI().getPath());
            if (blob == null) {
                exchange.sendResponseHeaders(RestStatus.NOT_FOUND.getStatus(), -1);
            } else {
                exchange.sendResponseHeaders(RestStatus.OK.getStatus(), -1);
            }
        } else if (Regex.simpleMatch("POST /" + path + "/*?uploads", request)) {
            final String uploadId = UUIDs.randomBase64UUID();
            byte[] response = ("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" + "<InitiateMultipartUploadResult>\n" + "  <Bucket>" + bucket + "</Bucket>\n" + "  <Key>" + exchange.getRequestURI().getPath() + "</Key>\n" + "  <UploadId>" + uploadId + "</UploadId>\n" + "</InitiateMultipartUploadResult>").getBytes(StandardCharsets.UTF_8);
            blobs.put(multipartKey(uploadId, 0), BytesArray.EMPTY);
            exchange.getResponseHeaders().add("Content-Type", "application/xml");
            exchange.sendResponseHeaders(RestStatus.OK.getStatus(), response.length);
            exchange.getResponseBody().write(response);
        } else if (Regex.simpleMatch("PUT /" + path + "/*?uploadId=*&partNumber=*", request)) {
            final Map<String, String> params = new HashMap<>();
            RestUtils.decodeQueryString(exchange.getRequestURI().getQuery(), 0, params);
            final String uploadId = params.get("uploadId");
            if (blobs.containsKey(multipartKey(uploadId, 0))) {
                final Tuple<String, BytesReference> blob = parseRequestBody(exchange);
                final int partNumber = Integer.parseInt(params.get("partNumber"));
                blobs.put(multipartKey(uploadId, partNumber), blob.v2());
                exchange.getResponseHeaders().add("ETag", blob.v1());
                exchange.sendResponseHeaders(RestStatus.OK.getStatus(), -1);
            } else {
                exchange.sendResponseHeaders(RestStatus.NOT_FOUND.getStatus(), -1);
            }
        } else if (Regex.simpleMatch("POST /" + path + "/*?uploadId=*", request)) {
            Streams.readFully(exchange.getRequestBody());
            final Map<String, String> params = new HashMap<>();
            RestUtils.decodeQueryString(exchange.getRequestURI().getQuery(), 0, params);
            final String uploadId = params.get("uploadId");
            final int nbParts = blobs.keySet().stream().filter(blobName -> blobName.startsWith(uploadId)).map(blobName -> blobName.replaceFirst(uploadId + '\n', "")).mapToInt(Integer::parseInt).max().orElse(0);
            final ByteArrayOutputStream blob = new ByteArrayOutputStream();
            for (int partNumber = 0; partNumber <= nbParts; partNumber++) {
                BytesReference part = blobs.remove(multipartKey(uploadId, partNumber));
                if (part == null) {
                    throw new AssertionError("Upload part is null");
                }
                part.writeTo(blob);
            }
            blobs.put(exchange.getRequestURI().getPath(), new BytesArray(blob.toByteArray()));
            byte[] response = ("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" + "<CompleteMultipartUploadResult>\n" + "  <Bucket>" + bucket + "</Bucket>\n" + "  <Key>" + exchange.getRequestURI().getPath() + "</Key>\n" + "</CompleteMultipartUploadResult>").getBytes(StandardCharsets.UTF_8);
            exchange.getResponseHeaders().add("Content-Type", "application/xml");
            exchange.sendResponseHeaders(RestStatus.OK.getStatus(), response.length);
            exchange.getResponseBody().write(response);
        } else if (Regex.simpleMatch("PUT /" + path + "/*", request)) {
            final Tuple<String, BytesReference> blob = parseRequestBody(exchange);
            blobs.put(exchange.getRequestURI().toString(), blob.v2());
            exchange.getResponseHeaders().add("ETag", blob.v1());
            exchange.sendResponseHeaders(RestStatus.OK.getStatus(), -1);
        } else if (Regex.simpleMatch("GET /" + bucket + "/?prefix=*", request)) {
            final Map<String, String> params = new HashMap<>();
            RestUtils.decodeQueryString(exchange.getRequestURI().getQuery(), 0, params);
            if (params.get("list-type") != null) {
                throw new AssertionError("Test must be adapted for GET Bucket (List Objects) Version 2");
            }
            final StringBuilder list = new StringBuilder();
            list.append("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
            list.append("<ListBucketResult>");
            final String prefix = params.get("prefix");
            if (prefix != null) {
                list.append("<Prefix>").append(prefix).append("</Prefix>");
            }
            final Set<String> commonPrefixes = new HashSet<>();
            final String delimiter = params.get("delimiter");
            if (delimiter != null) {
                list.append("<Delimiter>").append(delimiter).append("</Delimiter>");
            }
            for (Map.Entry<String, BytesReference> blob : blobs.entrySet()) {
                if (prefix != null && blob.getKey().startsWith("/" + bucket + "/" + prefix) == false) {
                    continue;
                }
                String blobPath = blob.getKey().replace("/" + bucket + "/", "");
                if (delimiter != null) {
                    int fromIndex = (prefix != null ? prefix.length() : 0);
                    int delimiterPosition = blobPath.indexOf(delimiter, fromIndex);
                    if (delimiterPosition > 0) {
                        commonPrefixes.add(blobPath.substring(0, delimiterPosition) + delimiter);
                        continue;
                    }
                }
                list.append("<Contents>");
                list.append("<Key>").append(blobPath).append("</Key>");
                list.append("<Size>").append(blob.getValue().length()).append("</Size>");
                list.append("</Contents>");
            }
            if (commonPrefixes.isEmpty() == false) {
                list.append("<CommonPrefixes>");
                commonPrefixes.forEach(commonPrefix -> list.append("<Prefix>").append(commonPrefix).append("</Prefix>"));
                list.append("</CommonPrefixes>");
            }
            list.append("</ListBucketResult>");
            byte[] response = list.toString().getBytes(StandardCharsets.UTF_8);
            exchange.getResponseHeaders().add("Content-Type", "application/xml");
            exchange.sendResponseHeaders(RestStatus.OK.getStatus(), response.length);
            exchange.getResponseBody().write(response);
        } else if (Regex.simpleMatch("GET /" + path + "/*", request)) {
            final BytesReference blob = blobs.get(exchange.getRequestURI().toString());
            if (blob != null) {
                final String range = exchange.getRequestHeaders().getFirst("Range");
                if (range == null) {
                    exchange.getResponseHeaders().add("Content-Type", "application/octet-stream");
                    exchange.sendResponseHeaders(RestStatus.OK.getStatus(), blob.length());
                    blob.writeTo(exchange.getResponseBody());
                } else {
                    final Matcher matcher = Pattern.compile("^bytes=([0-9]+)-([0-9]+)$").matcher(range);
                    if (matcher.matches() == false) {
                        throw new AssertionError("Bytes range does not match expected pattern: " + range);
                    }
                    final int start = Integer.parseInt(matcher.group(1));
                    final int end = Integer.parseInt(matcher.group(2));
                    final BytesReference rangeBlob = blob.slice(start, end + 1 - start);
                    exchange.getResponseHeaders().add("Content-Type", "application/octet-stream");
                    exchange.getResponseHeaders().add("Content-Range", String.format(Locale.ROOT, "bytes %d-%d/%d", start, end, rangeBlob.length()));
                    exchange.sendResponseHeaders(RestStatus.OK.getStatus(), rangeBlob.length());
                    rangeBlob.writeTo(exchange.getResponseBody());
                }
            } else {
                exchange.sendResponseHeaders(RestStatus.NOT_FOUND.getStatus(), -1);
            }
        } else if (Regex.simpleMatch("DELETE /" + path + "/*", request)) {
            int deletions = 0;
            for (Iterator<Map.Entry<String, BytesReference>> iterator = blobs.entrySet().iterator(); iterator.hasNext(); ) {
                Map.Entry<String, BytesReference> blob = iterator.next();
                if (blob.getKey().startsWith(exchange.getRequestURI().toString())) {
                    iterator.remove();
                    deletions++;
                }
            }
            exchange.sendResponseHeaders((deletions > 0 ? RestStatus.OK : RestStatus.NO_CONTENT).getStatus(), -1);
        } else if (Regex.simpleMatch("POST /" + bucket + "/?delete", request)) {
            final String requestBody = Streams.copyToString(new InputStreamReader(exchange.getRequestBody(), UTF_8));
            final StringBuilder deletes = new StringBuilder();
            deletes.append("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
            deletes.append("<DeleteResult>");
            for (Iterator<Map.Entry<String, BytesReference>> iterator = blobs.entrySet().iterator(); iterator.hasNext(); ) {
                Map.Entry<String, BytesReference> blob = iterator.next();
                String key = blob.getKey().replace("/" + path + "/", "");
                if (requestBody.contains("<Key>" + key + "</Key>")) {
                    deletes.append("<Deleted><Key>").append(key).append("</Key></Deleted>");
                    iterator.remove();
                }
            }
            deletes.append("</DeleteResult>");
            byte[] response = deletes.toString().getBytes(StandardCharsets.UTF_8);
            exchange.getResponseHeaders().add("Content-Type", "application/xml");
            exchange.sendResponseHeaders(RestStatus.OK.getStatus(), response.length);
            exchange.getResponseBody().write(response);
        } else {
            exchange.sendResponseHeaders(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), -1);
        }
    } finally {
        exchange.close();
    }
}
Also used : BytesReference(org.opensearch.common.bytes.BytesReference) CheckedInputStream(java.util.zip.CheckedInputStream) BufferedInputStream(java.io.BufferedInputStream) BytesReference(org.opensearch.common.bytes.BytesReference) ByteArrayOutputStream(java.io.ByteArrayOutputStream) MessageDigest(java.security.MessageDigest) HashMap(java.util.HashMap) Regex(org.opensearch.common.regex.Regex) ConcurrentMap(java.util.concurrent.ConcurrentMap) HashSet(java.util.HashSet) Checksum(java.util.zip.Checksum) Matcher(java.util.regex.Matcher) Streams(org.opensearch.common.io.Streams) Locale(java.util.Locale) Map(java.util.Map) Headers(com.sun.net.httpserver.Headers) UUIDs(org.opensearch.common.UUIDs) SuppressForbidden(org.opensearch.common.SuppressForbidden) Iterator(java.util.Iterator) UTF_8(java.nio.charset.StandardCharsets.UTF_8) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) RestUtils(org.opensearch.rest.RestUtils) Set(java.util.Set) IOException(java.io.IOException) RestStatus(org.opensearch.rest.RestStatus) InputStreamReader(java.io.InputStreamReader) Nullable(org.opensearch.common.Nullable) StandardCharsets(java.nio.charset.StandardCharsets) Tuple(org.opensearch.common.collect.Tuple) Objects(java.util.Objects) MessageDigests(org.opensearch.common.hash.MessageDigests) BytesArray(org.opensearch.common.bytes.BytesArray) HttpHandler(com.sun.net.httpserver.HttpHandler) HttpExchange(com.sun.net.httpserver.HttpExchange) Pattern(java.util.regex.Pattern) InputStream(java.io.InputStream) BytesArray(org.opensearch.common.bytes.BytesArray) HashSet(java.util.HashSet) Set(java.util.Set) InputStreamReader(java.io.InputStreamReader) HashMap(java.util.HashMap) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) Matcher(java.util.regex.Matcher) ByteArrayOutputStream(java.io.ByteArrayOutputStream) Iterator(java.util.Iterator) HashMap(java.util.HashMap) ConcurrentMap(java.util.concurrent.ConcurrentMap) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap)

Example 28 with Tuple

use of org.opensearch.common.collect.Tuple in project OpenSearch by opensearch-project.

the class LinearizabilityChecker method isLinearizable.

private boolean isLinearizable(SequentialSpec spec, List<Event> history, BooleanSupplier terminateEarly) {
    logger.debug("Checking history of size: {}: {}", history.size(), history);
    // the current state of the datatype
    Object state = spec.initialState();
    // the linearized prefix of the history
    final FixedBitSet linearized = new FixedBitSet(history.size() / 2);
    // cache of explored <state, linearized prefix> pairs
    final Cache cache = new Cache();
    // path we're currently exploring
    final Deque<Tuple<Entry, Object>> calls = new LinkedList<>();
    final Entry headEntry = createLinkedEntries(history);
    // current entry
    Entry entry = headEntry.next;
    while (headEntry.next != null) {
        if (terminateEarly.getAsBoolean()) {
            return false;
        }
        if (entry.match != null) {
            final Optional<Object> maybeNextState = spec.nextState(state, entry.event.value, entry.match.event.value);
            boolean shouldExploreNextState = false;
            if (maybeNextState.isPresent()) {
                // check if we have already explored this linearization
                final FixedBitSet updatedLinearized = linearized.clone();
                updatedLinearized.set(entry.id);
                shouldExploreNextState = cache.add(maybeNextState.get(), updatedLinearized);
            }
            if (shouldExploreNextState) {
                calls.push(new Tuple<>(entry, state));
                state = maybeNextState.get();
                linearized.set(entry.id);
                entry.lift();
                entry = headEntry.next;
            } else {
                entry = entry.next;
            }
        } else {
            if (calls.isEmpty()) {
                return false;
            }
            final Tuple<Entry, Object> top = calls.pop();
            entry = top.v1();
            state = top.v2();
            linearized.clear(entry.id);
            entry.unlift();
            entry = entry.next;
        }
    }
    return true;
}
Also used : FixedBitSet(org.apache.lucene.util.FixedBitSet) Tuple(org.opensearch.common.collect.Tuple) LinkedList(java.util.LinkedList)

Example 29 with Tuple

use of org.opensearch.common.collect.Tuple in project ml-commons by opensearch-project.

the class TribuoUtil method generateDatasetWithTarget.

/**
 * Generate tribuo dataset from data frame with target.
 * @param dataFrame features data
 * @param outputFactory the tribuo output factory
 * @param desc description for tribuo provenance
 * @param outputType the tribuo output type
 * @param target target name
 * @return tribuo dataset
 */
public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType, String target) {
    if (StringUtils.isEmpty(target)) {
        throw new IllegalArgumentException("Empty target when generating dataset from data frame.");
    }
    List<Example<T>> dataset = new ArrayList<>();
    Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
    int targetIndex = -1;
    for (int i = 0; i < featureNamesValues.v1().length; ++i) {
        if (featureNamesValues.v1()[i].equals(target)) {
            targetIndex = i;
            break;
        }
    }
    if (targetIndex == -1) {
        throw new IllegalArgumentException("No matched target when generating dataset from data frame.");
    }
    ArrayExample<T> example;
    final int finalTargetIndex = targetIndex;
    String[] featureNames = IntStream.range(0, featureNamesValues.v1().length).filter(e -> e != finalTargetIndex).mapToObj(e -> featureNamesValues.v1()[e]).toArray(String[]::new);
    for (int i = 0; i < dataFrame.size(); ++i) {
        switch(outputType) {
            case REGRESSOR:
                final int finalI = i;
                double targetValue = featureNamesValues.v2()[finalI][finalTargetIndex];
                double[] featureValues = IntStream.range(0, featureNamesValues.v2()[i].length).filter(e -> e != finalTargetIndex).mapToDouble(e -> featureNamesValues.v2()[finalI][e]).toArray();
                example = new ArrayExample<>((T) new Regressor(target, targetValue), featureNames, featureValues);
                break;
            default:
                throw new IllegalArgumentException("unknown type:" + outputType);
        }
        dataset.add(example);
    }
    SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
    return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
Also used : IntStream(java.util.stream.IntStream) Example(org.tribuo.Example) Arrays(java.util.Arrays) Row(org.opensearch.ml.common.dataframe.Row) Iterator(java.util.Iterator) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) ClusterID(org.tribuo.clustering.ClusterID) StringUtils(org.apache.commons.lang3.StringUtils) OutputFactory(org.tribuo.OutputFactory) Event(org.tribuo.anomaly.Event) Tuple(org.opensearch.common.collect.Tuple) ArrayList(java.util.ArrayList) UtilityClass(lombok.experimental.UtilityClass) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) List(java.util.List) TribuoOutputType(org.opensearch.ml.engine.contants.TribuoOutputType) Output(org.tribuo.Output) ListDataSource(org.tribuo.datasource.ListDataSource) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) StreamSupport(java.util.stream.StreamSupport) ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) MutableDataset(org.tribuo.MutableDataset) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) ArrayList(java.util.ArrayList) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) MutableDataset(org.tribuo.MutableDataset)

Example 30 with Tuple

use of org.opensearch.common.collect.Tuple in project ml-commons by opensearch-project.

the class TribuoUtil method transformDataFrame.

public static Tuple<String[], double[][]> transformDataFrame(DataFrame dataFrame) {
    String[] featureNames = Arrays.stream(dataFrame.columnMetas()).map(ColumnMeta::getName).toArray(String[]::new);
    double[][] featureValues = new double[dataFrame.size()][];
    Iterator<Row> itr = dataFrame.iterator();
    int i = 0;
    while (itr.hasNext()) {
        Row row = itr.next();
        featureValues[i] = StreamSupport.stream(row.spliterator(), false).mapToDouble(ColumnValue::doubleValue).toArray();
        ++i;
    }
    return new Tuple<>(featureNames, featureValues);
}
Also used : ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) Row(org.opensearch.ml.common.dataframe.Row) Tuple(org.opensearch.common.collect.Tuple)

Aggregations

Tuple (org.opensearch.common.collect.Tuple)151 ArrayList (java.util.ArrayList)65 List (java.util.List)49 IOException (java.io.IOException)45 Collections (java.util.Collections)44 HashMap (java.util.HashMap)40 Map (java.util.Map)40 Settings (org.opensearch.common.settings.Settings)38 ClusterState (org.opensearch.cluster.ClusterState)34 HashSet (java.util.HashSet)28 ShardId (org.opensearch.index.shard.ShardId)28 Arrays (java.util.Arrays)27 Collectors (java.util.stream.Collectors)26 Set (java.util.Set)25 Index (org.opensearch.index.Index)25 BytesReference (org.opensearch.common.bytes.BytesReference)24 OpenSearchTestCase (org.opensearch.test.OpenSearchTestCase)24 CountDownLatch (java.util.concurrent.CountDownLatch)22 Version (org.opensearch.Version)21 Strings (org.opensearch.common.Strings)21