use of org.apache.ignite.ml.math.isolve.LinSysPartitionDataBuilderOnHeap in project ignite by apache.
the class LinearRegressionLSQRTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
LSQRResult res;
try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(datasetBuilder, new LinSysPartitionDataBuilderOnHeap<>((k, v) -> {
double[] row = Arrays.copyOf(featureExtractor.apply(k, v), cols + 1);
row[cols] = 1.0;
return row;
}, lbExtractor, cols + 1))) {
res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
} catch (Exception e) {
throw new RuntimeException(e);
}
Vector weights = new DenseLocalOnHeapVector(Arrays.copyOfRange(res.getX(), 0, cols));
return new LinearRegressionModel(weights, res.getX()[cols]);
}
Aggregations