use of io.trino.operator.aggregation.TypedSet in project trino by trinodb.
the class MathFunctions method mapDotProduct.
private static double mapDotProduct(BlockPositionEqual varcharEqual, BlockPositionHashCode varcharHashCode, Block leftMap, Block rightMap) {
TypedSet rightMapKeys = createEqualityTypedSet(VARCHAR, varcharEqual, varcharHashCode, rightMap.getPositionCount(), "cosine_similarity");
for (int i = 0; i < rightMap.getPositionCount(); i += 2) {
rightMapKeys.add(rightMap, i);
}
double result = 0.0;
for (int i = 0; i < leftMap.getPositionCount(); i += 2) {
int position = rightMapKeys.positionOf(leftMap, i);
if (position != -1) {
result += DOUBLE.getDouble(leftMap, i + 1) * DOUBLE.getDouble(rightMap, 2 * position + 1);
}
}
return result;
}
use of io.trino.operator.aggregation.TypedSet in project trino by trinodb.
the class MultimapFromEntriesFunction method multimapFromEntries.
@TypeParameter("K")
@TypeParameter("V")
@SqlType("map(K,array(V))")
@SqlNullable
public Block multimapFromEntries(@TypeParameter("map(K,array(V))") MapType mapType, @OperatorDependency(operator = EQUAL, argumentTypes = { "K", "K" }, convention = @Convention(arguments = { BLOCK_POSITION, BLOCK_POSITION }, result = NULLABLE_RETURN)) BlockPositionEqual keyEqual, @OperatorDependency(operator = HASH_CODE, argumentTypes = "K", convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) BlockPositionHashCode keyHashCode, @SqlType("array(row(K,V))") Block mapEntries) {
Type keyType = mapType.getKeyType();
Type valueType = ((ArrayType) mapType.getValueType()).getElementType();
RowType mapEntryType = RowType.anonymous(ImmutableList.of(keyType, valueType));
if (pageBuilder.isFull()) {
pageBuilder.reset();
}
int entryCount = mapEntries.getPositionCount();
if (entryCount > entryIndicesList.length) {
initializeEntryIndicesList(entryCount);
}
TypedSet keySet = createEqualityTypedSet(keyType, keyEqual, keyHashCode, entryCount, NAME);
for (int i = 0; i < entryCount; i++) {
if (mapEntries.isNull(i)) {
clearEntryIndices(keySet.size());
throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "map entry cannot be null");
}
Block mapEntryBlock = mapEntryType.getObject(mapEntries, i);
if (mapEntryBlock.isNull(0)) {
clearEntryIndices(keySet.size());
throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "map key cannot be null");
}
if (keySet.add(mapEntryBlock, 0)) {
entryIndicesList[keySet.size() - 1].add(i);
} else {
entryIndicesList[keySet.positionOf(mapEntryBlock, 0)].add(i);
}
}
BlockBuilder multimapBlockBuilder = pageBuilder.getBlockBuilder(0);
BlockBuilder mapWriter = multimapBlockBuilder.beginBlockEntry();
for (int i = 0; i < keySet.size(); i++) {
keyType.appendTo(mapEntryType.getObject(mapEntries, entryIndicesList[i].getInt(0)), 0, mapWriter);
BlockBuilder valuesArray = mapWriter.beginBlockEntry();
for (int entryIndex : entryIndicesList[i]) {
valueType.appendTo(mapEntryType.getObject(mapEntries, entryIndex), 1, valuesArray);
}
mapWriter.closeEntry();
}
multimapBlockBuilder.closeEntry();
pageBuilder.declarePosition();
clearEntryIndices(keySet.size());
return mapType.getObject(multimapBlockBuilder, multimapBlockBuilder.getPositionCount() - 1);
}
use of io.trino.operator.aggregation.TypedSet in project trino by trinodb.
the class MapConcatFunction method mapConcat.
@UsedByGeneratedCode
public static Block mapConcat(MapType mapType, BlockPositionEqual keyEqual, BlockPositionHashCode keyHashCode, Object state, Block[] maps) {
int entries = 0;
int lastMapIndex = maps.length - 1;
int firstMapIndex = lastMapIndex;
for (int i = 0; i < maps.length; i++) {
entries += maps[i].getPositionCount();
if (maps[i].getPositionCount() > 0) {
lastMapIndex = i;
firstMapIndex = min(firstMapIndex, i);
}
}
if (lastMapIndex == firstMapIndex) {
return maps[lastMapIndex];
}
PageBuilder pageBuilder = (PageBuilder) state;
if (pageBuilder.isFull()) {
pageBuilder.reset();
}
// TODO: we should move TypedSet into user state as well
Type keyType = mapType.getKeyType();
Type valueType = mapType.getValueType();
TypedSet typedSet = createEqualityTypedSet(keyType, keyEqual, keyHashCode, entries / 2, FUNCTION_NAME);
BlockBuilder mapBlockBuilder = pageBuilder.getBlockBuilder(0);
BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry();
// the last map
Block map = maps[lastMapIndex];
for (int i = 0; i < map.getPositionCount(); i += 2) {
typedSet.add(map, i);
keyType.appendTo(map, i, blockBuilder);
valueType.appendTo(map, i + 1, blockBuilder);
}
// the map between the last and the first
for (int idx = lastMapIndex - 1; idx > firstMapIndex; idx--) {
map = maps[idx];
for (int i = 0; i < map.getPositionCount(); i += 2) {
if (typedSet.add(map, i)) {
keyType.appendTo(map, i, blockBuilder);
valueType.appendTo(map, i + 1, blockBuilder);
}
}
}
// the first map
map = maps[firstMapIndex];
for (int i = 0; i < map.getPositionCount(); i += 2) {
if (!typedSet.contains(map, i)) {
keyType.appendTo(map, i, blockBuilder);
valueType.appendTo(map, i + 1, blockBuilder);
}
}
mapBlockBuilder.closeEntry();
pageBuilder.declarePosition();
return mapType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1);
}
use of io.trino.operator.aggregation.TypedSet in project trino by trinodb.
the class MapToMapCast method mapCast.
@UsedByGeneratedCode
public static Block mapCast(MethodHandle keyProcessFunction, MethodHandle valueProcessFunction, Type targetType, BlockPositionEqual keyEqual, BlockPositionHashCode keyHashCode, ConnectorSession session, Block fromMap) {
checkState(targetType.getTypeParameters().size() == 2, "Expect two type parameters for targetType");
Type toKeyType = targetType.getTypeParameters().get(0);
TypedSet resultKeys = createEqualityTypedSet(toKeyType, keyEqual, keyHashCode, fromMap.getPositionCount() / 2, "map-to-map cast");
// Cast the keys into a new block
BlockBuilder keyBlockBuilder = toKeyType.createBlockBuilder(null, fromMap.getPositionCount() / 2);
for (int i = 0; i < fromMap.getPositionCount(); i += 2) {
try {
keyProcessFunction.invokeExact(fromMap, i, session, keyBlockBuilder);
} catch (Throwable t) {
throw internalError(t);
}
}
Block keyBlock = keyBlockBuilder.build();
BlockBuilder mapBlockBuilder = targetType.createBlockBuilder(null, 1);
BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry();
for (int i = 0; i < fromMap.getPositionCount(); i += 2) {
if (resultKeys.add(keyBlock, i / 2)) {
toKeyType.appendTo(keyBlock, i / 2, blockBuilder);
if (fromMap.isNull(i + 1)) {
blockBuilder.appendNull();
continue;
}
try {
valueProcessFunction.invokeExact(fromMap, i + 1, session, blockBuilder);
} catch (Throwable t) {
throw internalError(t);
}
} else {
// if there are duplicated keys, fail it!
throw new TrinoException(INVALID_CAST_ARGUMENT, "duplicate keys");
}
}
mapBlockBuilder.closeEntry();
return (Block) targetType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1);
}
use of io.trino.operator.aggregation.TypedSet in project trino by trinodb.
the class ArrayExceptFunction method except.
@TypeParameter("E")
@SqlType("array(E)")
public static Block except(@TypeParameter("E") Type type, @OperatorDependency(operator = EQUAL, argumentTypes = { "E", "E" }, convention = @Convention(arguments = { BLOCK_POSITION, BLOCK_POSITION }, result = NULLABLE_RETURN)) BlockPositionEqual elementEqual, @OperatorDependency(operator = HASH_CODE, argumentTypes = "E", convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) BlockPositionHashCode elementHashCode, @SqlType("array(E)") Block leftArray, @SqlType("array(E)") Block rightArray) {
int leftPositionCount = leftArray.getPositionCount();
int rightPositionCount = rightArray.getPositionCount();
if (leftPositionCount == 0 || rightPositionCount == 0) {
return leftArray;
}
TypedSet typedSet = createEqualityTypedSet(type, elementEqual, elementHashCode, leftPositionCount, "array_except");
BlockBuilder distinctElementBlockBuilder = type.createBlockBuilder(null, leftPositionCount);
for (int i = 0; i < rightPositionCount; i++) {
typedSet.add(rightArray, i);
}
for (int i = 0; i < leftPositionCount; i++) {
if (typedSet.add(leftArray, i)) {
type.appendTo(leftArray, i, distinctElementBlockBuilder);
}
}
return distinctElementBlockBuilder.build();
}
Aggregations