Search in sources :

Example 56 with DenseVector

use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.

the class LinearRegressionLSQRTrainer method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
    LSQRResult res;
    PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(LinearRegressionLSQRTrainer::extendLabeledVector, extractor);
    try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(datasetBuilder, envBuilder, new SimpleLabeledDatasetDataBuilder<>(patchedPreprocessor), learningEnvironment())) {
        double[] x0 = null;
        if (mdl != null) {
            int x0Size = mdl.weights().size() + 1;
            Vector weights = mdl.weights().like(x0Size);
            mdl.weights().nonZeroes().forEach(ith -> weights.set(ith.index(), ith.get()));
            weights.set(weights.size() - 1, mdl.intercept());
            x0 = weights.asArray();
        }
        res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, x0);
        if (res == null)
            return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    double[] x = res.getX();
    Vector weights = new DenseVector(Arrays.copyOfRange(x, 0, x.length - 1));
    return new LinearRegressionModel(weights, x[x.length - 1]);
}
Also used : LSQROnHeap(org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap) LSQRResult(org.apache.ignite.ml.math.isolve.lsqr.LSQRResult) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 57 with DenseVector

use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.

the class LogisticRegressionSGDTrainer method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
    IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
        Integer cols = dataset.compute(data -> {
            if (data.getFeatures() == null)
                return null;
            return data.getFeatures().length / data.getRows();
        }, (a, b) -> {
            // If both are null then zero will be propagated, no good.
            if (a == null)
                return b;
            return a;
        });
        if (cols == null)
            throw new IllegalStateException("Cannot train on empty dataset");
        MLPArchitecture architecture = new MLPArchitecture(cols);
        architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
        return architecture;
    };
    MLPTrainer<?> trainer = new MLPTrainer<>(archSupplier, LossFunctions.L2, updatesStgy, maxIterations, batchSize, locIterations, seed).withEnvironmentBuilder(envBuilder);
    MultilayerPerceptron mlp;
    IgniteFunction<LabeledVector<Double>, LabeledVector<double[]>> func = lv -> new LabeledVector<>(lv.features(), new double[] { lv.label() });
    PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(func, extractor);
    if (mdl != null) {
        mlp = restoreMLPState(mdl);
        mlp = trainer.update(mlp, datasetBuilder, patchedPreprocessor);
    } else
        mlp = trainer.fit(datasetBuilder, patchedPreprocessor);
    double[] params = mlp.parameters().getStorage().data();
    return new LogisticRegressionModel(new DenseVector(Arrays.copyOf(params, params.length - 1)), params[params.length - 1]);
}
Also used : SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) Arrays(java.util.Arrays) Activators(org.apache.ignite.ml.nn.Activators) SimpleGDParameterUpdate(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate) UpdatesStrategy(org.apache.ignite.ml.nn.UpdatesStrategy) SimpleLabeledDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Dataset(org.apache.ignite.ml.dataset.Dataset) SingleLabelDatasetTrainer(org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer) LossFunctions(org.apache.ignite.ml.optimization.LossFunctions) MultilayerPerceptron(org.apache.ignite.ml.nn.MultilayerPerceptron) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) NotNull(org.jetbrains.annotations.NotNull) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) MLPTrainer(org.apache.ignite.ml.nn.MLPTrainer) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) Dataset(org.apache.ignite.ml.dataset.Dataset) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) MultilayerPerceptron(org.apache.ignite.ml.nn.MultilayerPerceptron) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 58 with DenseVector

use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.

the class Deltas method calculateUpdates.

/**
 */
private Vector calculateUpdates(Vector weights, Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset) {
    return dataset.compute(data -> {
        Vector copiedWeights = weights.copy();
        Vector deltaWeights = initializeWeightsWithZeros(weights.size());
        final int amountOfObservation = data.rowSize();
        Vector tmpAlphas = initializeWeightsWithZeros(amountOfObservation);
        Vector deltaAlphas = initializeWeightsWithZeros(amountOfObservation);
        Random random = new Random(seed);
        for (int i = 0; i < this.getAmountOfLocIterations(); i++) {
            int randomIdx = random.nextInt(amountOfObservation);
            Deltas deltas = getDeltas(data, copiedWeights, amountOfObservation, tmpAlphas, randomIdx);
            // creates new vector
            copiedWeights = copiedWeights.plus(deltas.deltaWeights);
            // creates new vector
            deltaWeights = deltaWeights.plus(deltas.deltaWeights);
            tmpAlphas.set(randomIdx, tmpAlphas.get(randomIdx) + deltas.deltaAlpha);
            deltaAlphas.set(randomIdx, deltaAlphas.get(randomIdx) + deltas.deltaAlpha);
        }
        return deltaWeights;
    }, (a, b) -> {
        if (a == null)
            return b == null ? new DenseVector() : b;
        if (b == null)
            return a;
        return a.plus(b);
    });
}
Also used : Random(java.util.Random) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 59 with DenseVector

