use of com.yahoo.searchdefinition.derived.AttributeFields in project vespa by vespa-engine.
the class RankingExpressionConstantsTestCase method testConstants.
@Test
public void testConstants() throws ParseException {
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
QueryProfileRegistry queryProfileRegistry = new QueryProfileRegistry();
SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
builder.importString("search test {\n" + " document test { \n" + " field a type string { \n" + " indexing: index \n" + " }\n" + " }\n" + " \n" + " rank-profile parent {\n" + " constants {\n" + " p1: 7 \n" + " p2: 0 \n" + " }\n" + " first-phase {\n" + " expression: p2 * (1.3 + p1 )\n" + " }\n" + " }\n" + " rank-profile child1 inherits parent {\n" + " first-phase {\n" + " expression: a + b + c \n" + " }\n" + " second-phase {\n" + " expression: a + p1 + c \n" + " }\n" + " constants {\n" + " a: 1.0 \n" + " b: 2 \n" + " c: 3.5 \n" + " }\n" + " }\n" + " rank-profile child2 inherits parent {\n" + " constants {\n" + " p2: 2.0 \n" + " }\n" + " macro foo() {\n" + " expression: p2*p1\n" + " }\n" + " }\n" + "\n" + "}\n");
builder.build();
Search s = builder.getSearch();
RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(queryProfileRegistry);
assertEquals("0.0", parent.getFirstPhaseRanking().getRoot().toString());
RankProfile child1 = rankProfileRegistry.getRankProfile(s, "child1").compile(queryProfileRegistry);
assertEquals("6.5", child1.getFirstPhaseRanking().getRoot().toString());
assertEquals("11.5", child1.getSecondPhaseRanking().getRoot().toString());
RankProfile child2 = rankProfileRegistry.getRankProfile(s, "child2").compile(queryProfileRegistry);
assertEquals("16.6", child2.getFirstPhaseRanking().getRoot().toString());
assertEquals("foo: 14.0", child2.getMacros().get("foo").getRankingExpression().toString());
List<Pair<String, String>> rankProperties = new RawRankProfile(child2, queryProfileRegistry, new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(foo).rankingScript,14.0)", rankProperties.get(0).toString());
assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(2).toString());
}
use of com.yahoo.searchdefinition.derived.AttributeFields in project vespa by vespa-engine.
the class RankingExpressionShadowingTestCase method testNeuralNetworkSetup.
@Test
public void testNeuralNetworkSetup() throws ParseException {
// Note: the type assigned to query profile and constant tensors here is not the correct type
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(x[])");
SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles);
builder.importString("search test {\n" + " document test { \n" + " field a type string { \n" + " indexing: index \n" + " }\n" + " }\n" + " \n" + " rank-profile test {\n" + // relu is a built in function, redefined here
" macro relu(x) {\n" + " expression: max(1.0, x)\n" + " }\n" + " macro hidden_layer() {\n" + " expression: relu(sum(query(q) * constant(W_hidden), input) + constant(b_input))\n" + " }\n" + " macro final_layer() {\n" + " expression: sigmoid(sum(hidden_layer * constant(W_final), hidden) + constant(b_final))\n" + " }\n" + " second-phase {\n" + " expression: sum(final_layer)\n" + " }\n" + " }\n" + " constant W_hidden {\n" + " type: tensor(x[])\n" + " file: ignored.json\n" + " }\n" + " constant b_input {\n" + " type: tensor(x[])\n" + " file: ignored.json\n" + " }\n" + " constant W_final {\n" + " type: tensor(x[])\n" + " file: ignored.json\n" + " }\n" + " constant b_final {\n" + " type: tensor(x[])\n" + " file: ignored.json\n" + " }\n" + "}\n");
builder.build();
Search s = builder.getSearch();
RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles);
List<Pair<String, String>> testRankProperties = new RawRankProfile(test, queryProfiles, new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", testRankProperties.get(0).toString());
assertEquals("(rankingExpression(relu@).rankingScript,max(1.0,reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input)))", censorBindingHash(testRankProperties.get(1).toString()));
assertEquals("(rankingExpression(hidden_layer).rankingScript,rankingExpression(relu@))", censorBindingHash(testRankProperties.get(2).toString()));
assertEquals("(rankingExpression(final_layer).rankingScript,sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))", testRankProperties.get(3).toString());
assertEquals("(vespa.rank.secondphase,rankingExpression(secondphase))", testRankProperties.get(4).toString());
assertEquals("(rankingExpression(secondphase).rankingScript,reduce(rankingExpression(final_layer), sum))", testRankProperties.get(5).toString());
}
Aggregations