Search in sources :

Example 1 with RankingExpression

use of com.yahoo.searchlib.rankingexpression.RankingExpression in project vespa by vespa-engine.

the class TensorFlowFeatureConverter method reduceBatchDimensions.

/**
 * Check if batch dimensions of inputs can be reduced out. If the input
 * macro specifies that a single exemplar should be evaluated, we can
 * reduce the batch dimension out.
 */
private void reduceBatchDimensions(RankingExpression expression, TensorFlowModel model, RankProfile profile, QueryProfileRegistry queryProfiles) {
    TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
    TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
    // Check generated macros for inputs to reduce
    Set<String> macroNames = new HashSet<>();
    addMacroNamesIn(expression.getRoot(), macroNames, model);
    for (String macroName : macroNames) {
        if (!model.macros().containsKey(macroName)) {
            continue;
        }
        RankProfile.Macro macro = profile.getMacros().get(macroName);
        if (macro == null) {
            throw new IllegalArgumentException("Model refers to generated macro '" + macroName + "but this macro is not present in " + profile);
        }
        RankingExpression macroExpression = macro.getRankingExpression();
        macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
    }
    // Check expression for inputs to reduce
    ExpressionNode root = expression.getRoot();
    root = reduceBatchDimensionsAtInput(root, model, typeContext);
    TensorType typeAfterReducing = root.type(typeContext);
    root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
    expression.setRoot(root);
}
Also used : RankingExpression(com.yahoo.searchlib.rankingexpression.RankingExpression) Reference(com.yahoo.searchlib.rankingexpression.Reference) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) TensorType(com.yahoo.tensor.TensorType) RankProfile(com.yahoo.searchdefinition.RankProfile) HashSet(java.util.HashSet)

Example 2 with RankingExpression

use of com.yahoo.searchlib.rankingexpression.RankingExpression in project vespa by vespa-engine.

the class GBDTForestOptimizerTestCase method testForestOptimization.

public void testForestOptimization() throws ParseException {
    String gbdtString = "if (LW_NEWS_SEARCHES_RATIO < 1.72971, 0.0697159, if (LW_USERS < 0.10496, if (SEARCHES < 0.0329127, 0.151257, 0.117501), if (SUGG_OVERLAP < 18.5, 0.0897622, 0.0756903))) + \n" + "if (LW_NEWS_SEARCHES_RATIO < 1.73156, if (NEWS_USERS < 0.0737993, -0.00481646, 0.00110018), if (LW_USERS < 0.0844616, 0.0488919, if (SUGG_OVERLAP < 32.5, 0.0136917, 9.85328E-4))) + \n" + "if (LW_NEWS_SEARCHES_RATIO < 1.74451, -0.00298257, if (LW_USERS < 0.116207, if (SEARCHES < 0.0329127, 0.0676105, 0.0340198), if (NUM_WORDS < 1.5, -8.55514E-5, 0.0112406))) + \n" + "if (LW_NEWS_SEARCHES_RATIO < 1.72995, if (NEWS_USERS < 0.0737993, -0.00407515, 0.00139088), if (LW_USERS == 0.0509035, 0.0439466, if (LW_USERS < 0.325818, 0.0187156, 0.00236949)))";
    RankingExpression gbdt = new RankingExpression(gbdtString);
    // Regular evaluation
    MapContext arguments = new MapContext();
    arguments.put("LW_NEWS_SEARCHES_RATIO", 1d);
    arguments.put("SUGG_OVERLAP", 17d);
    double result1 = gbdt.evaluate(arguments).asDouble();
    arguments.put("LW_NEWS_SEARCHES_RATIO", 2d);
    arguments.put("SUGG_OVERLAP", 20d);
    double result2 = gbdt.evaluate(arguments).asDouble();
    arguments.put("LW_NEWS_SEARCHES_RATIO", 2d);
    arguments.put("SUGG_OVERLAP", 40d);
    double result3 = gbdt.evaluate(arguments).asDouble();
    // Optimized evaluation
    ArrayContext fArguments = new ArrayContext(gbdt);
    ExpressionOptimizer optimizer = new ExpressionOptimizer();
    OptimizationReport report = optimizer.optimize(gbdt, fArguments);
    assertEquals(4, report.getMetric("Optimized GDBT trees"));
    assertEquals(4, report.getMetric("GBDT trees optimized to forests"));
    assertEquals(1, report.getMetric("Number of forests"));
    fArguments.put("LW_NEWS_SEARCHES_RATIO", 1d);
    fArguments.put("SUGG_OVERLAP", 17d);
    double oResult1 = gbdt.evaluate(fArguments).asDouble();
    fArguments.put("LW_NEWS_SEARCHES_RATIO", 2d);
    fArguments.put("SUGG_OVERLAP", 20d);
    double oResult2 = gbdt.evaluate(fArguments).asDouble();
    fArguments.put("LW_NEWS_SEARCHES_RATIO", 2d);
    fArguments.put("SUGG_OVERLAP", 40d);
    double oResult3 = gbdt.evaluate(fArguments).asDouble();
    // Assert the same results are produced
    // (adding linearly to one double does not produce exactly the same double
    // as adding up a tree of stack frames though)
    assertEqualish(result1, oResult1);
    assertEqualish(result2, oResult2);
    assertEqualish(result3, oResult3);
}
Also used : RankingExpression(com.yahoo.searchlib.rankingexpression.RankingExpression)

