use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class LinearRegressionLSQRTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
LSQRResult res;
PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(LinearRegressionLSQRTrainer::extendLabeledVector, extractor);
try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(datasetBuilder, envBuilder, new SimpleLabeledDatasetDataBuilder<>(patchedPreprocessor), learningEnvironment())) {
double[] x0 = null;
if (mdl != null) {
int x0Size = mdl.weights().size() + 1;
Vector weights = mdl.weights().like(x0Size);
mdl.weights().nonZeroes().forEach(ith -> weights.set(ith.index(), ith.get()));
weights.set(weights.size() - 1, mdl.intercept());
x0 = weights.asArray();
}
res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, x0);
if (res == null)
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
} catch (Exception e) {
throw new RuntimeException(e);
}
double[] x = res.getX();
Vector weights = new DenseVector(Arrays.copyOfRange(x, 0, x.length - 1));
return new LinearRegressionModel(weights, x[x.length - 1]);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class LogisticRegressionSGDTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
Integer cols = dataset.compute(data -> {
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> {
// If both are null then zero will be propagated, no good.
if (a == null)
return b;
return a;
});
if (cols == null)
throw new IllegalStateException("Cannot train on empty dataset");
MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
return architecture;
};
MLPTrainer<?> trainer = new MLPTrainer<>(archSupplier, LossFunctions.L2, updatesStgy, maxIterations, batchSize, locIterations, seed).withEnvironmentBuilder(envBuilder);
MultilayerPerceptron mlp;
IgniteFunction<LabeledVector<Double>, LabeledVector<double[]>> func = lv -> new LabeledVector<>(lv.features(), new double[] { lv.label() });
PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(func, extractor);
if (mdl != null) {
mlp = restoreMLPState(mdl);
mlp = trainer.update(mlp, datasetBuilder, patchedPreprocessor);
} else
mlp = trainer.fit(datasetBuilder, patchedPreprocessor);
double[] params = mlp.parameters().getStorage().data();
return new LogisticRegressionModel(new DenseVector(Arrays.copyOf(params, params.length - 1)), params[params.length - 1]);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class Deltas method calculateUpdates.
/**
*/
private Vector calculateUpdates(Vector weights, Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset) {
return dataset.compute(data -> {
Vector copiedWeights = weights.copy();
Vector deltaWeights = initializeWeightsWithZeros(weights.size());
final int amountOfObservation = data.rowSize();
Vector tmpAlphas = initializeWeightsWithZeros(amountOfObservation);
Vector deltaAlphas = initializeWeightsWithZeros(amountOfObservation);
Random random = new Random(seed);
for (int i = 0; i < this.getAmountOfLocIterations(); i++) {
int randomIdx = random.nextInt(amountOfObservation);
Deltas deltas = getDeltas(data, copiedWeights, amountOfObservation, tmpAlphas, randomIdx);
// creates new vector
copiedWeights = copiedWeights.plus(deltas.deltaWeights);
// creates new vector
deltaWeights = deltaWeights.plus(deltas.deltaWeights);
tmpAlphas.set(randomIdx, tmpAlphas.get(randomIdx) + deltas.deltaAlpha);
deltaAlphas.set(randomIdx, deltaAlphas.get(randomIdx) + deltas.deltaAlpha);
}
return deltaWeights;
}, (a, b) -> {
if (a == null)
return b == null ? new DenseVector() : b;
if (b == null)
return a;
return a.plus(b);
});
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class SVMModelTest method testPredictWithErasedLabels.
/**
*/
@Test
public void testPredictWithErasedLabels() {
Vector weights = new DenseVector(new double[] { 1.0, 1.0 });
SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0);
Vector observation = new DenseVector(new double[] { 1.0, 1.0 });
TestUtils.assertEquals(1.0, mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { 3.0, 4.0 });
TestUtils.assertEquals(1.0, mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { -1.0, -1.0 });
TestUtils.assertEquals(0.0, mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { -2.0, 1.0 });
TestUtils.assertEquals(0.0, mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { -1.0, -2.0 });
TestUtils.assertEquals(0.0, mdl.predict(observation), PRECISION);
final SVMLinearClassificationModel mdlWithNewData = mdl.withIntercept(-2.0).withWeights(new DenseVector(new double[] { -2.0, -2.0 }));
System.out.println("The SVM model is " + mdlWithNewData);
observation = new DenseVector(new double[] { -1.0, -2.0 });
TestUtils.assertEquals(1.0, mdl.predict(observation), PRECISION);
TestUtils.assertEquals(-2.0, mdl.intercept(), PRECISION);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class SVMModelTest method testPredictWithRawLabels.
/**
*/
@Test
public void testPredictWithRawLabels() {
Vector weights = new DenseVector(new double[] { 2.0, 3.0 });
SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0).withRawLabels(true);
Vector observation = new DenseVector(new double[] { 1.0, 1.0 });
TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 1.0, mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { 2.0, 1.0 });
TestUtils.assertEquals(1.0 + 2.0 * 2.0 + 3.0 * 1.0, mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { 1.0, 2.0 });
TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 2.0, mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { -2.0, 1.0 });
TestUtils.assertEquals(1.0 - 2.0 * 2.0 + 3.0 * 1.0, mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { 1.0, -2.0 });
TestUtils.assertEquals(1.0 + 2.0 * 1.0 - 3.0 * 2.0, mdl.predict(observation), PRECISION);
Assert.assertTrue(mdl.isKeepingRawLabels());
Assert.assertTrue(!mdl.toString().isEmpty());
Assert.assertTrue(!mdl.toString(true).isEmpty());
Assert.assertTrue(!mdl.toString(false).isEmpty());
}
Aggregations