use of org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor in project ignite by apache.
the class EncoderTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public EncoderPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
if (handledIndices.isEmpty() && encoderType != EncoderType.LABEL_ENCODER)
throw new RuntimeException("Add indices of handled features");
try (Dataset<EmptyContext, EncoderPartitionData> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
EncoderPartitionData partData = new EncoderPartitionData();
if (encoderType == EncoderType.LABEL_ENCODER) {
Map<String, Integer> lbFrequencies = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector<Double> row = basePreprocessor.apply(entity.getKey(), entity.getValue());
lbFrequencies = updateLabelFrequenciesForNextRow(row, lbFrequencies);
}
partData.withLabelFrequencies(lbFrequencies);
} else if (encoderType == EncoderType.TARGET_ENCODER) {
TargetCounter[] targetCounter = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector<Double> row = basePreprocessor.apply(entity.getKey(), entity.getValue());
targetCounter = updateTargetCountersForNextRow(row, targetCounter);
}
partData.withTargetCounters(targetCounter);
} else {
// This array will contain not null values for handled indices
Map<String, Integer>[] categoryFrequencies = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector<Double> row = basePreprocessor.apply(entity.getKey(), entity.getValue());
categoryFrequencies = updateFeatureFrequenciesForNextRow(row, categoryFrequencies);
}
partData.withCategoryFrequencies(categoryFrequencies);
}
return partData;
}, learningEnvironment(basePreprocessor))) {
switch(encoderType) {
case ONE_HOT_ENCODER:
return new OneHotEncoderPreprocessor<>(calculateEncodingValuesByFrequencies(dataset), basePreprocessor, handledIndices);
case STRING_ENCODER:
return new StringEncoderPreprocessor<>(calculateEncodingValuesByFrequencies(dataset), basePreprocessor, handledIndices);
case LABEL_ENCODER:
return new LabelEncoderPreprocessor<>(calculateEncodingValuesForLabelsByFrequencies(dataset), basePreprocessor);
case FREQUENCY_ENCODER:
return new FrequencyEncoderPreprocessor<>(calculateEncodingFrequencies(dataset), basePreprocessor, handledIndices);
case TARGET_ENCODER:
return new TargetEncoderPreprocessor<>(calculateTargetEncodingFrequencies(dataset), basePreprocessor, handledIndices);
default:
throw new IllegalStateException("Define the type of the resulting prerocessor.");
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor in project ignite by apache.
the class OneHotEncoderPreprocessorTest method testTwoCategorialFeatureAndTwoDoubleFeatures.
/**
*/
@Test
public void testTwoCategorialFeatureAndTwoDoubleFeatures() {
Vector[] data = new Vector[] { new DenseVector(new Serializable[] { "42", 1.0, "M", 2.0 }), new DenseVector(new Serializable[] { "43", 2.0, "F", 3.0 }), new DenseVector(new Serializable[] { "42", 3.0, Double.NaN, 4.0 }), new DenseVector(new Serializable[] { "42", 4.0, "F", 5.0 }) };
Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0, 1, 2, 3);
HashMap[] encodingValues = new HashMap[4];
encodingValues[0] = new HashMap() {
{
put("42", 0);
put("43", 1);
}
};
encodingValues[2] = new HashMap() {
{
put("F", 0);
put("M", 1);
put("", 2);
}
};
OneHotEncoderPreprocessor<Integer, Vector> preprocessor = new OneHotEncoderPreprocessor<Integer, Vector>(encodingValues, vectorizer, new HashSet() {
{
add(0);
add(2);
}
});
double[][] postProcessedData = new double[][] { { 1.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0 }, { 2.0, 3.0, 0.0, 1.0, 1.0, 0.0, 0.0 }, { 3.0, 4.0, 1.0, 0.0, 0.0, 0.0, 1.0 }, { 4.0, 5.0, 1.0, 0.0, 1.0, 0.0, 0.0 } };
for (int i = 0; i < data.length; i++) assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).features().asArray(), 1e-8);
}
use of org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor in project ignite by apache.
the class OneHotEncoderPreprocessorTest method testOneCategorialFeature.
/**
*/
@Test
public void testOneCategorialFeature() {
Vector[] data = new Vector[] { new DenseVector(new Serializable[] { "42" }), new DenseVector(new Serializable[] { "43" }), new DenseVector(new Serializable[] { "42" }) };
Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0);
OneHotEncoderPreprocessor<Integer, Vector> preprocessor = new OneHotEncoderPreprocessor<Integer, Vector>(new HashMap[] { new HashMap() {
{
put("42", 0);
put("43", 1);
}
} }, vectorizer, new HashSet() {
{
add(0);
}
});
double[][] postProcessedData = new double[][] { { 1.0, 0.0 }, { 0.0, 1.0 }, { 1.0, 0.0 } };
for (int i = 0; i < data.length; i++) assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).features().asArray(), 1e-8);
}
use of org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor in project ignite by apache.
the class OneHotEncoderPreprocessorTest method testApplyWithUnknownCategorialValues.
/**
* The {@code apply()} method is failed with UnknownCategorialFeatureValue exception.
*
* The reason is missed information in encodingValues.
*
* @see UnknownCategorialValueException
*/
@Test
public void testApplyWithUnknownCategorialValues() {
Vector[] data = new Vector[] { new DenseVector(new Serializable[] { "1", "Moscow", "A" }), new DenseVector(new Serializable[] { "2", "Moscow", "A" }), new DenseVector(new Serializable[] { "2", "Moscow", "B" }) };
Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0, 1, 2);
OneHotEncoderPreprocessor<Integer, Vector> preprocessor = new OneHotEncoderPreprocessor<Integer, Vector>(new HashMap[] { new HashMap() {
{
put("2", 0);
}
}, new HashMap() {
{
put("Moscow", 0);
}
}, new HashMap() {
{
put("A", 0);
put("B", 1);
}
} }, vectorizer, new HashSet() {
{
add(0);
add(1);
add(2);
}
});
double[][] postProcessedData = new double[][] { { 0.0, 1.0, 1.0, 1.0, 0.0 }, { 1.0, 0.0, 1.0, 1.0, 0.0 }, { 1.0, 0.0, 1.0, 0.0, 1.0 } };
try {
for (int i = 0; i < data.length; i++) assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).features().asArray(), 1e-8);
fail("UnknownCategorialFeatureValue");
} catch (UnknownCategorialValueException e) {
return;
}
fail("UnknownCategorialFeatureValue");
}
use of org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor in project ignite by apache.
the class OneHotEncoderPreprocessorTest method testApplyWithStringValues.
/**
* Tests {@code apply()} method.
*/
@Test
public void testApplyWithStringValues() {
Vector[] data = new Vector[] { new DenseVector(new Serializable[] { "1", "Moscow", "A" }), new DenseVector(new Serializable[] { "2", "Moscow", "A" }), new DenseVector(new Serializable[] { "2", "Moscow", "B" }) };
Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0, 1, 2);
OneHotEncoderPreprocessor<Integer, Vector> preprocessor = new OneHotEncoderPreprocessor<Integer, Vector>(new HashMap[] { new HashMap() {
{
put("1", 1);
put("2", 0);
}
}, new HashMap() {
{
put("Moscow", 0);
}
}, new HashMap() {
{
put("A", 0);
put("B", 1);
}
} }, vectorizer, new HashSet() {
{
add(0);
add(1);
add(2);
}
});
double[][] postProcessedData = new double[][] { { 0.0, 1.0, 1.0, 1.0, 0.0 }, { 1.0, 0.0, 1.0, 1.0, 0.0 }, { 1.0, 0.0, 1.0, 0.0, 1.0 } };
for (int i = 0; i < data.length; i++) assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).features().asArray(), 1e-8);
}
Aggregations