use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.
the class Reshape method reshape.
public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
if (!tensorSize(inputType).equals(tensorSize(outputType))) {
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
}
// Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
// then use the dimension order of the new shape to roll back into a tensor.
// Here we create a transformation tensor that is multiplied with the from tensor to map into
// the new shape. We have to introduce temporary dimension names and rename back if dimension names
// in the new and old tensor type overlap.
ExpressionNode unrollFrom = unrollTensorExpression(inputType);
ExpressionNode unrollTo = unrollTensorExpression(outputType);
ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
Generate transformTensor = new Generate(transformationType, new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
TensorFunction outputFunction = new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), Reduce.Aggregator.sum, inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
return outputFunction;
}
use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.
the class JsonWriterTestCase method registerTensorDocumentType.
private void registerTensorDocumentType() {
DocumentType x = new DocumentType("testtensor");
TensorType tensorType = new TensorType.Builder().mapped("x").mapped("y").build();
x.addField(new Field("tensorfield", new TensorDataType(tensorType)));
types.registerDocumentType(x);
}
use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.
the class Mean method lazyGetFunction.
// todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
protected TensorFunction lazyGetFunction() {
if (!allInputTypesPresent(2)) {
return null;
}
TensorFunction inputFunction = inputs.get(0).function().get();
TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
if (shouldKeepDimensions()) {
// multiply with a generated tensor created from the reduced dimensions
TensorType.Builder typeBuilder = new TensorType.Builder();
for (String name : reduceDimensions) {
typeBuilder.indexed(name, 1);
}
TensorType generatedType = typeBuilder.build();
ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
Generate generatedFunction = new Generate(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
}
return output;
}
use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.
the class Join method evaluate.
@Override
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
// Choose join algorithm
if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
return indexedVectorJoin((IndexedTensor) a, (IndexedTensor) b, joinedType);
else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
return singleSpaceJoin(a, b, joinedType);
else if (a.type().dimensions().containsAll(b.type().dimensions()))
return subspaceJoin(b, a, joinedType, true);
else if (b.type().dimensions().containsAll(a.type().dimensions()))
return subspaceJoin(a, b, joinedType, false);
else
return generalJoin(a, b, joinedType);
}
use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.
the class Concat method evaluate.
@Override
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
a = ensureIndexedDimension(dimension, a);
b = ensureIndexedDimension(dimension, b);
// If you get an exception here you have implemented a mixed tensor
IndexedTensor aIndexed = (IndexedTensor) a;
IndexedTensor bIndexed = (IndexedTensor) b;
TensorType concatType = type(a.type(), b.type());
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
int[] aToIndexes = mapIndexes(a.type(), concatType);
int[] bToIndexes = mapIndexes(b.type(), concatType);
concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
return builder.build();
}
Aggregations