use of org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData in project ignite by apache.
the class AlgorithmSpecificDatasetExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws Exception {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Algorithm Specific Dataset example started.");
IgniteCache<Integer, Person> persons = createCache(ignite);
// labels are extracted, and partition data and context are created.
try (AlgorithmSpecificDataset dataset = DatasetFactory.create(ignite, persons, (upstream, upstreamSize) -> new AlgorithmSpecificPartitionContext(), new SimpleLabeledDatasetDataBuilder<Integer, Person, AlgorithmSpecificPartitionContext>((k, v) -> new double[] { v.getAge() }, (k, v) -> v.getSalary(), 1).andThen((data, ctx) -> {
double[] features = data.getFeatures();
int rows = data.getRows();
// Makes a copy of features to supplement it by columns with values equal to 1.0.
double[] a = new double[features.length + rows];
for (int i = 0; i < rows; i++) a[i] = 1.0;
System.arraycopy(features, 0, a, rows, features.length);
return new SimpleLabeledDatasetData(a, rows, data.getCols() + 1, data.getLabels());
})).wrap(AlgorithmSpecificDataset::new)) {
// Trains linear regression model using gradient descent.
double[] linearRegressionMdl = new double[2];
for (int i = 0; i < 1000; i++) {
double[] gradient = dataset.gradient(linearRegressionMdl);
if (BLAS.getInstance().dnrm2(gradient.length, gradient, 1) < 1e-4)
break;
for (int j = 0; j < gradient.length; j++) linearRegressionMdl[j] -= 0.1 / persons.size() * gradient[j];
}
System.out.println("Linear Regression Model: " + Arrays.toString(linearRegressionMdl));
}
System.out.println(">>> Algorithm Specific Dataset example completed.");
}
}
Aggregations