Search in sources :

Example 1 with Join

use of com.yahoo.tensor.functions.Join in project vespa by vespa-engine.

the class TensorFlowFeatureConverter method expandBatchDimensionsAtOutput.

/**
 * If batch dimensions have been reduced away above, bring them back here
 * for any following computation of the tensor.
 * Todo: determine when this is not necessary!
 */
private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
    if (after.equals(before)) {
        return node;
    }
    TensorType.Builder typeBuilder = new TensorType.Builder();
    for (TensorType.Dimension dimension : before.dimensions()) {
        if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
            typeBuilder.indexed(dimension.name(), 1);
        }
    }
    TensorType expandDimensionsType = typeBuilder.build();
    if (expandDimensionsType.dimensions().size() > 0) {
        ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
        Generate generatedFunction = new Generate(expandDimensionsType, new GeneratorLambdaFunctionNode(expandDimensionsType, generatedExpression).asLongListToDoubleOperator());
        Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
        return new TensorFunctionNode(expand);
    }
    return node;
}
Also used : GeneratorLambdaFunctionNode(com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode) ConstantNode(com.yahoo.searchlib.rankingexpression.rule.ConstantNode) DoubleValue(com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue) TensorFunctionNode(com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) Generate(com.yahoo.tensor.functions.Generate) Join(com.yahoo.tensor.functions.Join) TensorType(com.yahoo.tensor.TensorType)

Example 2 with Join

use of com.yahoo.tensor.functions.Join in project vespa by vespa-engine.

the class TensorFunctionBenchmark method dotProduct.

private double dotProduct(Tensor tensor, List<Tensor> tensors) {
    double largest = Double.MIN_VALUE;
    TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), new VariableTensor("argument"), (a, b) -> a * b), Reduce.Aggregator.sum).toPrimitive();
    MapEvaluationContext context = new MapEvaluationContext();
    for (Tensor tensorElement : tensors) {
        // tensors.size() = 1 for larger tensor
        context.put("argument", tensorElement);
        double dotProduct = dotProductFunction.evaluate(context).asDouble();
        if (dotProduct > largest) {
            largest = dotProduct;
        }
    }
    return largest;
}
Also used : VariableTensor(com.yahoo.tensor.evaluation.VariableTensor) ConstantTensor(com.yahoo.tensor.functions.ConstantTensor) VariableTensor(com.yahoo.tensor.evaluation.VariableTensor) TensorFunction(com.yahoo.tensor.functions.TensorFunction) Join(com.yahoo.tensor.functions.Join) Reduce(com.yahoo.tensor.functions.Reduce) ConstantTensor(com.yahoo.tensor.functions.ConstantTensor) MapEvaluationContext(com.yahoo.tensor.evaluation.MapEvaluationContext)

Example 3 with Join

use of com.yahoo.tensor.functions.Join in project vespa by vespa-engine.

the class TensorTestCase method testTensorComputation.

/**
 * All functions are more throughly tested in searchlib EvaluationTestCase
 */
