use of org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData in project ignite by apache.
the class SimpleDatasetDataBuilder method build.
/**
* {@inheritDoc}
*/
@Override
public SimpleDatasetData build(LearningEnvironment env, Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
// Prepares the matrix of features in flat column-major format.
int cols = -1;
double[] features = null;
int ptr = 0;
while (upstreamData.hasNext()) {
UpstreamEntry<K, V> entry = upstreamData.next();
Vector row = preprocessor.apply(entry.getKey(), entry.getValue()).features();
if (row.isNumeric()) {
if (cols < 0) {
cols = row.size();
features = new double[Math.toIntExact(upstreamDataSize * cols)];
} else
assert row.size() == cols : "Feature extractor must return exactly " + cols + " features";
for (int i = 0; i < cols; i++) features[Math.toIntExact(i * upstreamDataSize + ptr)] = row.get(i);
ptr++;
} else
throw new NonDoubleVectorException(row);
}
return new SimpleDatasetData(features, Math.toIntExact(upstreamDataSize));
}
use of org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData in project ignite by apache.
the class DataStreamGeneratorFillCacheTest method testCacheFilling.
/**
*/
@Test
public void testCacheFilling() {
IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
String cacheName = "TEST_CACHE";
CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
int datasetSize = 5000;
try (Ignite ignite = Ignition.start(configuration)) {
IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
env.deployingContext().initByClientObject(map);
try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
StatPair res = dataset.compute(map, StatPair::sum);
assertEquals(datasetSize, res.cntOfRows);
assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
}
ignite.destroyCache(cacheName);
}
}
use of org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData in project ignite by apache.
the class CacheBasedDatasetTest method testPartitionExchangeDuringComputeWithCtxCall.
/**
* Tests that partitions of the upstream cache and the partition {@code context} cache are reserved during
* computations on dataset. Reservation means that partitions won't be unloaded from the node before computation is
* completed.
*/
@Test
public void testPartitionExchangeDuringComputeWithCtxCall() {
int partitions = 4;
IgniteCache<Integer, String> upstreamCache = generateTestData(4, 0);
CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache);
CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset = builder.build(TestUtils.testEnvBuilder(), (env, upstream, upstreamSize) -> upstreamSize, (env, upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0), TestUtils.testEnvBuilder().buildForTrainer());
assertTrue("Before computation all partitions should not be reserved", areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName()));
UUID numOfStartedComputationsId = UUID.randomUUID();
IgniteAtomicLong numOfStartedComputations = ignite.atomicLong(numOfStartedComputationsId.toString(), 0, true);
UUID computationsLockId = UUID.randomUUID();
IgniteLock computationsLock = ignite.reentrantLock(computationsLockId.toString(), false, true, true);
// lock computations lock to stop computations in the middle
computationsLock.lock();
try {
new Thread(() -> dataset.computeWithCtx((ctx, data, partIndex) -> {
// track number of started computations
ignite.atomicLong(numOfStartedComputationsId.toString(), 0, false).incrementAndGet();
ignite.reentrantLock(computationsLockId.toString(), false, true, false).lock();
ignite.reentrantLock(computationsLockId.toString(), false, true, false).unlock();
})).start();
while (numOfStartedComputations.get() < partitions) {
}
assertTrue("During computation all partitions should be reserved", areAllPartitionsReserved(upstreamCache.getName(), dataset.getDatasetCache().getName()));
} finally {
computationsLock.unlock();
}
assertTrue("All partitions should be released", areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName()));
}
use of org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData in project ignite by apache.
the class CacheBasedDatasetTest method testPartitionExchangeDuringComputeCall.
/**
* Tests that partitions of the upstream cache and the partition {@code context} cache are reserved during
* computations on dataset. Reservation means that partitions won't be unloaded from the node before computation is
* completed.
*/
@Test
public void testPartitionExchangeDuringComputeCall() {
int partitions = 4;
IgniteCache<Integer, String> upstreamCache = generateTestData(4, 0);
CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache);
CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset = builder.build(TestUtils.testEnvBuilder(), (env, upstream, upstreamSize) -> upstreamSize, (env, upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0), TestUtils.testEnvBuilder().buildForTrainer());
assertEquals("Upstream cache name from dataset", upstreamCache.getName(), dataset.getUpstreamCache().getName());
assertTrue("Before computation all partitions should not be reserved", areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName()));
UUID numOfStartedComputationsId = UUID.randomUUID();
IgniteAtomicLong numOfStartedComputations = ignite.atomicLong(numOfStartedComputationsId.toString(), 0, true);
UUID computationsLockId = UUID.randomUUID();
IgniteLock computationsLock = ignite.reentrantLock(computationsLockId.toString(), false, true, true);
// lock computations lock to stop computations in the middle
computationsLock.lock();
try {
new Thread(() -> dataset.compute((data, partIndex) -> {
// track number of started computations
ignite.atomicLong(numOfStartedComputationsId.toString(), 0, false).incrementAndGet();
ignite.reentrantLock(computationsLockId.toString(), false, true, false).lock();
ignite.reentrantLock(computationsLockId.toString(), false, true, false).unlock();
})).start();
while (numOfStartedComputations.get() < partitions) {
}
assertTrue("During computation all partitions should be reserved", areAllPartitionsReserved(upstreamCache.getName(), dataset.getDatasetCache().getName()));
} finally {
computationsLock.unlock();
}
assertTrue("All partitions should be released", areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName()));
}
Aggregations