Example 3 with RankingExpression

use of com.yahoo.searchlib.rankingexpression.RankingExpression in project vespa by vespa-engine.

the class GBDTForestOptimizerTestCase method testForestOptimizationWithSetMembershipConditions.

public void testForestOptimizationWithSetMembershipConditions() throws ParseException {
    String gbdtString = "if (MYSTRING in [\"string 1\",\"string 2\"], 0.0697159, if (LW_USERS < 0.10496, if (SEARCHES < 0.0329127, 0.151257, 0.117501), if (MYSTRING in [\"string 2\"], 0.0897622, 0.0756903))) + \n" + "if (LW_NEWS_SEARCHES_RATIO < 1.73156, if (NEWS_USERS < 0.0737993, -0.00481646, 0.00110018), if (LW_USERS < 0.0844616, 0.0488919, if (SUGG_OVERLAP < 32.5, 0.0136917, 9.85328E-4))) + \n" + "if (LW_NEWS_SEARCHES_RATIO < 1.74451, -0.00298257, if (LW_USERS < 0.116207, if (SEARCHES < 0.0329127, 0.0676105, 0.0340198), if (NUM_WORDS < 1.5, -8.55514E-5, 0.0112406))) + \n" + "if (LW_NEWS_SEARCHES_RATIO < 1.72995, if (NEWS_USERS < 0.0737993, -0.00407515, 0.00139088), if (LW_USERS == 0.0509035, 0.0439466, if (LW_USERS < 0.325818, 0.0187156, 0.00236949)))";
    RankingExpression gbdt = new RankingExpression(gbdtString);
    // Regular evaluation
    MapContext arguments = new MapContext();
    arguments.put("MYSTRING", new StringValue("string 1"));
    arguments.put("LW_NEWS_SEARCHES_RATIO", 1d);
    arguments.put("SUGG_OVERLAP", 17d);
    double result1 = gbdt.evaluate(arguments).asDouble();
    arguments.put("LW_NEWS_SEARCHES_RATIO", 2d);
    arguments.put("SUGG_OVERLAP", 20d);
    double result2 = gbdt.evaluate(arguments).asDouble();
    arguments.put("LW_NEWS_SEARCHES_RATIO", 2d);
    arguments.put("SUGG_OVERLAP", 40d);
    double result3 = gbdt.evaluate(arguments).asDouble();
    // Optimized evaluation
    ArrayContext fArguments = new ArrayContext(gbdt);
    ExpressionOptimizer optimizer = new ExpressionOptimizer();
    OptimizationReport report = optimizer.optimize(gbdt, fArguments);
    assertEquals(4, report.getMetric("Optimized GDBT trees"));
    assertEquals(4, report.getMetric("GBDT trees optimized to forests"));
    assertEquals(1, report.getMetric("Number of forests"));
    fArguments.put("MYSTRING", new StringValue("string 1"));
    fArguments.put("LW_NEWS_SEARCHES_RATIO", 1d);
    fArguments.put("SUGG_OVERLAP", 17d);
    double oResult1 = gbdt.evaluate(fArguments).asDouble();
    fArguments.put("LW_NEWS_SEARCHES_RATIO", 2d);
    fArguments.put("SUGG_OVERLAP", 20d);
    double oResult2 = gbdt.evaluate(fArguments).asDouble();
    fArguments.put("LW_NEWS_SEARCHES_RATIO", 2d);
    fArguments.put("SUGG_OVERLAP", 40d);
    double oResult3 = gbdt.evaluate(fArguments).asDouble();
    // Assert the same results are produced
    // (adding linearly to one double does not produce exactly the same double
    // as adding up a tree of stack frames though)
    assertEqualish(result1, oResult1);
    assertEqualish(result2, oResult2);
    assertEqualish(result3, oResult3);
}
Also used : RankingExpression(com.yahoo.searchlib.rankingexpression.RankingExpression)