use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.

the class SVMModelTest method testPredictWithErasedLabels.

/**
 */
@Test
public void testPredictWithErasedLabels() {
    Vector weights = new DenseVector(new double[] { 1.0, 1.0 });
    SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0);
    Vector observation = new DenseVector(new double[] { 1.0, 1.0 });
    TestUtils.assertEquals(1.0, mdl.predict(observation), PRECISION);
    observation = new DenseVector(new double[] { 3.0, 4.0 });
    TestUtils.assertEquals(1.0, mdl.predict(observation), PRECISION);
    observation = new DenseVector(new double[] { -1.0, -1.0 });
    TestUtils.assertEquals(0.0, mdl.predict(observation), PRECISION);
    observation = new DenseVector(new double[] { -2.0, 1.0 });
    TestUtils.assertEquals(0.0, mdl.predict(observation), PRECISION);
    observation = new DenseVector(new double[] { -1.0, -2.0 });
    TestUtils.assertEquals(0.0, mdl.predict(observation), PRECISION);
    final SVMLinearClassificationModel mdlWithNewData = mdl.withIntercept(-2.0).withWeights(new DenseVector(new double[] { -2.0, -2.0 }));
    System.out.println("The SVM model is " + mdlWithNewData);
    observation = new DenseVector(new double[] { -1.0, -2.0 });
    TestUtils.assertEquals(1.0, mdl.predict(observation), PRECISION);
    TestUtils.assertEquals(-2.0, mdl.intercept(), PRECISION);
}
Also used : Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) Test(org.junit.Test)

Example 60 with DenseVector

use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.

the class SVMModelTest method testPredictWithRawLabels.

/**
 */
@Test
public void testPredictWithRawLabels() {
    Vector weights = new DenseVector(new double[] { 2.0, 3.0 });
    SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0).withRawLabels(true);
    Vector observation = new DenseVector(new double[] { 1.0, 1.0 });
    TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 1.0, mdl.predict(observation), PRECISION);
    observation = new DenseVector(new double[] { 2.0, 1.0 });
    TestUtils.assertEquals(1.0 + 2.0 * 2.0 + 3.0 * 1.0, mdl.predict(observation), PRECISION);
    observation = new DenseVector(new double[] { 1.0, 2.0 });
    TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 2.0, mdl.predict(observation), PRECISION);
    observation = new DenseVector(new double[] { -2.0, 1.0 });
    TestUtils.assertEquals(1.0 - 2.0 * 2.0 + 3.0 * 1.0, mdl.predict(observation), PRECISION);
    observation = new DenseVector(new double[] { 1.0, -2.0 });
    TestUtils.assertEquals(1.0 + 2.0 * 1.0 - 3.0 * 2.0, mdl.predict(observation), PRECISION);
    Assert.assertTrue(mdl.isKeepingRawLabels());
    Assert.assertTrue(!mdl.toString().isEmpty());
    Assert.assertTrue(!mdl.toString(true).isEmpty());
    Assert.assertTrue(!mdl.toString(false).isEmpty());
}
Also used : Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) Test(org.junit.Test)

Aggregations

DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)101 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)59 Test (org.junit.Test)59 Serializable (java.io.Serializable)16 SparseVector (org.apache.ignite.ml.math.primitives.vector.impl.SparseVector)14 HashMap (java.util.HashMap)13 DenseMatrix (org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix)13 DummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer)10 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)10 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)9 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)9 HashSet (java.util.HashSet)7 TrainerTest (org.apache.ignite.ml.common.TrainerTest)7 KMeansModel (org.apache.ignite.ml.clustering.kmeans.KMeansModel)5 LocalDatasetBuilder (org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder)5 EuclideanDistance (org.apache.ignite.ml.math.distances.EuclideanDistance)5 IgniteDifferentiableVectorToDoubleFunction (org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction)5 MLPArchitecture (org.apache.ignite.ml.nn.architecture.MLPArchitecture)5 OneHotEncoderPreprocessor (org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor)4 Random (java.util.Random)3