Search in sources :

Example 1 with OneHotEncoderPreprocessor

use of org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor in project ignite by apache.

the class EncoderTrainer method fit.

/**
 * {@inheritDoc}
 */
@Override
public EncoderPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
    if (handledIndices.isEmpty() && encoderType != EncoderType.LABEL_ENCODER)
        throw new RuntimeException("Add indices of handled features");
    try (Dataset<EmptyContext, EncoderPartitionData> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
        EncoderPartitionData partData = new EncoderPartitionData();
        if (encoderType == EncoderType.LABEL_ENCODER) {
            Map<String, Integer> lbFrequencies = null;
            while (upstream.hasNext()) {
                UpstreamEntry<K, V> entity = upstream.next();
                LabeledVector<Double> row = basePreprocessor.apply(entity.getKey(), entity.getValue());
                lbFrequencies = updateLabelFrequenciesForNextRow(row, lbFrequencies);
            }
            partData.withLabelFrequencies(lbFrequencies);
        } else if (encoderType == EncoderType.TARGET_ENCODER) {
            TargetCounter[] targetCounter = null;
            while (upstream.hasNext()) {
                UpstreamEntry<K, V> entity = upstream.next();
                LabeledVector<Double> row = basePreprocessor.apply(entity.getKey(), entity.getValue());
                targetCounter = updateTargetCountersForNextRow(row, targetCounter);
            }
            partData.withTargetCounters(targetCounter);
        } else {
            // This array will contain not null values for handled indices
            Map<String, Integer>[] categoryFrequencies = null;
            while (upstream.hasNext()) {
                UpstreamEntry<K, V> entity = upstream.next();
                LabeledVector<Double> row = basePreprocessor.apply(entity.getKey(), entity.getValue());
                categoryFrequencies = updateFeatureFrequenciesForNextRow(row, categoryFrequencies);
            }
            partData.withCategoryFrequencies(categoryFrequencies);
        }
        return partData;
    }, learningEnvironment(basePreprocessor))) {
        switch(encoderType) {
            case ONE_HOT_ENCODER:
                return new OneHotEncoderPreprocessor<>(calculateEncodingValuesByFrequencies(dataset), basePreprocessor, handledIndices);
            case STRING_ENCODER:
                return new StringEncoderPreprocessor<>(calculateEncodingValuesByFrequencies(dataset), basePreprocessor, handledIndices);
            case LABEL_ENCODER:
                return new LabelEncoderPreprocessor<>(calculateEncodingValuesForLabelsByFrequencies(dataset), basePreprocessor);
            case FREQUENCY_ENCODER:
                return new FrequencyEncoderPreprocessor<>(calculateEncodingFrequencies(dataset), basePreprocessor, handledIndices);
            case TARGET_ENCODER:
                return new TargetEncoderPreprocessor<>(calculateTargetEncodingFrequencies(dataset), basePreprocessor, handledIndices);
            default:
                throw new IllegalStateException("Define the type of the resulting prerocessor.");
        }
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) StringEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) TargetEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.target.TargetEncoderPreprocessor) FrequencyEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.frequency.FrequencyEncoderPreprocessor) UndefinedLabelException(org.apache.ignite.ml.math.exceptions.preprocessing.UndefinedLabelException) OneHotEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor) LabelEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.label.LabelEncoderPreprocessor) UpstreamEntry(org.apache.ignite.ml.dataset.UpstreamEntry)

Example 2 with OneHotEncoderPreprocessor

use of org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor in project ignite by apache.

the class OneHotEncoderPreprocessorTest method testTwoCategorialFeatureAndTwoDoubleFeatures.

/**
 */
@Test
public void testTwoCategorialFeatureAndTwoDoubleFeatures() {
    Vector[] data = new Vector[] { new DenseVector(new Serializable[] { "42", 1.0, "M", 2.0 }), new DenseVector(new Serializable[] { "43", 2.0, "F", 3.0 }), new DenseVector(new Serializable[] { "42", 3.0, Double.NaN, 4.0 }), new DenseVector(new Serializable[] { "42", 4.0, "F", 5.0 }) };
    Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0, 1, 2, 3);
    HashMap[] encodingValues = new HashMap[4];
    encodingValues[0] = new HashMap() {

        {
            put("42", 0);
            put("43", 1);
        }
    };
    encodingValues[2] = new HashMap() {

        {
            put("F", 0);
            put("M", 1);
            put("", 2);
        }
    };
    OneHotEncoderPreprocessor<Integer, Vector> preprocessor = new OneHotEncoderPreprocessor<Integer, Vector>(encodingValues, vectorizer, new HashSet() {

        {
            add(0);
            add(2);
        }
    });
    double[][] postProcessedData = new double[][] { { 1.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0 }, { 2.0, 3.0, 0.0, 1.0, 1.0, 0.0, 0.0 }, { 3.0, 4.0, 1.0, 0.0, 0.0, 0.0, 1.0 }, { 4.0, 5.0, 1.0, 0.0, 1.0, 0.0, 0.0 } };
    for (int i = 0; i < data.length; i++) assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).features().asArray(), 1e-8);
}
Also used : HashMap(java.util.HashMap) DummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer) OneHotEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) HashSet(java.util.HashSet) Test(org.junit.Test)

Example 3 with OneHotEncoderPreprocessor

use of org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor in project ignite by apache.

the class OneHotEncoderPreprocessorTest method testOneCategorialFeature.

/**
 */
