Search in sources :

Example 16 with Tensor

use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.

the class TensorFieldValueTestCase method requireThatSameTensorValueIsEqual.

@Test
public void requireThatSameTensorValueIsEqual() {
    Tensor tensor = Tensor.from("{{x:0}:2.0}");
    TensorFieldValue field1 = new TensorFieldValue(tensor);
    TensorFieldValue field2 = new TensorFieldValue(tensor);
    assertTrue(field1.equals(field1));
    assertTrue(field1.equals(field2));
    assertTrue(field1.equals(createFieldValue("{{x:0}:2.0}")));
}
Also used : Tensor(com.yahoo.tensor.Tensor) Test(org.junit.Test)

Example 17 with Tensor

use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.

the class TensorFlowImporter method importConstant.

private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, SavedModelBundle bundle) {
    String name = operation.vespaName();
    if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
        return operation.function();
    }
    Tensor tensor;
    if (operation.getConstantValue().isPresent()) {
        Value value = operation.getConstantValue().get();
        if (!(value instanceof TensorValue)) {
            // scalar values are inserted directly into the expression
            return operation.function();
        }
        tensor = value.asTensor();
    } else {
        // Here we use the type from the operation, which will have correct dimension names after name resolving
        tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle), operation.type().get());
        operation.setConstantValue(new TensorValue(tensor));
    }
    if (tensor.type().rank() == 0 || tensor.size() <= 1) {
        model.smallConstant(name, tensor);
    } else {
        model.largeConstant(name, tensor);
    }
    return operation.function();
}
Also used : TensorValue(com.yahoo.searchlib.rankingexpression.evaluation.TensorValue) Tensor(com.yahoo.tensor.Tensor) TensorValue(com.yahoo.searchlib.rankingexpression.evaluation.TensorValue) Value(com.yahoo.searchlib.rankingexpression.evaluation.Value)

Example 18 with Tensor

use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.

the class Mean method lazyGetType.

@Override
protected OrderedTensorType lazyGetType() {
    if (!allInputTypesPresent(2)) {
        return null;
    }
    TensorFlowOperation reductionIndices = inputs.get(1);
    if (!reductionIndices.getConstantValue().isPresent()) {
        throw new IllegalArgumentException("Mean in " + node.getName() + ": " + "reduction indices must be a constant.");
    }
    Tensor indices = reductionIndices.getConstantValue().get().asTensor();
    reduceDimensions = new ArrayList<>();
    OrderedTensorType inputType = inputs.get(0).type().get();
    for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext(); ) {
        Tensor.Cell cell = cellIterator.next();
        int dimensionIndex = cell.getValue().intValue();
        if (dimensionIndex < 0) {
            dimensionIndex = inputType.dimensions().size() - dimensionIndex;
        }
        reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
    }
    return reducedType(inputType, shouldKeepDimensions());
}
Also used : Tensor(com.yahoo.tensor.Tensor) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)

Example 19 with Tensor

use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.

the class Join method mappedGeneralJoin.

private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) {
    int[] aToIndexes = mapIndexes(a.type(), joinedType);
    int[] bToIndexes = mapIndexes(b.type(), joinedType);
    Tensor.Builder builder = Tensor.Builder.of(joinedType);
    for (Iterator<Tensor.Cell> aIterator = a.cellIterator(); aIterator.hasNext(); ) {
        Map.Entry<TensorAddress, Double> aCell = aIterator.next();
        for (Iterator<Tensor.Cell> bIterator = b.cellIterator(); bIterator.hasNext(); ) {
            Map.Entry<TensorAddress, Double> bCell = bIterator.next();
            TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType);
            // not combinable
            if (combinedAddress == null)
                continue;
            builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue()));
        }
    }
    return builder.build();
}
Also used : TensorAddress(com.yahoo.tensor.TensorAddress) IndexedTensor(com.yahoo.tensor.IndexedTensor) Tensor(com.yahoo.tensor.Tensor) HashMap(java.util.HashMap) Map(java.util.Map)

Example 20 with Tensor

use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.

the class Join method evaluate.

@Override
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
    Tensor a = argumentA.evaluate(context);
    Tensor b = argumentB.evaluate(context);
    TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
    // Choose join algorithm
    if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
        return indexedVectorJoin((IndexedTensor) a, (IndexedTensor) b, joinedType);
    else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
        return singleSpaceJoin(a, b, joinedType);
    else if (a.type().dimensions().containsAll(b.type().dimensions()))
        return subspaceJoin(b, a, joinedType, true);
    else if (b.type().dimensions().containsAll(a.type().dimensions()))
        return subspaceJoin(a, b, joinedType, false);
    else
        return generalJoin(a, b, joinedType);
}
Also used : IndexedTensor(com.yahoo.tensor.IndexedTensor) Tensor(com.yahoo.tensor.Tensor) TensorType(com.yahoo.tensor.TensorType) IndexedTensor(com.yahoo.tensor.IndexedTensor)

Aggregations

Tensor (com.yahoo.tensor.Tensor)58 Test (org.junit.Test)26 TensorType (com.yahoo.tensor.TensorType)17 IndexedTensor (com.yahoo.tensor.IndexedTensor)10 TensorAddress (com.yahoo.tensor.TensorAddress)7 MixedTensor (com.yahoo.tensor.MixedTensor)5 HashMap (java.util.HashMap)5 Map (java.util.Map)4 MapContext (com.yahoo.searchlib.rankingexpression.evaluation.MapContext)3 TensorValue (com.yahoo.searchlib.rankingexpression.evaluation.TensorValue)3 OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)3 DimensionSizes (com.yahoo.tensor.DimensionSizes)3 JsonNode (com.fasterxml.jackson.databind.JsonNode)2 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)2 ImmutableList (com.google.common.collect.ImmutableList)2 MappedTensor (com.yahoo.tensor.MappedTensor)2 IOException (java.io.IOException)2 List (java.util.List)2 GrowableByteBuffer (com.yahoo.io.GrowableByteBuffer)1 Path (com.yahoo.path.Path)1