Search in sources :

Example 1 with TargetEncodingMeta

use of org.apache.ignite.ml.preprocessing.encoding.target.TargetEncodingMeta in project ignite by apache.

the class EncoderTrainer method calculateTargetEncodingFrequencies.

/**
 * Calculates encoding frequencies as avarage category target on amount of rows in dataset.
 *
 * NOTE: The amount of rows is calculated as sum of absolute frequencies.
 *
 * @param dataset Dataset.
 * @return Encoding frequency for each feature.
 */
private TargetEncodingMeta[] calculateTargetEncodingFrequencies(Dataset<EmptyContext, EncoderPartitionData> dataset) {
    TargetCounter[] targetCounters = dataset.compute(EncoderPartitionData::targetCounters, (a, b) -> {
        if (a == null)
            return b;
        if (b == null)
            return a;
        assert a.length == b.length;
        for (int i = 0; i < a.length; i++) {
            if (handledIndices.contains(i)) {
                int finalI = i;
                b[i].setTargetSum(a[i].getTargetSum() + b[i].getTargetSum());
                b[i].setTargetCount(a[i].getTargetCount() + b[i].getTargetCount());
                a[i].getCategoryCounts().forEach((k, v) -> b[finalI].getCategoryCounts().merge(k, v, Long::sum));
                a[i].getCategoryTargetSum().forEach((k, v) -> b[finalI].getCategoryTargetSum().merge(k, v, Double::sum));
            }
        }
        return b;
    });
    TargetEncodingMeta[] targetEncodingMetas = new TargetEncodingMeta[targetCounters.length];
    for (int i = 0; i < targetCounters.length; i++) {
        if (handledIndices.contains(i)) {
            TargetCounter targetCounter = targetCounters[i];
            targetEncodingMetas[i] = new TargetEncodingMeta().withGlobalMean(targetCounter.getTargetSum() / targetCounter.getTargetCount()).withCategoryMean(calculateCategoryTargetEncodingFrequency(targetCounter));
        }
    }
    return targetEncodingMetas;
}
Also used : TargetEncodingMeta(org.apache.ignite.ml.preprocessing.encoding.target.TargetEncodingMeta) TargetCounter(org.apache.ignite.ml.preprocessing.encoding.target.TargetCounter)

Example 2 with TargetEncodingMeta

use of org.apache.ignite.ml.preprocessing.encoding.target.TargetEncodingMeta in project ignite by apache.

the class TargetEncoderPreprocessorTest method testApply.

/**
 * Tests {@code apply()} method.
 */
@Test
public void testApply() {
    Vector[] data = new Vector[] { new DenseVector(new Serializable[] { "1", "Moscow", "A" }), new DenseVector(new Serializable[] { "2", "Moscow", "B" }), new DenseVector(new Serializable[] { "3", "Moscow", "B" }) };
    Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0, 1, 2);
    TargetEncoderPreprocessor<Integer, Vector> preprocessor = new TargetEncoderPreprocessor<>(new TargetEncodingMeta[] { // feature 0
    new TargetEncodingMeta().withGlobalMean(0.5).withCategoryMean(new HashMap<String, Double>() {

        {
            // category "1" avg mean = 1.0
            put("1", 1.0);
            // category "2" avg mean = 0.0
            put("2", 0.0);
        }
    }), // feature 1
    new TargetEncodingMeta().withGlobalMean(0.1).withCategoryMean(new HashMap<String, Double>() {
    }), // feature 2
    new TargetEncodingMeta().withGlobalMean(0.1).withCategoryMean(new HashMap<String, Double>() {

        {
            // category "A" avg mean 1.0
            put("A", 1.0);
            // category "B" avg mean 2.0
            put("B", 2.0);
        }
    }) }, vectorizer, new HashSet<Integer>() {

        {
            add(0);
            add(1);
            add(2);
        }
    });
    double[][] postProcessedData = new double[][] { { // "1" contains in dict => use category mean 1.0
    1.0, // "Moscow" not contains in dict => use global 0.1
    0.1, // "A" contains in dict => use category mean 1.0
    1.0 }, { // "2" contains in dict => use category mean 0.0
    0.0, // "Moscow" not contains in dict => use global 0.1
    0.1, // "B" contains in dict => use category mean 2.0
    2.0 }, { // "3" not contains in dict => use global mean 0.5
    0.5, // "Moscow" not contains in dict => use global 0.1
    0.1, // "B" contains in dict => use category mean 2.0
    2.0 } };
    for (int i = 0; i < data.length; i++) assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).features().asArray(), 1e-8);
}
Also used : HashMap(java.util.HashMap) DummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer) TargetEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.target.TargetEncoderPreprocessor) TargetEncodingMeta(org.apache.ignite.ml.preprocessing.encoding.target.TargetEncodingMeta) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) Test(org.junit.Test)

Aggregations

TargetEncodingMeta (org.apache.ignite.ml.preprocessing.encoding.target.TargetEncodingMeta)2 HashMap (java.util.HashMap)1 DummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer)1 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)1 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)1 TargetCounter (org.apache.ignite.ml.preprocessing.encoding.target.TargetCounter)1 TargetEncoderPreprocessor (org.apache.ignite.ml.preprocessing.encoding.target.TargetEncoderPreprocessor)1 Test (org.junit.Test)1