use of org.apache.ignite.ml.dataset.DatasetBuilder in project ignite by apache.
the class LinearRegressionLSQRTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
LSQRResult res;
try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(datasetBuilder, new LinSysPartitionDataBuilderOnHeap<>((k, v) -> {
double[] row = Arrays.copyOf(featureExtractor.apply(k, v), cols + 1);
row[cols] = 1.0;
return row;
}, lbExtractor, cols + 1))) {
res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
} catch (Exception e) {
throw new RuntimeException(e);
}
Vector weights = new DenseLocalOnHeapVector(Arrays.copyOfRange(res.getX(), 0, cols));
return new LinearRegressionModel(weights, res.getX()[cols]);
}
use of org.apache.ignite.ml.dataset.DatasetBuilder in project ignite by apache.
the class NormalizationExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws Exception {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Normalization example started.");
IgniteCache<Integer, Person> persons = createCache(ignite);
DatasetBuilder<Integer, Person> builder = new CacheBasedDatasetBuilder<>(ignite, persons);
// Defines first preprocessor that extracts features from an upstream data.
IgniteBiFunction<Integer, Person, double[]> featureExtractor = (k, v) -> new double[] { v.getAge(), v.getSalary() };
// Defines second preprocessor that normalizes features.
NormalizationPreprocessor<Integer, Person> preprocessor = new NormalizationTrainer<Integer, Person>().fit(builder, featureExtractor, 2);
// Creates a cache based simple dataset containing features and providing standard dataset API.
try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(builder, preprocessor, 2)) {
// Calculation of the mean value. This calculation will be performed in map-reduce manner.
double[] mean = dataset.mean();
System.out.println("Mean \n\t" + Arrays.toString(mean));
// Calculation of the standard deviation. This calculation will be performed in map-reduce manner.
double[] std = dataset.std();
System.out.println("Standard deviation \n\t" + Arrays.toString(std));
// Calculation of the covariance matrix. This calculation will be performed in map-reduce manner.
double[][] cov = dataset.cov();
System.out.println("Covariance matrix ");
for (double[] row : cov) System.out.println("\t" + Arrays.toString(row));
// Calculation of the correlation matrix. This calculation will be performed in map-reduce manner.
double[][] corr = dataset.corr();
System.out.println("Correlation matrix ");
for (double[] row : corr) System.out.println("\t" + Arrays.toString(row));
}
System.out.println(">>> Normalization example completed.");
}
}
Aggregations