use of org.apache.flink.types.DoubleValue in project flink by apache.
the class HITS method runInternal.
@Override
public DataSet<Result<K>> runInternal(Graph<K, VV, EV> input) throws Exception {
DataSet<Tuple2<K, K>> edges = input.getEdges().map(new ExtractEdgeIDs<K, EV>()).setParallelism(parallelism).name("Extract edge IDs");
// ID, hub, authority
DataSet<Tuple3<K, DoubleValue, DoubleValue>> initialScores = edges.map(new InitializeScores<K>()).setParallelism(parallelism).name("Initial scores").groupBy(0).reduce(new SumScores<K>()).setCombineHint(CombineHint.HASH).setParallelism(parallelism).name("Sum");
IterativeDataSet<Tuple3<K, DoubleValue, DoubleValue>> iterative = initialScores.iterate(maxIterations);
// ID, hubbiness
DataSet<Tuple2<K, DoubleValue>> hubbiness = iterative.coGroup(edges).where(0).equalTo(1).with(new Hubbiness<K>()).setParallelism(parallelism).name("Hub").groupBy(0).reduce(new SumScore<K>()).setCombineHint(CombineHint.HASH).setParallelism(parallelism).name("Sum");
// sum-of-hubbiness-squared
DataSet<DoubleValue> hubbinessSumSquared = hubbiness.map(new Square<K>()).setParallelism(parallelism).name("Square").reduce(new Sum()).setCombineHint(CombineHint.HASH).setParallelism(parallelism).name("Sum");
// ID, new authority
DataSet<Tuple2<K, DoubleValue>> authority = hubbiness.coGroup(edges).where(0).equalTo(0).with(new Authority<K>()).setParallelism(parallelism).name("Authority").groupBy(0).reduce(new SumScore<K>()).setCombineHint(CombineHint.HASH).setParallelism(parallelism).name("Sum");
// sum-of-authority-squared
DataSet<DoubleValue> authoritySumSquared = authority.map(new Square<K>()).setParallelism(parallelism).name("Square").reduce(new Sum()).setCombineHint(CombineHint.HASH).setParallelism(parallelism).name("Sum");
// ID, normalized hubbiness, normalized authority
DataSet<Tuple3<K, DoubleValue, DoubleValue>> scores = hubbiness.fullOuterJoin(authority, JoinHint.REPARTITION_SORT_MERGE).where(0).equalTo(0).with(new JoinAndNormalizeHubAndAuthority<K>()).withBroadcastSet(hubbinessSumSquared, HUBBINESS_SUM_SQUARED).withBroadcastSet(authoritySumSquared, AUTHORITY_SUM_SQUARED).setParallelism(parallelism).name("Join scores");
DataSet<Tuple3<K, DoubleValue, DoubleValue>> passThrough;
if (convergenceThreshold < Double.MAX_VALUE) {
passThrough = iterative.fullOuterJoin(scores, JoinHint.REPARTITION_SORT_MERGE).where(0).equalTo(0).with(new ChangeInScores<K>()).setParallelism(parallelism).name("Change in scores");
iterative.registerAggregationConvergenceCriterion(CHANGE_IN_SCORES, new DoubleSumAggregator(), new ScoreConvergence(convergenceThreshold));
} else {
passThrough = scores;
}
return iterative.closeWith(passThrough).map(new TranslateResult<K>()).setParallelism(parallelism).name("Map result");
}
use of org.apache.flink.types.DoubleValue in project flink by apache.
the class ToNullValueTest method testTranslation.
@Test
public void testTranslation() throws Exception {
NullValue reuse = NullValue.getInstance();
assertEquals(NullValue.getInstance(), new ToNullValue<>().translate(new DoubleValue(), reuse));
assertEquals(NullValue.getInstance(), new ToNullValue<>().translate(new FloatValue(), reuse));
assertEquals(NullValue.getInstance(), new ToNullValue<>().translate(new IntValue(), reuse));
assertEquals(NullValue.getInstance(), new ToNullValue<>().translate(new LongValue(), reuse));
assertEquals(NullValue.getInstance(), new ToNullValue<>().translate(new StringValue(), reuse));
}
use of org.apache.flink.types.DoubleValue in project flink by apache.
the class OutputEmitterTest method testMultiKeys.
@Test
public void testMultiKeys() {
@SuppressWarnings({ "unchecked", "rawtypes" }) final TypeComparator<Record> multiComp = new RecordComparatorFactory(new int[] { 0, 1, 3 }, new Class[] { IntValue.class, StringValue.class, DoubleValue.class }).createComparator();
final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_HASH, multiComp);
final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer());
int numChannels = 100;
int numRecords = 5000;
int[] hit = new int[numChannels];
for (int i = 0; i < numRecords; i++) {
Record rec = new Record(4);
rec.setField(0, new IntValue(i));
rec.setField(1, new StringValue("AB" + i + "CD" + i));
rec.setField(3, new DoubleValue(i * 3.141d));
delegate.setInstance(rec);
int[] chans = oe1.selectChannels(delegate, hit.length);
for (int chan : chans) {
hit[chan]++;
}
}
int cnt = 0;
for (int aHit : hit) {
assertTrue(aHit > 0);
cnt += aHit;
}
assertTrue(cnt == numRecords);
}
Aggregations