use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class DistributedLinearRegressionWithLSQRTrainerExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws InterruptedException {
System.out.println();
System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
// Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
// because we create ignite cache internally.
IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), SparseDistributedMatrixExample.class.getSimpleName(), () -> {
IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
System.out.println(">>> Create new linear regression trainer object.");
LinearRegressionLSQRTrainer<Integer, double[]> trainer = new LinearRegressionLSQRTrainer<>();
System.out.println(">>> Perform the training to get the model.");
LinearRegressionModel mdl = trainer.fit(new CacheBasedDatasetBuilder<>(ignite, dataCache), (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0], 4);
System.out.println(">>> Linear regression model: " + mdl);
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, double[]> observation : observations) {
double[] val = observation.getValue();
double[] inputs = Arrays.copyOfRange(val, 1, val.length);
double groundTruth = val[0];
double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
}
System.out.println(">>> ---------------------------------");
});
igniteThread.start();
igniteThread.join();
}
}
use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class DecisionTreesExample method main.
/**
* Launches example.
*
* @param args Program arguments.
*/
public static void main(String[] args) throws IOException {
System.out.println(">>> Decision trees example started.");
String igniteCfgPath;
CommandLineParser parser = new BasicParser();
String trainingImagesPath;
String trainingLabelsPath;
String testImagesPath;
String testLabelsPath;
Map<String, String> mnistPaths = new HashMap<>();
mnistPaths.put(MNIST_TRAIN_IMAGES, "train-images-idx3-ubyte");
mnistPaths.put(MNIST_TRAIN_LABELS, "train-labels-idx1-ubyte");
mnistPaths.put(MNIST_TEST_IMAGES, "t10k-images-idx3-ubyte");
mnistPaths.put(MNIST_TEST_LABELS, "t10k-labels-idx1-ubyte");
try {
// Parse the command line arguments.
CommandLine line = parser.parse(buildOptions(), args);
if (line.hasOption(MLExamplesCommonArgs.UNATTENDED)) {
System.out.println(">>> Skipped example execution because 'unattended' mode is used.");
System.out.println(">>> Decision trees example finished.");
return;
}
igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG);
} catch (ParseException e) {
e.printStackTrace();
return;
}
if (!getMNIST(mnistPaths.values())) {
System.out.println(">>> You should have MNIST dataset in " + MNIST_DIR + " to run this example.");
return;
}
trainingImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + mnistPaths.get(MNIST_TRAIN_IMAGES))).getPath();
trainingLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + mnistPaths.get(MNIST_TRAIN_LABELS))).getPath();
testImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + mnistPaths.get(MNIST_TEST_IMAGES))).getPath();
testLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + mnistPaths.get(MNIST_TEST_LABELS))).getPath();
try (Ignite ignite = Ignition.start(igniteCfgPath)) {
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
int ptsCnt = 60000;
int featCnt = 28 * 28;
Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath, new Random(123L), ptsCnt);
Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath, new Random(123L), 10_000);
IgniteCache<BiIndex, Double> cache = createBiIndexedCache(ignite);
loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite);
ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite);
System.out.println(">>> Training started");
long before = System.currentTimeMillis();
DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt));
System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage();
Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
System.out.println(">>> Errs percentage: " + accuracy);
} catch (IOException e) {
e.printStackTrace();
}
System.out.println(">>> Decision trees example finished.");
}
use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class AbstractMultipleLinearRegressionTest method testNewSample.
/**
* Verifies that newSampleData methods consistently insert unitary columns
* in design matrix. Confirms the fix for MATH-411.
*/
@Test
public void testNewSample() {
double[] design = new double[] { 1, 19, 22, 33, 2, 20, 30, 40, 3, 25, 35, 45, 4, 27, 37, 47 };
double[] y = new double[] { 1, 2, 3, 4 };
double[][] x = new double[][] { { 19, 22, 33 }, { 20, 30, 40 }, { 25, 35, 45 }, { 27, 37, 47 } };
AbstractMultipleLinearRegression regression = createRegression();
regression.newSampleData(design, 4, 3, new DenseLocalOnHeapMatrix());
Matrix flatX = regression.getX().copy();
Vector flatY = regression.getY().copy();
regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
regression.newYSampleData(new DenseLocalOnHeapVector(y));
Assert.assertEquals(flatX, regression.getX());
Assert.assertEquals(flatY, regression.getY());
// No intercept
regression.setNoIntercept(true);
regression.newSampleData(design, 4, 3, new DenseLocalOnHeapMatrix());
flatX = regression.getX().copy();
flatY = regression.getY().copy();
regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
regression.newYSampleData(new DenseLocalOnHeapVector(y));
Assert.assertEquals(flatX, regression.getX());
Assert.assertEquals(flatY, regression.getY());
}
use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class OLSMultipleLinearRegressionTest method testMathIllegalArgumentException.
/** */
@Test(expected = MathIllegalArgumentException.class)
public void testMathIllegalArgumentException() {
OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
mdl.validateSampleData(new DenseLocalOnHeapMatrix(1, 2), new DenseLocalOnHeapVector(1));
}
use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class OLSMultipleLinearRegressionTest method testNewSample2.
/**
* Verifies that setting X and Y separately has the same effect as newSample(X,Y).
*/
@Test
public void testNewSample2() {
double[] y = new double[] { 1, 2, 3, 4 };
double[][] x = new double[][] { { 19, 22, 33 }, { 20, 30, 40 }, { 25, 35, 45 }, { 27, 37, 47 } };
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
regression.newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x));
Matrix combinedX = regression.getX().copy();
Vector combinedY = regression.getY().copy();
regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
regression.newYSampleData(new DenseLocalOnHeapVector(y));
Assert.assertEquals(combinedX, regression.getX());
Assert.assertEquals(combinedY, regression.getY());
// No intercept
regression.setNoIntercept(true);
regression.newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x));
combinedX = regression.getX().copy();
combinedY = regression.getY().copy();
regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
regression.newYSampleData(new DenseLocalOnHeapVector(y));
Assert.assertEquals(combinedX, regression.getX());
Assert.assertEquals(combinedY, regression.getY());
}
Aggregations