Search in sources :

Example 1 with TargetCounter

use of org.apache.ignite.ml.preprocessing.encoding.target.TargetCounter in project ignite by apache.

the class EncoderTrainer method updateTargetCountersForNextRow.

/**
 * Updates frequencies by values and features.
 *
 * @param row Feature vector.
 * @param targetCounters Holds the frequencies of categories by values and features.
 * @return target counter.
 */
private TargetCounter[] updateTargetCountersForNextRow(LabeledVector row, TargetCounter[] targetCounters) {
    if (targetCounters == null)
        targetCounters = initializeTargetCounters(row);
    else
        assert targetCounters.length == row.size() : "Base preprocessor must return exactly " + targetCounters.length + " features";
    double targetValue = row.features().get(targetLabelIndex);
    for (int i = 0; i < targetCounters.length; i++) {
        if (handledIndices.contains(i)) {
            String strVal;
            Object featureVal = row.features().getRaw(i);
            if (featureVal.equals(Double.NaN)) {
                strVal = EncoderPreprocessor.KEY_FOR_NULL_VALUES;
                row.features().setRaw(i, strVal);
            } else if (featureVal instanceof String)
                strVal = (String) featureVal;
            else if (featureVal instanceof Number)
                strVal = String.valueOf(featureVal);
            else if (featureVal instanceof Boolean)
                strVal = String.valueOf(featureVal);
            else
                throw new RuntimeException("The type " + featureVal.getClass() + " is not supported for the feature values.");
            TargetCounter targetCounter = targetCounters[i];
            targetCounter.setTargetCount(targetCounter.getTargetCount() + 1);
            targetCounter.setTargetSum(targetCounter.getTargetSum() + targetValue);
            Map<String, Long> categoryCounts = targetCounter.getCategoryCounts();
            if (categoryCounts.containsKey(strVal)) {
                categoryCounts.put(strVal, categoryCounts.get(strVal) + 1);
            } else {
                categoryCounts.put(strVal, 1L);
            }
            Map<String, Double> categoryTargetSum = targetCounter.getCategoryTargetSum();
            if (categoryTargetSum.containsKey(strVal)) {
                categoryTargetSum.put(strVal, categoryTargetSum.get(strVal) + targetValue);
            } else {
                categoryTargetSum.put(strVal, targetValue);
            }
        }
    }
    return targetCounters;
}
Also used : TargetCounter(org.apache.ignite.ml.preprocessing.encoding.target.TargetCounter)

Example 2 with TargetCounter

use of org.apache.ignite.ml.preprocessing.encoding.target.TargetCounter in project ignite by apache.

the class EncoderTrainer method calculateTargetEncodingFrequencies.

/**
 * Calculates encoding frequencies as avarage category target on amount of rows in dataset.
 *
 * NOTE: The amount of rows is calculated as sum of absolute frequencies.
 *
 * @param dataset Dataset.
 * @return Encoding frequency for each feature.
 */
private TargetEncodingMeta[] calculateTargetEncodingFrequencies(Dataset<EmptyContext, EncoderPartitionData> dataset) {
    TargetCounter[] targetCounters = dataset.compute(EncoderPartitionData::targetCounters, (a, b) -> {
        if (a == null)
            return b;
        if (b == null)
            return a;
        assert a.length == b.length;
        for (int i = 0; i < a.length; i++) {
            if (handledIndices.contains(i)) {
                int finalI = i;
                b[i].setTargetSum(a[i].getTargetSum() + b[i].getTargetSum());
                b[i].setTargetCount(a[i].getTargetCount() + b[i].getTargetCount());
                a[i].getCategoryCounts().forEach((k, v) -> b[finalI].getCategoryCounts().merge(k, v, Long::sum));
                a[i].getCategoryTargetSum().forEach((k, v) -> b[finalI].getCategoryTargetSum().merge(k, v, Double::sum));
            }
        }
        return b;
    });
    TargetEncodingMeta[] targetEncodingMetas = new TargetEncodingMeta[targetCounters.length];
    for (int i = 0; i < targetCounters.length; i++) {
        if (handledIndices.contains(i)) {
            TargetCounter targetCounter = targetCounters[i];
            targetEncodingMetas[i] = new TargetEncodingMeta().withGlobalMean(targetCounter.getTargetSum() / targetCounter.getTargetCount()).withCategoryMean(calculateCategoryTargetEncodingFrequency(targetCounter));
        }
    }
    return targetEncodingMetas;
}
Also used : TargetEncodingMeta(org.apache.ignite.ml.preprocessing.encoding.target.TargetEncodingMeta) TargetCounter(org.apache.ignite.ml.preprocessing.encoding.target.TargetCounter)

Aggregations

TargetCounter (org.apache.ignite.ml.preprocessing.encoding.target.TargetCounter)2 TargetEncodingMeta (org.apache.ignite.ml.preprocessing.encoding.target.TargetEncodingMeta)1