Search in sources :

Example 1 with SimpleDatasetData

use of org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData in project ignite by apache.

the class SimpleDatasetDataBuilder method build.

/**
 * {@inheritDoc}
 */
@Override
public SimpleDatasetData build(LearningEnvironment env, Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
    // Prepares the matrix of features in flat column-major format.
    int cols = -1;
    double[] features = null;
    int ptr = 0;
    while (upstreamData.hasNext()) {
        UpstreamEntry<K, V> entry = upstreamData.next();
        Vector row = preprocessor.apply(entry.getKey(), entry.getValue()).features();
        if (row.isNumeric()) {
            if (cols < 0) {
                cols = row.size();
                features = new double[Math.toIntExact(upstreamDataSize * cols)];
            } else
                assert row.size() == cols : "Feature extractor must return exactly " + cols + " features";
            for (int i = 0; i < cols; i++) features[Math.toIntExact(i * upstreamDataSize + ptr)] = row.get(i);
            ptr++;
        } else
            throw new NonDoubleVectorException(row);
    }
    return new SimpleDatasetData(features, Math.toIntExact(upstreamDataSize));
}
Also used : NonDoubleVectorException(org.apache.ignite.ml.math.exceptions.preprocessing.NonDoubleVectorException) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 2 with SimpleDatasetData

use of org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData in project ignite by apache.

the class DataStreamGeneratorFillCacheTest method testCacheFilling.

/**
 */
@Test
public void testCacheFilling() {
    IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
    String cacheName = "TEST_CACHE";
    CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
    int datasetSize = 5000;
    try (Ignite ignite = Ignition.start(configuration)) {
        IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
        DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
        generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
        LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
        CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
        IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
        LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
        env.deployingContext().initByClientObject(map);
        try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
            StatPair res = dataset.compute(map, StatPair::sum);
            assertEquals(datasetSize, res.cntOfRows);
            assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
        }
        ignite.destroyCache(cacheName);
    }
}
Also used : Arrays(java.util.Arrays) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) EmptyContextBuilder(org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) SimpleDatasetDataBuilder(org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) TcpDiscoveryVmIpFinder(org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder) Test(org.junit.Test) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) IgniteCache(org.apache.ignite.IgniteCache) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) DoubleStream(java.util.stream.DoubleStream) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) CacheBasedDataset(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset) TcpDiscoverySpi(org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) TcpDiscoveryVmIpFinder(org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) Ignite(org.apache.ignite.Ignite) UUID(java.util.UUID) TcpDiscoverySpi(org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) Test(org.junit.Test)

Example 3 with SimpleDatasetData

use of org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData in project ignite by apache.

the class CacheBasedDatasetTest method testPartitionExchangeDuringComputeWithCtxCall.

/**
 * Tests that partitions of the upstream cache and the partition {@code context} cache are reserved during
 * computations on dataset. Reservation means that partitions won't be unloaded from the node before computation is
 * completed.
 */
@Test
public void testPartitionExchangeDuringComputeWithCtxCall() {
    int partitions = 4;
    IgniteCache<Integer, String> upstreamCache = generateTestData(4, 0);
    CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache);
    CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset = builder.build(TestUtils.testEnvBuilder(), (env, upstream, upstreamSize) -> upstreamSize, (env, upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0), TestUtils.testEnvBuilder().buildForTrainer());
    assertTrue("Before computation all partitions should not be reserved", areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName()));
    UUID numOfStartedComputationsId = UUID.randomUUID();
    IgniteAtomicLong numOfStartedComputations = ignite.atomicLong(numOfStartedComputationsId.toString(), 0, true);
    UUID computationsLockId = UUID.randomUUID();
    IgniteLock computationsLock = ignite.reentrantLock(computationsLockId.toString(), false, true, true);
    // lock computations lock to stop computations in the middle
    computationsLock.lock();
    try {
        new Thread(() -> dataset.computeWithCtx((ctx, data, partIndex) -> {
            // track number of started computations
            ignite.atomicLong(numOfStartedComputationsId.toString(), 0, false).incrementAndGet();
            ignite.reentrantLock(computationsLockId.toString(), false, true, false).lock();
            ignite.reentrantLock(computationsLockId.toString(), false, true, false).unlock();
        })).start();
        while (numOfStartedComputations.get() < partitions) {
        }
        assertTrue("During computation all partitions should be reserved", areAllPartitionsReserved(upstreamCache.getName(), dataset.getDatasetCache().getName()));
    } finally {
        computationsLock.unlock();
    }
    assertTrue("All partitions should be released", areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName()));
}
Also used : IgniteAtomicLong(org.apache.ignite.IgniteAtomicLong) IgniteAtomicLong(org.apache.ignite.IgniteAtomicLong) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) IgniteLock(org.apache.ignite.IgniteLock) UUID(java.util.UUID) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) Test(org.junit.Test)

