use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class DistributedLinearRegressionWithQRTrainerExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws InterruptedException {
System.out.println();
System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
// Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
// because we create ignite cache internally.
IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), SparseDistributedMatrixExample.class.getSimpleName(), () -> {
// Create SparseDistributedMatrix, new cache will be created automagically.
System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread.");
SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(data);
System.out.println(">>> Create new linear regression trainer object.");
Trainer<LinearRegressionModel, Matrix> trainer = new LinearRegressionQRTrainer();
System.out.println(">>> Perform the training to get the model.");
LinearRegressionModel model = trainer.train(distributedMatrix);
System.out.println(">>> Linear regression model: " + model);
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
for (double[] observation : data) {
Vector inputs = new SparseDistributedVector(Arrays.copyOfRange(observation, 1, observation.length));
double prediction = model.apply(inputs);
double groundTruth = observation[0];
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
System.out.println(">>> ---------------------------------");
});
igniteThread.start();
igniteThread.join();
}
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class LinearRegressionQRTrainer method train.
/**
* {@inheritDoc}
*/
@Override
public LinearRegressionModel train(Matrix data) {
Vector groundTruth = extractGroundTruth(data);
Matrix inputs = extractInputs(data);
QRDecomposition decomposition = new QRDecomposition(inputs);
QRDSolver solver = new QRDSolver(decomposition.getQ(), decomposition.getR());
Vector variables = solver.solve(groundTruth);
Vector weights = variables.viewPart(1, variables.size() - 1);
double intercept = variables.get(0);
return new LinearRegressionModel(weights, intercept);
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class LocalBatchTrainer method train.
/**
* {@inheritDoc}
*/
@Override
public M train(LocalBatchTrainerInput<M> data) {
int i = 0;
M mdl = data.mdl();
double err;
ParameterUpdateCalculator<? super M, P> updater = updaterSupplier.get();
P updaterParams = updater.init(mdl, loss);
while (i < maxIterations) {
IgniteBiTuple<Matrix, Matrix> batch = data.batchSupplier().get();
Matrix input = batch.get1();
Matrix truth = batch.get2();
updaterParams = updater.calculateNewUpdate(mdl, updaterParams, i, input, truth);
// Update mdl with updater parameters.
mdl = updater.update(mdl, updaterParams);
Matrix predicted = mdl.apply(input);
int batchSize = input.columnSize();
err = MatrixUtil.zipFoldByColumns(predicted, truth, (predCol, truthCol) -> loss.apply(truthCol).apply(predCol)).sum() / batchSize;
debug("Error: " + err);
if (err < errorThreshold)
break;
i++;
}
return mdl;
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class TestUtils method assertEquals.
/**
* Verifies that two matrices are close (1-norm).
*
* @param msg The identifying message for the assertion error.
* @param exp Expected matrix.
* @param actual Actual matrix.
* @param tolerance Comparison tolerance value.
*/
public static void assertEquals(String msg, Matrix exp, Matrix actual, double tolerance) {
Assert.assertNotNull(msg + "\nObserved should not be null", actual);
if (exp.columnSize() != actual.columnSize() || exp.rowSize() != actual.rowSize()) {
String msgBuff = msg + "\nObserved has incorrect dimensions." + "\nobserved is " + actual.rowSize() + " x " + actual.columnSize() + "\nexpected " + exp.rowSize() + " x " + exp.columnSize();
Assert.fail(msgBuff);
}
Matrix delta = exp.minus(actual);
if (TestUtils.maximumAbsoluteRowSum(delta) >= tolerance) {
String msgBuff = msg + "\nExpected: " + exp + "\nObserved: " + actual + "\nexpected - observed: " + delta;
Assert.fail(msgBuff);
}
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class VectorToMatrixTest method assertCross.
/**
*/
private void assertCross(Vector v1, Vector v2, String desc) {
assertNotNull(v1);
assertNotNull(v2);
final Matrix res = v1.cross(v2);
assertNotNull("Cross matrix is expected to be not null in " + desc, res);
assertEquals("Unexpected number of rows in cross Matrix in " + desc, v1.size(), res.rowSize());
assertEquals("Unexpected number of cols in cross Matrix in " + desc, v2.size(), res.columnSize());
for (int row = 0; row < v1.size(); row++) for (int col = 0; col < v2.size(); col++) {
final Metric metric = new Metric(v1.get(row) * v2.get(col), res.get(row, col));
assertTrue("Not close enough cross " + metric + " at row " + row + " at col " + col + " in " + desc, metric.closeEnough());
}
}
Aggregations