use of org.opensearch.common.collect.Tuple in project OpenSearch by opensearch-project.
the class InternalTopHitsTests method assertReduced.
@Override
protected void assertReduced(InternalTopHits reduced, List<InternalTopHits> inputs) {
boolean sortedByFields = inputs.get(0).getTopDocs().topDocs instanceof TopFieldDocs;
Comparator<ScoreDoc> dataNodeComparator;
if (sortedByFields) {
dataNodeComparator = sortFieldsComparator(((TopFieldDocs) inputs.get(0).getTopDocs().topDocs).fields);
} else {
dataNodeComparator = scoreComparator();
}
Comparator<ScoreDoc> reducedComparator = dataNodeComparator.thenComparing(s -> s.shardIndex);
SearchHits actualHits = reduced.getHits();
List<Tuple<ScoreDoc, SearchHit>> allHits = new ArrayList<>();
float maxScore = Float.NEGATIVE_INFINITY;
long totalHits = 0;
TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
for (int input = 0; input < inputs.size(); input++) {
SearchHits internalHits = inputs.get(input).getHits();
totalHits += internalHits.getTotalHits().value;
if (internalHits.getTotalHits().relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) {
relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
}
maxScore = max(maxScore, internalHits.getMaxScore());
for (int i = 0; i < internalHits.getHits().length; i++) {
ScoreDoc doc = inputs.get(input).getTopDocs().topDocs.scoreDocs[i];
if (sortedByFields) {
doc = new FieldDoc(doc.doc, doc.score, ((FieldDoc) doc).fields, input);
} else {
doc = new ScoreDoc(doc.doc, doc.score, input);
}
allHits.add(new Tuple<>(doc, internalHits.getHits()[i]));
}
}
allHits.sort(comparing(Tuple::v1, reducedComparator));
SearchHit[] expectedHitsHits = new SearchHit[min(inputs.get(0).getSize(), allHits.size())];
for (int i = 0; i < expectedHitsHits.length; i++) {
expectedHitsHits[i] = allHits.get(i).v2();
}
// Lucene's TopDocs initializes the maxScore to Float.NaN, if there is no maxScore
SearchHits expectedHits = new SearchHits(expectedHitsHits, new TotalHits(totalHits, relation), maxScore == Float.NEGATIVE_INFINITY ? Float.NaN : maxScore);
assertEqualsWithErrorMessageFromXContent(expectedHits, actualHits);
}
use of org.opensearch.common.collect.Tuple in project OpenSearch by opensearch-project.
the class S3HttpHandler method handle.
@Override
public void handle(final HttpExchange exchange) throws IOException {
final String request = exchange.getRequestMethod() + " " + exchange.getRequestURI().toString();
if (request.startsWith("GET") || request.startsWith("HEAD") || request.startsWith("DELETE")) {
int read = exchange.getRequestBody().read();
assert read == -1 : "Request body should have been empty but saw [" + read + "]";
}
try {
if (Regex.simpleMatch("HEAD /" + path + "/*", request)) {
final BytesReference blob = blobs.get(exchange.getRequestURI().getPath());
if (blob == null) {
exchange.sendResponseHeaders(RestStatus.NOT_FOUND.getStatus(), -1);
} else {
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), -1);
}
} else if (Regex.simpleMatch("POST /" + path + "/*?uploads", request)) {
final String uploadId = UUIDs.randomBase64UUID();
byte[] response = ("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" + "<InitiateMultipartUploadResult>\n" + " <Bucket>" + bucket + "</Bucket>\n" + " <Key>" + exchange.getRequestURI().getPath() + "</Key>\n" + " <UploadId>" + uploadId + "</UploadId>\n" + "</InitiateMultipartUploadResult>").getBytes(StandardCharsets.UTF_8);
blobs.put(multipartKey(uploadId, 0), BytesArray.EMPTY);
exchange.getResponseHeaders().add("Content-Type", "application/xml");
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), response.length);
exchange.getResponseBody().write(response);
} else if (Regex.simpleMatch("PUT /" + path + "/*?uploadId=*&partNumber=*", request)) {
final Map<String, String> params = new HashMap<>();
RestUtils.decodeQueryString(exchange.getRequestURI().getQuery(), 0, params);
final String uploadId = params.get("uploadId");
if (blobs.containsKey(multipartKey(uploadId, 0))) {
final Tuple<String, BytesReference> blob = parseRequestBody(exchange);
final int partNumber = Integer.parseInt(params.get("partNumber"));
blobs.put(multipartKey(uploadId, partNumber), blob.v2());
exchange.getResponseHeaders().add("ETag", blob.v1());
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), -1);
} else {
exchange.sendResponseHeaders(RestStatus.NOT_FOUND.getStatus(), -1);
}
} else if (Regex.simpleMatch("POST /" + path + "/*?uploadId=*", request)) {
Streams.readFully(exchange.getRequestBody());
final Map<String, String> params = new HashMap<>();
RestUtils.decodeQueryString(exchange.getRequestURI().getQuery(), 0, params);
final String uploadId = params.get("uploadId");
final int nbParts = blobs.keySet().stream().filter(blobName -> blobName.startsWith(uploadId)).map(blobName -> blobName.replaceFirst(uploadId + '\n', "")).mapToInt(Integer::parseInt).max().orElse(0);
final ByteArrayOutputStream blob = new ByteArrayOutputStream();
for (int partNumber = 0; partNumber <= nbParts; partNumber++) {
BytesReference part = blobs.remove(multipartKey(uploadId, partNumber));
if (part == null) {
throw new AssertionError("Upload part is null");
}
part.writeTo(blob);
}
blobs.put(exchange.getRequestURI().getPath(), new BytesArray(blob.toByteArray()));
byte[] response = ("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" + "<CompleteMultipartUploadResult>\n" + " <Bucket>" + bucket + "</Bucket>\n" + " <Key>" + exchange.getRequestURI().getPath() + "</Key>\n" + "</CompleteMultipartUploadResult>").getBytes(StandardCharsets.UTF_8);
exchange.getResponseHeaders().add("Content-Type", "application/xml");
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), response.length);
exchange.getResponseBody().write(response);
} else if (Regex.simpleMatch("PUT /" + path + "/*", request)) {
final Tuple<String, BytesReference> blob = parseRequestBody(exchange);
blobs.put(exchange.getRequestURI().toString(), blob.v2());
exchange.getResponseHeaders().add("ETag", blob.v1());
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), -1);
} else if (Regex.simpleMatch("GET /" + bucket + "/?prefix=*", request)) {
final Map<String, String> params = new HashMap<>();
RestUtils.decodeQueryString(exchange.getRequestURI().getQuery(), 0, params);
if (params.get("list-type") != null) {
throw new AssertionError("Test must be adapted for GET Bucket (List Objects) Version 2");
}
final StringBuilder list = new StringBuilder();
list.append("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
list.append("<ListBucketResult>");
final String prefix = params.get("prefix");
if (prefix != null) {
list.append("<Prefix>").append(prefix).append("</Prefix>");
}
final Set<String> commonPrefixes = new HashSet<>();
final String delimiter = params.get("delimiter");
if (delimiter != null) {
list.append("<Delimiter>").append(delimiter).append("</Delimiter>");
}
for (Map.Entry<String, BytesReference> blob : blobs.entrySet()) {
if (prefix != null && blob.getKey().startsWith("/" + bucket + "/" + prefix) == false) {
continue;
}
String blobPath = blob.getKey().replace("/" + bucket + "/", "");
if (delimiter != null) {
int fromIndex = (prefix != null ? prefix.length() : 0);
int delimiterPosition = blobPath.indexOf(delimiter, fromIndex);
if (delimiterPosition > 0) {
commonPrefixes.add(blobPath.substring(0, delimiterPosition) + delimiter);
continue;
}
}
list.append("<Contents>");
list.append("<Key>").append(blobPath).append("</Key>");
list.append("<Size>").append(blob.getValue().length()).append("</Size>");
list.append("</Contents>");
}
if (commonPrefixes.isEmpty() == false) {
list.append("<CommonPrefixes>");
commonPrefixes.forEach(commonPrefix -> list.append("<Prefix>").append(commonPrefix).append("</Prefix>"));
list.append("</CommonPrefixes>");
}
list.append("</ListBucketResult>");
byte[] response = list.toString().getBytes(StandardCharsets.UTF_8);
exchange.getResponseHeaders().add("Content-Type", "application/xml");
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), response.length);
exchange.getResponseBody().write(response);
} else if (Regex.simpleMatch("GET /" + path + "/*", request)) {
final BytesReference blob = blobs.get(exchange.getRequestURI().toString());
if (blob != null) {
final String range = exchange.getRequestHeaders().getFirst("Range");
if (range == null) {
exchange.getResponseHeaders().add("Content-Type", "application/octet-stream");
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), blob.length());
blob.writeTo(exchange.getResponseBody());
} else {
final Matcher matcher = Pattern.compile("^bytes=([0-9]+)-([0-9]+)$").matcher(range);
if (matcher.matches() == false) {
throw new AssertionError("Bytes range does not match expected pattern: " + range);
}
final int start = Integer.parseInt(matcher.group(1));
final int end = Integer.parseInt(matcher.group(2));
final BytesReference rangeBlob = blob.slice(start, end + 1 - start);
exchange.getResponseHeaders().add("Content-Type", "application/octet-stream");
exchange.getResponseHeaders().add("Content-Range", String.format(Locale.ROOT, "bytes %d-%d/%d", start, end, rangeBlob.length()));
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), rangeBlob.length());
rangeBlob.writeTo(exchange.getResponseBody());
}
} else {
exchange.sendResponseHeaders(RestStatus.NOT_FOUND.getStatus(), -1);
}
} else if (Regex.simpleMatch("DELETE /" + path + "/*", request)) {
int deletions = 0;
for (Iterator<Map.Entry<String, BytesReference>> iterator = blobs.entrySet().iterator(); iterator.hasNext(); ) {
Map.Entry<String, BytesReference> blob = iterator.next();
if (blob.getKey().startsWith(exchange.getRequestURI().toString())) {
iterator.remove();
deletions++;
}
}
exchange.sendResponseHeaders((deletions > 0 ? RestStatus.OK : RestStatus.NO_CONTENT).getStatus(), -1);
} else if (Regex.simpleMatch("POST /" + bucket + "/?delete", request)) {
final String requestBody = Streams.copyToString(new InputStreamReader(exchange.getRequestBody(), UTF_8));
final StringBuilder deletes = new StringBuilder();
deletes.append("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
deletes.append("<DeleteResult>");
for (Iterator<Map.Entry<String, BytesReference>> iterator = blobs.entrySet().iterator(); iterator.hasNext(); ) {
Map.Entry<String, BytesReference> blob = iterator.next();
String key = blob.getKey().replace("/" + path + "/", "");
if (requestBody.contains("<Key>" + key + "</Key>")) {
deletes.append("<Deleted><Key>").append(key).append("</Key></Deleted>");
iterator.remove();
}
}
deletes.append("</DeleteResult>");
byte[] response = deletes.toString().getBytes(StandardCharsets.UTF_8);
exchange.getResponseHeaders().add("Content-Type", "application/xml");
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), response.length);
exchange.getResponseBody().write(response);
} else {
exchange.sendResponseHeaders(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), -1);
}
} finally {
exchange.close();
}
}
use of org.opensearch.common.collect.Tuple in project OpenSearch by opensearch-project.
the class LinearizabilityChecker method isLinearizable.
private boolean isLinearizable(SequentialSpec spec, List<Event> history, BooleanSupplier terminateEarly) {
logger.debug("Checking history of size: {}: {}", history.size(), history);
// the current state of the datatype
Object state = spec.initialState();
// the linearized prefix of the history
final FixedBitSet linearized = new FixedBitSet(history.size() / 2);
// cache of explored <state, linearized prefix> pairs
final Cache cache = new Cache();
// path we're currently exploring
final Deque<Tuple<Entry, Object>> calls = new LinkedList<>();
final Entry headEntry = createLinkedEntries(history);
// current entry
Entry entry = headEntry.next;
while (headEntry.next != null) {
if (terminateEarly.getAsBoolean()) {
return false;
}
if (entry.match != null) {
final Optional<Object> maybeNextState = spec.nextState(state, entry.event.value, entry.match.event.value);
boolean shouldExploreNextState = false;
if (maybeNextState.isPresent()) {
// check if we have already explored this linearization
final FixedBitSet updatedLinearized = linearized.clone();
updatedLinearized.set(entry.id);
shouldExploreNextState = cache.add(maybeNextState.get(), updatedLinearized);
}
if (shouldExploreNextState) {
calls.push(new Tuple<>(entry, state));
state = maybeNextState.get();
linearized.set(entry.id);
entry.lift();
entry = headEntry.next;
} else {
entry = entry.next;
}
} else {
if (calls.isEmpty()) {
return false;
}
final Tuple<Entry, Object> top = calls.pop();
entry = top.v1();
state = top.v2();
linearized.clear(entry.id);
entry.unlift();
entry = entry.next;
}
}
return true;
}
use of org.opensearch.common.collect.Tuple in project ml-commons by opensearch-project.
the class TribuoUtil method generateDatasetWithTarget.
/**
* Generate tribuo dataset from data frame with target.
* @param dataFrame features data
* @param outputFactory the tribuo output factory
* @param desc description for tribuo provenance
* @param outputType the tribuo output type
* @param target target name
* @return tribuo dataset
*/
public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType, String target) {
if (StringUtils.isEmpty(target)) {
throw new IllegalArgumentException("Empty target when generating dataset from data frame.");
}
List<Example<T>> dataset = new ArrayList<>();
Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
int targetIndex = -1;
for (int i = 0; i < featureNamesValues.v1().length; ++i) {
if (featureNamesValues.v1()[i].equals(target)) {
targetIndex = i;
break;
}
}
if (targetIndex == -1) {
throw new IllegalArgumentException("No matched target when generating dataset from data frame.");
}
ArrayExample<T> example;
final int finalTargetIndex = targetIndex;
String[] featureNames = IntStream.range(0, featureNamesValues.v1().length).filter(e -> e != finalTargetIndex).mapToObj(e -> featureNamesValues.v1()[e]).toArray(String[]::new);
for (int i = 0; i < dataFrame.size(); ++i) {
switch(outputType) {
case REGRESSOR:
final int finalI = i;
double targetValue = featureNamesValues.v2()[finalI][finalTargetIndex];
double[] featureValues = IntStream.range(0, featureNamesValues.v2()[i].length).filter(e -> e != finalTargetIndex).mapToDouble(e -> featureNamesValues.v2()[finalI][e]).toArray();
example = new ArrayExample<>((T) new Regressor(target, targetValue), featureNames, featureValues);
break;
default:
throw new IllegalArgumentException("unknown type:" + outputType);
}
dataset.add(example);
}
SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
use of org.opensearch.common.collect.Tuple in project ml-commons by opensearch-project.
the class TribuoUtil method transformDataFrame.
public static Tuple<String[], double[][]> transformDataFrame(DataFrame dataFrame) {
String[] featureNames = Arrays.stream(dataFrame.columnMetas()).map(ColumnMeta::getName).toArray(String[]::new);
double[][] featureValues = new double[dataFrame.size()][];
Iterator<Row> itr = dataFrame.iterator();
int i = 0;
while (itr.hasNext()) {
Row row = itr.next();
featureValues[i] = StreamSupport.stream(row.spliterator(), false).mapToDouble(ColumnValue::doubleValue).toArray();
++i;
}
return new Tuple<>(featureNames, featureValues);
}
Aggregations