Search in sources :

Example 1 with DatasetBuilder

use of org.apache.ignite.ml.dataset.DatasetBuilder in project ignite by apache.

the class LinearRegressionLSQRTrainer method fit.

/**
 * {@inheritDoc}
 */
@Override
public LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
    LSQRResult res;
    try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(datasetBuilder, new LinSysPartitionDataBuilderOnHeap<>((k, v) -> {
        double[] row = Arrays.copyOf(featureExtractor.apply(k, v), cols + 1);
        row[cols] = 1.0;
        return row;
    }, lbExtractor, cols + 1))) {
        res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    Vector weights = new DenseLocalOnHeapVector(Arrays.copyOfRange(res.getX(), 0, cols));
    return new LinearRegressionModel(weights, res.getX()[cols]);
}
Also used : DatasetTrainer(org.apache.ignite.ml.DatasetTrainer) Arrays(java.util.Arrays) Vector(org.apache.ignite.ml.math.Vector) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) LinSysPartitionDataBuilderOnHeap(org.apache.ignite.ml.math.isolve.LinSysPartitionDataBuilderOnHeap) AbstractLSQR(org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR) LSQROnHeap(org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap) LSQRResult(org.apache.ignite.ml.math.isolve.lsqr.LSQRResult) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) LSQROnHeap(org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap) LSQRResult(org.apache.ignite.ml.math.isolve.lsqr.LSQRResult) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)

Example 2 with DatasetBuilder

use of org.apache.ignite.ml.dataset.DatasetBuilder in project ignite by apache.

the class NormalizationExample method main.

/**
 * Run example.
 */
public static void main(String[] args) throws Exception {
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Normalization example started.");
        IgniteCache<Integer, Person> persons = createCache(ignite);
        DatasetBuilder<Integer, Person> builder = new CacheBasedDatasetBuilder<>(ignite, persons);
        // Defines first preprocessor that extracts features from an upstream data.
        IgniteBiFunction<Integer, Person, double[]> featureExtractor = (k, v) -> new double[] { v.getAge(), v.getSalary() };
        // Defines second preprocessor that normalizes features.
        NormalizationPreprocessor<Integer, Person> preprocessor = new NormalizationTrainer<Integer, Person>().fit(builder, featureExtractor, 2);
        // Creates a cache based simple dataset containing features and providing standard dataset API.
        try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(builder, preprocessor, 2)) {
            // Calculation of the mean value. This calculation will be performed in map-reduce manner.
            double[] mean = dataset.mean();
            System.out.println("Mean \n\t" + Arrays.toString(mean));
            // Calculation of the standard deviation. This calculation will be performed in map-reduce manner.
            double[] std = dataset.std();
            System.out.println("Standard deviation \n\t" + Arrays.toString(std));
            // Calculation of the covariance matrix.  This calculation will be performed in map-reduce manner.
            double[][] cov = dataset.cov();
            System.out.println("Covariance matrix ");
            for (double[] row : cov) System.out.println("\t" + Arrays.toString(row));
            // Calculation of the correlation matrix.  This calculation will be performed in map-reduce manner.
            double[][] corr = dataset.corr();
            System.out.println("Correlation matrix ");
            for (double[] row : corr) System.out.println("\t" + Arrays.toString(row));
        }
        System.out.println(">>> Normalization example completed.");
    }
}
Also used : Arrays(java.util.Arrays) Ignite(org.apache.ignite.Ignite) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) IgniteCache(org.apache.ignite.IgniteCache) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) Ignition(org.apache.ignite.Ignition) SimpleDataset(org.apache.ignite.ml.dataset.primitive.SimpleDataset) DatasetFactory(org.apache.ignite.ml.dataset.DatasetFactory) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) NormalizationPreprocessor(org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor) Person(org.apache.ignite.examples.ml.dataset.model.Person) NormalizationTrainer(org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) Ignite(org.apache.ignite.Ignite) Person(org.apache.ignite.examples.ml.dataset.model.Person)

Aggregations

Arrays (java.util.Arrays)2 DatasetBuilder (org.apache.ignite.ml.dataset.DatasetBuilder)2 IgniteBiFunction (org.apache.ignite.ml.math.functions.IgniteBiFunction)2 Ignite (org.apache.ignite.Ignite)1 IgniteCache (org.apache.ignite.IgniteCache)1 Ignition (org.apache.ignite.Ignition)1 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)1 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)1 Person (org.apache.ignite.examples.ml.dataset.model.Person)1 DatasetTrainer (org.apache.ignite.ml.DatasetTrainer)1 DatasetFactory (org.apache.ignite.ml.dataset.DatasetFactory)1 CacheBasedDatasetBuilder (org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder)1 SimpleDataset (org.apache.ignite.ml.dataset.primitive.SimpleDataset)1 Vector (org.apache.ignite.ml.math.Vector)1 DenseLocalOnHeapVector (org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)1 LinSysPartitionDataBuilderOnHeap (org.apache.ignite.ml.math.isolve.LinSysPartitionDataBuilderOnHeap)1 AbstractLSQR (org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR)1 LSQROnHeap (org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap)1 LSQRResult (org.apache.ignite.ml.math.isolve.lsqr.LSQRResult)1 NormalizationPreprocessor (org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor)1