use of org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder in project ignite by apache.
the class DatasetFactory method createSimpleDataset.
/**
* Creates a new instance of distributed {@link SimpleDataset} using the specified {@code partCtxBuilder} and {@code
* featureExtractor}. This methods determines partition {@code data} to be {@link SimpleDatasetData}, but allows to
* use any desired type of partition {@code context}.
*
* @param datasetBuilder Dataset builder.
* @param envBuilder Learning environment builder.
* @param partCtxBuilder Partition {@code context} builder.
* @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
* @param <C> Type of a partition {@code context}.
* @return Dataset.
*/
public static <K, V, C extends Serializable, CO extends Serializable> SimpleDataset<C> createSimpleDataset(DatasetBuilder<K, V> datasetBuilder, LearningEnvironmentBuilder envBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder, Preprocessor<K, V> featureExtractor) {
LearningEnvironment environment = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
environment.initDeployingContext(featureExtractor);
return create(datasetBuilder, envBuilder, partCtxBuilder, new SimpleDatasetDataBuilder<>(featureExtractor), environment).wrap(SimpleDataset::new);
}
use of org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder in project ignite by apache.
the class DataStreamGeneratorFillCacheTest method testCacheFilling.
/**
*/
@Test
public void testCacheFilling() {
IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
String cacheName = "TEST_CACHE";
CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
int datasetSize = 5000;
try (Ignite ignite = Ignition.start(configuration)) {
IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
env.deployingContext().initByClientObject(map);
try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
StatPair res = dataset.compute(map, StatPair::sum);
assertEquals(datasetSize, res.cntOfRows);
assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
}
ignite.destroyCache(cacheName);
}
}
Aggregations