Example 4 with SimpleDatasetData

use of org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData in project ignite by apache.

the class CacheBasedDatasetTest method testPartitionExchangeDuringComputeCall.

/**
 * Tests that partitions of the upstream cache and the partition {@code context} cache are reserved during
 * computations on dataset. Reservation means that partitions won't be unloaded from the node before computation is
 * completed.
 */
@Test
public void testPartitionExchangeDuringComputeCall() {
    int partitions = 4;
    IgniteCache<Integer, String> upstreamCache = generateTestData(4, 0);
    CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache);
    CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset = builder.build(TestUtils.testEnvBuilder(), (env, upstream, upstreamSize) -> upstreamSize, (env, upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0), TestUtils.testEnvBuilder().buildForTrainer());
    assertEquals("Upstream cache name from dataset", upstreamCache.getName(), dataset.getUpstreamCache().getName());
    assertTrue("Before computation all partitions should not be reserved", areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName()));
    UUID numOfStartedComputationsId = UUID.randomUUID();
    IgniteAtomicLong numOfStartedComputations = ignite.atomicLong(numOfStartedComputationsId.toString(), 0, true);
    UUID computationsLockId = UUID.randomUUID();
    IgniteLock computationsLock = ignite.reentrantLock(computationsLockId.toString(), false, true, true);
    // lock computations lock to stop computations in the middle
    computationsLock.lock();
    try {
        new Thread(() -> dataset.compute((data, partIndex) -> {
            // track number of started computations
            ignite.atomicLong(numOfStartedComputationsId.toString(), 0, false).incrementAndGet();
            ignite.reentrantLock(computationsLockId.toString(), false, true, false).lock();
            ignite.reentrantLock(computationsLockId.toString(), false, true, false).unlock();
        })).start();
        while (numOfStartedComputations.get() < partitions) {
        }
        assertTrue("During computation all partitions should be reserved", areAllPartitionsReserved(upstreamCache.getName(), dataset.getDatasetCache().getName()));
    } finally {
        computationsLock.unlock();
    }
    assertTrue("All partitions should be released", areAllPartitionsNotReserved(upstreamCache.getName(), dataset.getDatasetCache().getName()));
}
Also used : IgniteAtomicLong(org.apache.ignite.IgniteAtomicLong) IgniteAtomicLong(org.apache.ignite.IgniteAtomicLong) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) IgniteLock(org.apache.ignite.IgniteLock) UUID(java.util.UUID) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) Test(org.junit.Test)

Aggregations

SimpleDatasetData (org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData)4 UUID (java.util.UUID)3 GridCommonAbstractTest (org.apache.ignite.testframework.junits.common.GridCommonAbstractTest)3 Test (org.junit.Test)3 IgniteAtomicLong (org.apache.ignite.IgniteAtomicLong)2 IgniteLock (org.apache.ignite.IgniteLock)2 Arrays (java.util.Arrays)1 DoubleStream (java.util.stream.DoubleStream)1 Ignite (org.apache.ignite.Ignite)1 IgniteCache (org.apache.ignite.IgniteCache)1 Ignition (org.apache.ignite.Ignition)1 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)1 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)1 IgniteConfiguration (org.apache.ignite.configuration.IgniteConfiguration)1 LabeledDummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer)1 CacheBasedDataset (org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset)1 CacheBasedDatasetBuilder (org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder)1 EmptyContextBuilder (org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder)1 SimpleDatasetDataBuilder (org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder)1 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)1