use of org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap in project ignite by apache.
the class LinearRegressionLSQRTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
LSQRResult res;
try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(datasetBuilder, new LinSysPartitionDataBuilderOnHeap<>((k, v) -> {
double[] row = Arrays.copyOf(featureExtractor.apply(k, v), cols + 1);
row[cols] = 1.0;
return row;
}, lbExtractor, cols + 1))) {
res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
} catch (Exception e) {
throw new RuntimeException(e);
}
Vector weights = new DenseLocalOnHeapVector(Arrays.copyOfRange(res.getX(), 0, cols));
return new LinearRegressionModel(weights, res.getX()[cols]);
}
use of org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap in project ignite by apache.
the class LinearRegressionLSQRTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
LSQRResult res;
PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(LinearRegressionLSQRTrainer::extendLabeledVector, extractor);
try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(datasetBuilder, envBuilder, new SimpleLabeledDatasetDataBuilder<>(patchedPreprocessor), learningEnvironment())) {
double[] x0 = null;
if (mdl != null) {
int x0Size = mdl.weights().size() + 1;
Vector weights = mdl.weights().like(x0Size);
mdl.weights().nonZeroes().forEach(ith -> weights.set(ith.index(), ith.get()));
weights.set(weights.size() - 1, mdl.intercept());
x0 = weights.asArray();
}
res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, x0);
if (res == null)
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
} catch (Exception e) {
throw new RuntimeException(e);
}
double[] x = res.getX();
Vector weights = new DenseVector(Arrays.copyOfRange(x, 0, x.length - 1));
return new LinearRegressionModel(weights, x[x.length - 1]);
}
Aggregations