Search in sources :

Example 1 with LSQROnHeap

use of org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap in project ignite by apache.

the class LinearRegressionLSQRTrainer method fit.

/**
 * {@inheritDoc}
 */
@Override
public LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
    LSQRResult res;
    try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(datasetBuilder, new LinSysPartitionDataBuilderOnHeap<>((k, v) -> {
        double[] row = Arrays.copyOf(featureExtractor.apply(k, v), cols + 1);
        row[cols] = 1.0;
        return row;
    }, lbExtractor, cols + 1))) {
        res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    Vector weights = new DenseLocalOnHeapVector(Arrays.copyOfRange(res.getX(), 0, cols));
    return new LinearRegressionModel(weights, res.getX()[cols]);
}
Also used : DatasetTrainer(org.apache.ignite.ml.DatasetTrainer) Arrays(java.util.Arrays) Vector(org.apache.ignite.ml.math.Vector) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) LinSysPartitionDataBuilderOnHeap(org.apache.ignite.ml.math.isolve.LinSysPartitionDataBuilderOnHeap) AbstractLSQR(org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR) LSQROnHeap(org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap) LSQRResult(org.apache.ignite.ml.math.isolve.lsqr.LSQRResult) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) LSQROnHeap(org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap) LSQRResult(org.apache.ignite.ml.math.isolve.lsqr.LSQRResult) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)

Example 2 with LSQROnHeap

use of org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap in project ignite by apache.

the class LinearRegressionLSQRTrainer method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
    LSQRResult res;
    PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(LinearRegressionLSQRTrainer::extendLabeledVector, extractor);
    try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(datasetBuilder, envBuilder, new SimpleLabeledDatasetDataBuilder<>(patchedPreprocessor), learningEnvironment())) {
        double[] x0 = null;
        if (mdl != null) {
            int x0Size = mdl.weights().size() + 1;
            Vector weights = mdl.weights().like(x0Size);
            mdl.weights().nonZeroes().forEach(ith -> weights.set(ith.index(), ith.get()));
            weights.set(weights.size() - 1, mdl.intercept());
            x0 = weights.asArray();
        }
        res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, x0);
        if (res == null)
            return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    double[] x = res.getX();
    Vector weights = new DenseVector(Arrays.copyOfRange(x, 0, x.length - 1));
    return new LinearRegressionModel(weights, x[x.length - 1]);
}
Also used : LSQROnHeap(org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap) LSQRResult(org.apache.ignite.ml.math.isolve.lsqr.LSQRResult) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Aggregations

LSQROnHeap (org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap)2 LSQRResult (org.apache.ignite.ml.math.isolve.lsqr.LSQRResult)2 Arrays (java.util.Arrays)1 DatasetTrainer (org.apache.ignite.ml.DatasetTrainer)1 DatasetBuilder (org.apache.ignite.ml.dataset.DatasetBuilder)1 Vector (org.apache.ignite.ml.math.Vector)1 IgniteBiFunction (org.apache.ignite.ml.math.functions.IgniteBiFunction)1 DenseLocalOnHeapVector (org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)1 LinSysPartitionDataBuilderOnHeap (org.apache.ignite.ml.math.isolve.LinSysPartitionDataBuilderOnHeap)1 AbstractLSQR (org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR)1 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)1 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)1 PatchedPreprocessor (org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor)1 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)1