@Test
public void testOneCategorialFeature() {
    Vector[] data = new Vector[] { new DenseVector(new Serializable[] { "42" }), new DenseVector(new Serializable[] { "43" }), new DenseVector(new Serializable[] { "42" }) };
    Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0);
    OneHotEncoderPreprocessor<Integer, Vector> preprocessor = new OneHotEncoderPreprocessor<Integer, Vector>(new HashMap[] { new HashMap() {

        {
            put("42", 0);
            put("43", 1);
        }
    } }, vectorizer, new HashSet() {

        {
            add(0);
        }
    });
    double[][] postProcessedData = new double[][] { { 1.0, 0.0 }, { 0.0, 1.0 }, { 1.0, 0.0 } };
    for (int i = 0; i < data.length; i++) assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).features().asArray(), 1e-8);
}
Also used : HashMap(java.util.HashMap) DummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer) OneHotEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) HashSet(java.util.HashSet) Test(org.junit.Test)

Example 4 with OneHotEncoderPreprocessor

use of org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor in project ignite by apache.

the class OneHotEncoderPreprocessorTest method testApplyWithUnknownCategorialValues.

/**
 * The {@code apply()} method is failed with UnknownCategorialFeatureValue exception.
 *
 * The reason is missed information in encodingValues.
 *
 * @see UnknownCategorialValueException
 */
@Test
public void testApplyWithUnknownCategorialValues() {
    Vector[] data = new Vector[] { new DenseVector(new Serializable[] { "1", "Moscow", "A" }), new DenseVector(new Serializable[] { "2", "Moscow", "A" }), new DenseVector(new Serializable[] { "2", "Moscow", "B" }) };
    Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0, 1, 2);
    OneHotEncoderPreprocessor<Integer, Vector> preprocessor = new OneHotEncoderPreprocessor<Integer, Vector>(new HashMap[] { new HashMap() {

        {
            put("2", 0);
        }
    }, new HashMap() {

        {
            put("Moscow", 0);
        }
    }, new HashMap() {

        {
            put("A", 0);
            put("B", 1);
        }
    } }, vectorizer, new HashSet() {

        {
            add(0);
            add(1);
            add(2);
        }
    });
    double[][] postProcessedData = new double[][] { { 0.0, 1.0, 1.0, 1.0, 0.0 }, { 1.0, 0.0, 1.0, 1.0, 0.0 }, { 1.0, 0.0, 1.0, 0.0, 1.0 } };
    try {
        for (int i = 0; i < data.length; i++) assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).features().asArray(), 1e-8);
        fail("UnknownCategorialFeatureValue");
    } catch (UnknownCategorialValueException e) {
        return;
    }
    fail("UnknownCategorialFeatureValue");
}
Also used : HashMap(java.util.HashMap) DummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer) OneHotEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor) UnknownCategorialValueException(org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialValueException) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) HashSet(java.util.HashSet) Test(org.junit.Test)

Example 5 with OneHotEncoderPreprocessor

use of org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor in project ignite by apache.

the class OneHotEncoderPreprocessorTest method testApplyWithStringValues.

/**
 * Tests {@code apply()} method.
 */
@Test
public void testApplyWithStringValues() {
    Vector[] data = new Vector[] { new DenseVector(new Serializable[] { "1", "Moscow", "A" }), new DenseVector(new Serializable[] { "2", "Moscow", "A" }), new DenseVector(new Serializable[] { "2", "Moscow", "B" }) };
    Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0, 1, 2);
    OneHotEncoderPreprocessor<Integer, Vector> preprocessor = new OneHotEncoderPreprocessor<Integer, Vector>(new HashMap[] { new HashMap() {

        {
            put("1", 1);
            put("2", 0);
        }
    }, new HashMap() {

        {
            put("Moscow", 0);
        }
    }, new HashMap() {

        {
            put("A", 0);
            put("B", 1);
        }
    } }, vectorizer, new HashSet() {

        {
            add(0);
            add(1);
            add(2);
        }
    });
    double[][] postProcessedData = new double[][] { { 0.0, 1.0, 1.0, 1.0, 0.0 }, { 1.0, 0.0, 1.0, 1.0, 0.0 }, { 1.0, 0.0, 1.0, 0.0, 1.0 } };
    for (int i = 0; i < data.length; i++) assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).features().asArray(), 1e-8);
}
Also used : HashMap(java.util.HashMap) DummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer) OneHotEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) HashSet(java.util.HashSet) Test(org.junit.Test)

Aggregations

OneHotEncoderPreprocessor (org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor)5 HashMap (java.util.HashMap)4 HashSet (java.util.HashSet)4 DummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer)4 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)4 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)4 Test (org.junit.Test)4 UpstreamEntry (org.apache.ignite.ml.dataset.UpstreamEntry)1 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)1 UndefinedLabelException (org.apache.ignite.ml.math.exceptions.preprocessing.UndefinedLabelException)1 UnknownCategorialValueException (org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialValueException)1 FrequencyEncoderPreprocessor (org.apache.ignite.ml.preprocessing.encoding.frequency.FrequencyEncoderPreprocessor)1 LabelEncoderPreprocessor (org.apache.ignite.ml.preprocessing.encoding.label.LabelEncoderPreprocessor)1 StringEncoderPreprocessor (org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor)1 TargetEncoderPreprocessor (org.apache.ignite.ml.preprocessing.encoding.target.TargetEncoderPreprocessor)1 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)1