@Test
public void testTensorComputation() {
    Tensor tensor1 = Tensor.from("{ {x:1}:3, {x:2}:7 }");
    Tensor tensor2 = Tensor.from("{ {y:1}:5 }");
    assertEquals(Tensor.from("{ {x:1,y:1}:15, {x:2,y:1}:35 }"), tensor1.multiply(tensor2));
    assertEquals(Tensor.from("{ {x:1,y:1}:12, {x:2,y:1}:28 }"), tensor1.join(tensor2, (a, b) -> a * b - a));
    assertEquals(Tensor.from("{ {x:1,y:1}:0, {x:2,y:1}:1 }"), tensor1.larger(tensor2));
    assertEquals(Tensor.from("{ {y:1}:50.0 }"), tensor1.matmul(tensor2, "x"));
    assertEquals(Tensor.from("{ {z:1}:3, {z:2}:7 }"), tensor1.rename("x", "z"));
    assertEquals(Tensor.from("{ {y:1,x:1}:8, {x:1,y:2}:12 }"), tensor1.add(tensor2).rename(ImmutableList.of("x", "y"), ImmutableList.of("y", "x")));
    assertEquals(Tensor.from("{ {x:0,y:0}:0, {x:0,y:1}:0, {x:1,y:0}:0, {x:1,y:1}:1, {x:2,y:0}:0, {x:2,y:1}:2, }"), Tensor.generate(new TensorType.Builder().indexed("x", 3).indexed("y", 2).build(), (List<Long> indexes) -> (double) indexes.get(0) * indexes.get(1)));
    assertEquals(Tensor.from("{ {x:0,y:0,z:0}:0, {x:0,y:1,z:0}:1, {x:1,y:0,z:0}:1, {x:1,y:1,z:0}:2, {x:2,y:0,z:0}:2, {x:2,y:1,z:0}:3, " + "  {x:0,y:0,z:1}:1, {x:0,y:1,z:1}:2, {x:1,y:0,z:1}:2, {x:1,y:1,z:1}:3, {x:2,y:0,z:1}:3, {x:2,y:1,z:1}:4 }"), Tensor.range(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build()));
    assertEquals(Tensor.from("{ {x:0,y:0,z:0}:1, {x:0,y:1,z:0}:0, {x:1,y:0,z:0}:0, {x:1,y:1,z:0}:0, {x:2,y:0,z:0}:0, {x:2,y:1,z:0}:0, " + "  {x:0,y:0,z:1}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:1}:1, {x:2,y:0,z:1}:0, {x:2,y:1,z:1}:00 }"), Tensor.diag(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build()));
    assertEquals(Tensor.from("{ {x:1}:0, {x:3}:1, {x:9}:0 }"), Tensor.from("{ {x:1}:1, {x:3}:5, {x:9}:3 }").argmax("x"));
}
Also used : ConstantTensor(com.yahoo.tensor.functions.ConstantTensor) MapEvaluationContext(com.yahoo.tensor.evaluation.MapEvaluationContext) Join(com.yahoo.tensor.functions.Join) Set(java.util.Set) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) Reduce(com.yahoo.tensor.functions.Reduce) ArrayList(java.util.ArrayList) List(java.util.List) ImmutableList(com.google.common.collect.ImmutableList) Type(com.yahoo.tensor.TensorType.Dimension.Type) Assert.fail(org.junit.Assert.fail) TensorFunction(com.yahoo.tensor.functions.TensorFunction) Collections(java.util.Collections) Assert.assertEquals(org.junit.Assert.assertEquals) VariableTensor(com.yahoo.tensor.evaluation.VariableTensor) ConstantTensor(com.yahoo.tensor.functions.ConstantTensor) VariableTensor(com.yahoo.tensor.evaluation.VariableTensor) Test(org.junit.Test)

Example 4 with Join

use of com.yahoo.tensor.functions.Join in project vespa by vespa-engine.

the class TensorTestCase method dotProduct.

private double dotProduct(Tensor tensor, List<Tensor> tensors) {
    double sum = 0;
    TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), new VariableTensor("argument"), (a, b) -> a * b), Reduce.Aggregator.sum).toPrimitive();
    MapEvaluationContext context = new MapEvaluationContext();
    for (Tensor tensorElement : tensors) {
        // tensors.size() = 1 for larger tensor
        context.put("argument", tensorElement);
        // System.out.println("Dot product of " + tensor + " and " + tensorElement + ": " + dotProductFunction.evaluate(context).asDouble());
        sum += dotProductFunction.evaluate(context).asDouble();
    }
    return sum;
}
Also used : VariableTensor(com.yahoo.tensor.evaluation.VariableTensor) ConstantTensor(com.yahoo.tensor.functions.ConstantTensor) VariableTensor(com.yahoo.tensor.evaluation.VariableTensor) TensorFunction(com.yahoo.tensor.functions.TensorFunction) Join(com.yahoo.tensor.functions.Join) Reduce(com.yahoo.tensor.functions.Reduce) ConstantTensor(com.yahoo.tensor.functions.ConstantTensor) MapEvaluationContext(com.yahoo.tensor.evaluation.MapEvaluationContext)

Aggregations

Join (com.yahoo.tensor.functions.Join)4 MapEvaluationContext (com.yahoo.tensor.evaluation.MapEvaluationContext)3 VariableTensor (com.yahoo.tensor.evaluation.VariableTensor)3 ConstantTensor (com.yahoo.tensor.functions.ConstantTensor)3 Reduce (com.yahoo.tensor.functions.Reduce)3 TensorFunction (com.yahoo.tensor.functions.TensorFunction)3 ImmutableList (com.google.common.collect.ImmutableList)1 DoubleValue (com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue)1 ConstantNode (com.yahoo.searchlib.rankingexpression.rule.ConstantNode)1 ExpressionNode (com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)1 GeneratorLambdaFunctionNode (com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode)1 TensorFunctionNode (com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode)1 TensorType (com.yahoo.tensor.TensorType)1 Type (com.yahoo.tensor.TensorType.Dimension.Type)1 Generate (com.yahoo.tensor.functions.Generate)1 ArrayList (java.util.ArrayList)1 Collections (java.util.Collections)1 List (java.util.List)1 Set (java.util.Set)1 Assert.assertEquals (org.junit.Assert.assertEquals)1