use of org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap in project ignite by apache.
the class OneVsRestTrainer method extractClassLabels.
/**
* Iterates among dataset and collects class labels.
*/
private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
assert datasetBuilder != null;
PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(preprocessor);
List<Double> res = new ArrayList<>();
try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
final Set<Double> clsLabels = dataset.compute(data -> {
final Set<Double> locClsLabels = new HashSet<>();
final double[] lbs = data.getY();
for (double lb : lbs) locClsLabels.add(lb);
return locClsLabels;
}, (a, b) -> {
if (a == null)
return b == null ? new HashSet<>() : b;
if (b == null)
return a;
return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
});
if (clsLabels != null)
res.addAll(clsLabels);
} catch (Exception e) {
throw new RuntimeException(e);
}
return res;
}
Aggregations