use of com.yahoo.searchlib.rankingexpression.RankingExpression in project vespa by vespa-engine.
the class EvaluationTestCase method testProgrammaticBuildingAndPrecedence.
@Test
public void testProgrammaticBuildingAndPrecedence() {
RankingExpression standardPrecedence = new RankingExpression(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, new ArithmeticNode(constant(3), ArithmeticOperator.MULTIPLY, constant(4))));
RankingExpression oppositePrecedence = new RankingExpression(new ArithmeticNode(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, constant(3)), ArithmeticOperator.MULTIPLY, constant(4)));
assertEquals(14.0, standardPrecedence.evaluate(null).asDouble(), tolerance);
assertEquals(20.0, oppositePrecedence.evaluate(null).asDouble(), tolerance);
assertEquals("2.0 + 3.0 * 4.0", standardPrecedence.toString());
assertEquals("(2.0 + 3.0) * 4.0", oppositePrecedence.toString());
}
use of com.yahoo.searchlib.rankingexpression.RankingExpression in project vespa by vespa-engine.
the class SimplifierTestCase method testSimplifyComplexExpression.
// A black box test verifying we are not screwing up real expressions
@Test
public void testSimplifyComplexExpression() throws ParseException {
RankingExpression initial = new RankingExpression("sqrt(if (if (INFERRED * 0.9 < INFERRED, GMP, (1 + 1.1) * INFERRED) < INFERRED * INFERRED - INFERRED, if (GMP < 85.80799542793133 * GMP, INFERRED, if (GMP < GMP, tanh(INFERRED), log(76.89956221113943))), tanh(tanh(INFERRED))) * sqrt(sqrt(GMP + INFERRED)) * GMP ) + 13.5 * (1 - GMP) * pow(GMP * 0.1, 2 + 1.1 * 0)");
TransformContext c = new TransformContext(Collections.emptyMap());
RankingExpression simplified = new Simplifier().transform(initial, c);
Context context = new MapContext();
context.put("INFERRED", 0.5);
context.put("GMP", 80.0);
context.put("value", 50.0);
assertEquals(initial.evaluate(context), simplified.evaluate(context));
context.put("INFERRED", 38.0);
context.put("GMP", 80.0);
context.put("value", 50.0);
assertEquals(initial.evaluate(context), simplified.evaluate(context));
context.put("INFERRED", 38.0);
context.put("GMP", 90.0);
context.put("value", 100.0);
assertEquals(initial.evaluate(context), simplified.evaluate(context));
context.put("INFERRED", 500.0);
context.put("GMP", 90.0);
context.put("value", 100.0);
assertEquals(initial.evaluate(context), simplified.evaluate(context));
}
use of com.yahoo.searchlib.rankingexpression.RankingExpression in project vespa by vespa-engine.
the class GbdtConverterTestCase method assertConvert.
private static void assertConvert(String gbdtModelFile, String expectedExpression) throws ParseException, UnsupportedEncodingException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
System.setOut(new PrintStream(out));
GbdtConverter.main(new String[] { gbdtModelFile });
String actualExpression = out.toString("UTF-8");
assertEquals(expectedExpression, actualExpression);
assertNotNull(new RankingExpression(actualExpression));
}
use of com.yahoo.searchlib.rankingexpression.RankingExpression in project vespa by vespa-engine.
the class MnistSoftmaxImportTestCase method testMnistSoftmaxImport.
@Test
public void testMnistSoftmaxImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved");
// Check constants
assertEquals(2, model.get().largeConstants().size());
Tensor constant0 = model.get().largeConstants().get("test_Variable_read");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type());
assertEquals(7840, constant0.size());
Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type());
assertEquals(10, constant1.size());
// Check (provided) macros
assertEquals(0, model.get().macros().size());
// Check required macros
assertEquals(1, model.get().requiredMacros().size());
assertTrue(model.get().requiredMacros().containsKey("Placeholder"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), model.get().requiredMacros().get("Placeholder"));
// Check signatures
assertEquals(1, model.get().signatures().size());
TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");
assertNotNull(signature);
// ... signature inputs
assertEquals(1, signature.inputs().size());
TensorType argument0 = signature.inputArgument("x");
assertNotNull(argument0);
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
// ... signature outputs
assertEquals(1, signature.outputs().size());
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("add", output.getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", output.getRoot().toString());
// Test execution
model.assertEqualResult("Placeholder", "MatMul");
model.assertEqualResult("Placeholder", "add");
}
Aggregations