use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class OneHotEncoderPreprocessorTest method testTwoCategorialFeatureAndTwoDoubleFeatures.
/**
*/
@Test
public void testTwoCategorialFeatureAndTwoDoubleFeatures() {
Vector[] data = new Vector[] { new DenseVector(new Serializable[] { "42", 1.0, "M", 2.0 }), new DenseVector(new Serializable[] { "43", 2.0, "F", 3.0 }), new DenseVector(new Serializable[] { "42", 3.0, Double.NaN, 4.0 }), new DenseVector(new Serializable[] { "42", 4.0, "F", 5.0 }) };
Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0, 1, 2, 3);
HashMap[] encodingValues = new HashMap[4];
encodingValues[0] = new HashMap() {
{
put("42", 0);
put("43", 1);
}
};
encodingValues[2] = new HashMap() {
{
put("F", 0);
put("M", 1);
put("", 2);
}
};
OneHotEncoderPreprocessor<Integer, Vector> preprocessor = new OneHotEncoderPreprocessor<Integer, Vector>(encodingValues, vectorizer, new HashSet() {
{
add(0);
add(2);
}
});
double[][] postProcessedData = new double[][] { { 1.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0 }, { 2.0, 3.0, 0.0, 1.0, 1.0, 0.0, 0.0 }, { 3.0, 4.0, 1.0, 0.0, 0.0, 0.0, 1.0 }, { 4.0, 5.0, 1.0, 0.0, 1.0, 0.0, 0.0 } };
for (int i = 0; i < data.length; i++) assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).features().asArray(), 1e-8);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class OneHotEncoderPreprocessorTest method testOneCategorialFeature.
/**
*/
@Test
public void testOneCategorialFeature() {
Vector[] data = new Vector[] { new DenseVector(new Serializable[] { "42" }), new DenseVector(new Serializable[] { "43" }), new DenseVector(new Serializable[] { "42" }) };
Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0);
OneHotEncoderPreprocessor<Integer, Vector> preprocessor = new OneHotEncoderPreprocessor<Integer, Vector>(new HashMap[] { new HashMap() {
{
put("42", 0);
put("43", 1);
}
} }, vectorizer, new HashSet() {
{
add(0);
}
});
double[][] postProcessedData = new double[][] { { 1.0, 0.0 }, { 0.0, 1.0 }, { 1.0, 0.0 } };
for (int i = 0; i < data.length; i++) assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).features().asArray(), 1e-8);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class VectorImplementationsTest method assignVectorWrongCardinality.
/**
*/
private void assignVectorWrongCardinality(Vector v, String desc) {
boolean expECaught = false;
try {
v.assign(new DenseVector(v.size() + 1));
} catch (CardinalityException ce) {
expECaught = true;
}
assertTrue("Expect exception at too large size in " + desc, expECaught);
if (v.size() < 2)
return;
expECaught = false;
try {
v.assign(new DenseVector(v.size() - 1));
} catch (CardinalityException ce) {
expECaught = true;
}
assertTrue("Expect exception at too small size in " + desc, expECaught);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class VectorImplementationsTest method assertWrongCardinality.
/**
*/
private void assertWrongCardinality(Vector v, String desc, BiFunction<Vector, Vector, Vector> vecOperation) {
boolean expECaught = false;
try {
vecOperation.apply(v, new DenseVector(v.size() + 1));
} catch (CardinalityException ce) {
expECaught = true;
}
assertTrue("Expect exception at too large size in " + desc, expECaught);
if (v.size() < 2)
return;
expECaught = false;
try {
vecOperation.apply(v, new DenseVector(v.size() - 1));
} catch (CardinalityException ce) {
expECaught = true;
}
assertTrue("Expect exception at too small size in " + desc, expECaught);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class VectorNormTest method dotTest.
/**
*/
@Test
public void dotTest() {
consumeSampleVectors((v, desc) -> {
// IMPL NOTE this initialises vector
new VectorImplementationsTest.ElementsChecker(v, desc);
final int size = v.size();
final Vector v1 = new DenseVector(size);
invertValues(v, v1);
final double actual = v.dot(v1);
double exp = 0;
for (Vector.Element e : v.all()) exp += e.get() * v1.get(e.index());
final VectorImplementationsTest.Metric metric = new VectorImplementationsTest.Metric(exp, actual);
assertTrue("Dot product not close enough at " + desc + ", " + metric, metric.closeEnough());
});
}
Aggregations