use of org.apache.hadoop.mrunit.mapreduce.ReduceFeeder in project incubator-rya by apache.
the class ForwardChainTest method testTransitiveChain.
/**
* MultipleOutputs support is minimal, so we have to check each map/reduce
* step explicitly
*/
@Test
public void testTransitiveChain() throws Exception {
int max = 8;
int n = 4;
URI prop = TestUtils.uri("subOrganizationOf");
Map<Integer, Map<Integer, Pair<Fact, NullWritable>>> connections = new HashMap<>();
for (int i = 0; i <= max; i++) {
connections.put(i, new HashMap<Integer, Pair<Fact, NullWritable>>());
}
// Initial input: make a chain from org0 to org8
for (int i = 0; i < max; i++) {
URI orgI = TestUtils.uri("org" + i);
URI orgJ = TestUtils.uri("org" + (i + 1));
Fact triple = new Fact(orgI, prop, orgJ);
connections.get(i).put(i + 1, new Pair<>(triple, NullWritable.get()));
}
for (int i = 1; i <= n; i++) {
// Map:
MapDriver<Fact, NullWritable, ResourceWritable, Fact> mDriver = new MapDriver<>();
mDriver.getConfiguration().setInt(MRReasoningUtils.STEP_PROP, i);
mDriver.setMapper(new ForwardChain.FileMapper(schema));
for (int j : connections.keySet()) {
for (int k : connections.get(j).keySet()) {
mDriver.addInput(connections.get(j).get(k));
}
}
List<Pair<ResourceWritable, Fact>> mapped = mDriver.run();
// Convert data for reduce phase:
ReduceFeeder<ResourceWritable, Fact> feeder = new ReduceFeeder<>(mDriver.getConfiguration());
List<KeyValueReuseList<ResourceWritable, Fact>> intermediate = feeder.sortAndGroup(mapped, new ResourceWritable.SecondaryComparator(), new ResourceWritable.PrimaryComparator());
// Reduce, and compare to expected output:
ReduceDriver<ResourceWritable, Fact, Fact, NullWritable> rDriver = new ReduceDriver<>();
rDriver.getConfiguration().setInt(MRReasoningUtils.STEP_PROP, i);
rDriver.setReducer(new ForwardChain.ReasoningReducer(schema));
rDriver.addAllElements(intermediate);
int maxSpan = (int) Math.pow(2, i);
int minSpan = (maxSpan / 2) + 1;
// For each j, build all paths starting with j:
for (int j = 0; j < max; j++) {
// This includes any path of length k for appropriate k:
for (int k = minSpan; k <= maxSpan && j + k <= max; k++) {
int middle = j + minSpan - 1;
URI left = TestUtils.uri("org" + j);
URI right = TestUtils.uri("org" + (j + k));
Fact triple = new Fact(left, prop, right, i, OwlRule.PRP_TRP, TestUtils.uri("org" + middle));
triple.addSource(connections.get(j).get(middle).getFirst());
triple.addSource(connections.get(middle).get(j + k).getFirst());
Pair<Fact, NullWritable> expected = new Pair<>(triple, NullWritable.get());
connections.get(j).put(j + k, expected);
rDriver.addMultiOutput("intermediate", expected);
}
}
rDriver.runTest();
}
}
Aggregations