Example 4 with RankingExpression

use of com.yahoo.searchlib.rankingexpression.RankingExpression in project vespa by vespa-engine.

the class BatchNormImportTestCase method testBatchNormImport.

@Test
public void testBatchNormImport() {
    TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved");
    TensorFlowModel.Signature signature = model.get().signature("serving_default");
    assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size());
    RankingExpression output = signature.outputExpression("y");
    assertNotNull(output);
    assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName());
    model.assertEqualResult("X", output.getName());
}
Also used : RankingExpression(com.yahoo.searchlib.rankingexpression.RankingExpression) Test(org.junit.Test)

Example 5 with RankingExpression

use of com.yahoo.searchlib.rankingexpression.RankingExpression in project vespa by vespa-engine.

the class DropoutImportTestCase method testDropoutImport.

@Test
public void testDropoutImport() {
    TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved");
    // Check required macros
    assertEquals(1, model.get().requiredMacros().size());
    assertTrue(model.get().requiredMacros().containsKey("X"));
    assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), model.get().requiredMacros().get("X"));
    TensorFlowModel.Signature signature = model.get().signature("serving_default");
    assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size());
    RankingExpression output = signature.outputExpression("y");
    assertNotNull(output);
    assertEquals("outputs/Maximum", output.getName());
    assertEquals("join(join(tf_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))", output.getRoot().toString());
    model.assertEqualResult("X", output.getName());
}
Also used : RankingExpression(com.yahoo.searchlib.rankingexpression.RankingExpression) Test(org.junit.Test)

Aggregations

RankingExpression (com.yahoo.searchlib.rankingexpression.RankingExpression)34 Test (org.junit.Test)10 ParseException (com.yahoo.searchlib.rankingexpression.parser.ParseException)6 ArrayContext (com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext)5 OptimizationReport (com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport)4 ExpressionOptimizer (com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer)3 MapContext (com.yahoo.searchlib.rankingexpression.evaluation.MapContext)3 TensorType (com.yahoo.tensor.TensorType)3 LinkedList (java.util.LinkedList)3 GBDTForestOptimizer (com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTForestOptimizer)2 ExpressionNode (com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)2 HashMap (java.util.HashMap)2 HashSet (java.util.HashSet)2 RankProfile (com.yahoo.searchdefinition.RankProfile)1 ExpressionFunction (com.yahoo.searchlib.rankingexpression.ExpressionFunction)1 Reference (com.yahoo.searchlib.rankingexpression.Reference)1 Context (com.yahoo.searchlib.rankingexpression.evaluation.Context)1 TensorValue (com.yahoo.searchlib.rankingexpression.evaluation.TensorValue)1 Value (com.yahoo.searchlib.rankingexpression.evaluation.Value)1 TensorFlowModel (com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel)1