Search in sources :

Example 1 with Mean

use of org.apache.commons.math.stat.descriptive.moment.Mean in project sidewinder by srotya.

the class TestMathUtils method testMean.

@Test
public void testMean() {
    double[] a = new double[] { 2, 3, 4, 5, 6 };
    double mean = MathUtils.mean(a);
    Mean cmean = new Mean();
    assertEquals(cmean.evaluate(a), mean, 0.0001);
}
Also used : Mean(org.apache.commons.math.stat.descriptive.moment.Mean) Test(org.junit.Test)

Example 2 with Mean

use of org.apache.commons.math.stat.descriptive.moment.Mean in project knime-core by knime.

the class TreeEnsembleRegressionPredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    TreeEnsembleModelPortObject modelObject = m_predictor.getModelObject();
    TreeEnsemblePredictorConfiguration cfg = m_predictor.getConfiguration();
    final TreeEnsembleModel ensembleModel = modelObject.getEnsembleModel();
    int size = 1;
    final boolean appendConfidence = cfg.isAppendPredictionConfidence();
    final boolean appendModelCount = cfg.isAppendModelCount();
    if (appendConfidence) {
        size += 1;
    }
    if (appendModelCount) {
        size += 1;
    }
    final boolean hasOutOfBagFilter = m_predictor.hasOutOfBagFilter();
    DataCell[] result = new DataCell[size];
    DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    PredictorRecord record = ensembleModel.createPredictorRecord(filterRow, m_learnSpec);
    if (record == null) {
        // missing value
        Arrays.fill(result, DataType.getMissingCell());
        return result;
    }
    Mean mean = new Mean();
    Variance variance = new Variance();
    final int nrModels = ensembleModel.getNrModels();
    for (int i = 0; i < nrModels; i++) {
        if (hasOutOfBagFilter && m_predictor.isRowPartOfTrainingData(row.getKey(), i)) {
        // ignore, row was used to train the model
        } else {
            TreeModelRegression m = ensembleModel.getTreeModelRegression(i);
            TreeNodeRegression match = m.findMatchingNode(record);
            double nodeMean = match.getMean();
            mean.increment(nodeMean);
            variance.increment(nodeMean);
        }
    }
    int nrValidModels = (int) mean.getN();
    int index = 0;
    result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(mean.getResult());
    if (appendConfidence) {
        result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(variance.getResult());
    }
    if (appendModelCount) {
        result[index++] = new IntCell(nrValidModels);
    }
    return result;
}
Also used : Mean(org.apache.commons.math.stat.descriptive.moment.Mean) TreeEnsembleModel(org.knime.base.node.mine.treeensemble.model.TreeEnsembleModel) DoubleCell(org.knime.core.data.def.DoubleCell) TreeEnsemblePredictorConfiguration(org.knime.base.node.mine.treeensemble.node.predictor.TreeEnsemblePredictorConfiguration) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble.model.TreeNodeRegression) Variance(org.apache.commons.math.stat.descriptive.moment.Variance) TreeModelRegression(org.knime.base.node.mine.treeensemble.model.TreeModelRegression) IntCell(org.knime.core.data.def.IntCell) TreeEnsembleModelPortObject(org.knime.base.node.mine.treeensemble.model.TreeEnsembleModelPortObject) PredictorRecord(org.knime.base.node.mine.treeensemble.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 3 with Mean

use of org.apache.commons.math.stat.descriptive.moment.Mean in project knime-core by knime.

the class TreeEnsembleRegressionPredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    TreeEnsembleModelPortObject modelObject = m_predictor.getModelObject();
    TreeEnsemblePredictorConfiguration cfg = m_predictor.getConfiguration();
    final TreeEnsembleModel ensembleModel = modelObject.getEnsembleModel();
    int size = 1;
    final boolean appendConfidence = cfg.isAppendPredictionConfidence();
    final boolean appendModelCount = cfg.isAppendModelCount();
    if (appendConfidence) {
        size += 1;
    }
    if (appendModelCount) {
        size += 1;
    }
    final boolean hasOutOfBagFilter = m_predictor.hasOutOfBagFilter();
    DataCell[] result = new DataCell[size];
    DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    PredictorRecord record = ensembleModel.createPredictorRecord(filterRow, m_learnSpec);
    if (record == null) {
        // missing value
        Arrays.fill(result, DataType.getMissingCell());
        return result;
    }
    Mean mean = new Mean();
    Variance variance = new Variance();
    final int nrModels = ensembleModel.getNrModels();
    for (int i = 0; i < nrModels; i++) {
        if (hasOutOfBagFilter && m_predictor.isRowPartOfTrainingData(row.getKey(), i)) {
        // ignore, row was used to train the model
        } else {
            TreeModelRegression m = ensembleModel.getTreeModelRegression(i);
            TreeNodeRegression match = m.findMatchingNode(record);
            double nodeMean = match.getMean();
            mean.increment(nodeMean);
            variance.increment(nodeMean);
        }
    }
    int nrValidModels = (int) mean.getN();
    int index = 0;
    result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(mean.getResult());
    if (appendConfidence) {
        result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(variance.getResult());
    }
    if (appendModelCount) {
        result[index++] = new IntCell(nrValidModels);
    }
    return result;
}
Also used : Mean(org.apache.commons.math.stat.descriptive.moment.Mean) TreeEnsembleModel(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel) DoubleCell(org.knime.core.data.def.DoubleCell) TreeEnsemblePredictorConfiguration(org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictorConfiguration) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) Variance(org.apache.commons.math.stat.descriptive.moment.Variance) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) IntCell(org.knime.core.data.def.IntCell) TreeEnsembleModelPortObject(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 4 with Mean

use of org.apache.commons.math.stat.descriptive.moment.Mean in project knime-core by knime.

the class NumericScorerNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected BufferedDataTable[] execute(final BufferedDataTable[] inData, final ExecutionContext exec) throws Exception {
    DataTableSpec spec = inData[0].getSpec();
    BufferedDataContainer container = exec.createDataContainer(createOutputSpec(spec));
    int referenceIdx = spec.findColumnIndex(m_numericScorerSettings.getReferenceColumnName());
    int predictionIdx = spec.findColumnIndex(m_numericScorerSettings.getPredictionColumnName());
    final Mean meanObserved = new Mean(), meanPredicted = new Mean();
    final Mean absError = new Mean(), squaredError = new Mean();
    final Mean signedDiff = new Mean();
    final SumOfSquares ssTot = new SumOfSquares(), ssRes = new SumOfSquares();
    int skippedRowCount = 0;
    for (DataRow row : inData[0]) {
        DataCell refCell = row.getCell(referenceIdx);
        DataCell predCell = row.getCell(predictionIdx);
        if (refCell.isMissing()) {
            skippedRowCount++;
            continue;
        }
        double ref = ((DoubleValue) refCell).getDoubleValue();
        if (predCell.isMissing()) {
            throw new IllegalArgumentException("Missing value in prediction column in row: " + row.getKey());
        }
        double pred = ((DoubleValue) predCell).getDoubleValue();
        meanObserved.increment(ref);
        meanPredicted.increment(pred);
        absError.increment(Math.abs(ref - pred));
        squaredError.increment((ref - pred) * (ref - pred));
        signedDiff.increment(pred - ref);
    }
    for (DataRow row : inData[0]) {
        DataCell refCell = row.getCell(referenceIdx);
        DataCell predCell = row.getCell(predictionIdx);
        if (refCell.isMissing()) {
            continue;
        }
        double ref = ((DoubleValue) refCell).getDoubleValue();
        double pred = ((DoubleValue) predCell).getDoubleValue();
        ssTot.increment(ref - meanObserved.getResult());
        ssRes.increment(ref - pred);
    }
    container.addRowToTable(new DefaultRow("R^2", m_rSquare = 1 - ssRes.getResult() / ssTot.getResult()));
    container.addRowToTable(new DefaultRow("mean absolute error", m_meanAbsError = absError.getResult()));
    container.addRowToTable(new DefaultRow("mean squared error", m_meanSquaredError = squaredError.getResult()));
    container.addRowToTable(new DefaultRow("root mean squared deviation", m_rmsd = Math.sqrt(squaredError.getResult())));
    container.addRowToTable(new DefaultRow("mean signed difference", m_meanSignedDifference = signedDiff.getResult()));
    container.close();
    if (skippedRowCount > 0) {
        setWarningMessage("Skipped " + skippedRowCount + " rows, because the reference column contained missing values there.");
    }
    pushFlowVars(false);
    return new BufferedDataTable[] { container.getTable() };
}
Also used : Mean(org.apache.commons.math.stat.descriptive.moment.Mean) DataTableSpec(org.knime.core.data.DataTableSpec) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) SumOfSquares(org.apache.commons.math.stat.descriptive.summary.SumOfSquares) DoubleValue(org.knime.core.data.DoubleValue) BufferedDataTable(org.knime.core.node.BufferedDataTable) DataCell(org.knime.core.data.DataCell) DefaultRow(org.knime.core.data.def.DefaultRow) DataRow(org.knime.core.data.DataRow)

Example 5 with Mean

use of org.apache.commons.math.stat.descriptive.moment.Mean in project drill by axbaretto.

the class TestOrderedPartitionExchange method twoBitTwoExchangeRun.

/**
 * Starts two drillbits and runs a physical plan with a Mock scan, project, OrderedParititionExchange, Union Exchange,
 * and sort. The final sort is done first on the partition column, and verifies that the partitions are correct, in that
 * all rows in partition 0 should come in the sort order before any row in partition 1, etc. Also verifies that the standard
 * deviation of the size of the partitions is less than one tenth the mean size of the partitions, because we expect all
 * the partitions to be roughly equal in size.
 * @throws Exception
 */
@Test
public void twoBitTwoExchangeRun() throws Exception {
    RemoteServiceSet serviceSet = RemoteServiceSet.getLocalServiceSet();
    try (Drillbit bit1 = new Drillbit(CONFIG, serviceSet);
        Drillbit bit2 = new Drillbit(CONFIG, serviceSet);
        DrillClient client = new DrillClient(CONFIG, serviceSet.getCoordinator())) {
        bit1.run();
        bit2.run();
        client.connect();
        List<QueryDataBatch> results = client.runQuery(org.apache.drill.exec.proto.UserBitShared.QueryType.PHYSICAL, Files.toString(DrillFileUtils.getResourceAsFile("/sender/ordered_exchange.json"), Charsets.UTF_8));
        int count = 0;
        List<Integer> partitionRecordCounts = Lists.newArrayList();
        for (QueryDataBatch b : results) {
            if (b.getData() != null) {
                int rows = b.getHeader().getRowCount();
                count += rows;
                DrillConfig config = DrillConfig.create();
                RecordBatchLoader loader = new RecordBatchLoader(new BootStrapContext(config, SystemOptionManager.createDefaultOptionDefinitions(), ClassPathScanner.fromPrescan(config)).getAllocator());
                loader.load(b.getHeader().getDef(), b.getData());
                BigIntVector vv1 = (BigIntVector) loader.getValueAccessorById(BigIntVector.class, loader.getValueVectorId(new SchemaPath("col1", ExpressionPosition.UNKNOWN)).getFieldIds()).getValueVector();
                Float8Vector vv2 = (Float8Vector) loader.getValueAccessorById(Float8Vector.class, loader.getValueVectorId(new SchemaPath("col2", ExpressionPosition.UNKNOWN)).getFieldIds()).getValueVector();
                IntVector pVector = (IntVector) loader.getValueAccessorById(IntVector.class, loader.getValueVectorId(new SchemaPath("partition", ExpressionPosition.UNKNOWN)).getFieldIds()).getValueVector();
                long previous1 = Long.MIN_VALUE;
                double previous2 = Double.MIN_VALUE;
                int partPrevious = -1;
                long current1 = Long.MIN_VALUE;
                double current2 = Double.MIN_VALUE;
                int partCurrent = -1;
                int partitionRecordCount = 0;
                for (int i = 0; i < rows; i++) {
                    previous1 = current1;
                    previous2 = current2;
                    partPrevious = partCurrent;
                    current1 = vv1.getAccessor().get(i);
                    current2 = vv2.getAccessor().get(i);
                    partCurrent = pVector.getAccessor().get(i);
                    Assert.assertTrue(current1 >= previous1);
                    if (current1 == previous1) {
                        Assert.assertTrue(current2 <= previous2);
                    }
                    if (partCurrent == partPrevious || partPrevious == -1) {
                        partitionRecordCount++;
                    } else {
                        partitionRecordCounts.add(partitionRecordCount);
                        partitionRecordCount = 0;
                    }
                }
                partitionRecordCounts.add(partitionRecordCount);
                loader.clear();
            }
            b.release();
        }
        double[] values = new double[partitionRecordCounts.size()];
        int i = 0;
        for (Integer rc : partitionRecordCounts) {
            values[i++] = rc.doubleValue();
        }
        StandardDeviation stdDev = new StandardDeviation();
        Mean mean = new Mean();
        double std = stdDev.evaluate(values);
        double m = mean.evaluate(values);
        System.out.println("mean: " + m + " std dev: " + std);
        // Assert.assertTrue(std < 0.1 * m);
        assertEquals(31000, count);
    }
}
Also used : Mean(org.apache.commons.math.stat.descriptive.moment.Mean) BigIntVector(org.apache.drill.exec.vector.BigIntVector) IntVector(org.apache.drill.exec.vector.IntVector) RecordBatchLoader(org.apache.drill.exec.record.RecordBatchLoader) Float8Vector(org.apache.drill.exec.vector.Float8Vector) BigIntVector(org.apache.drill.exec.vector.BigIntVector) QueryDataBatch(org.apache.drill.exec.rpc.user.QueryDataBatch) DrillConfig(org.apache.drill.common.config.DrillConfig) Drillbit(org.apache.drill.exec.server.Drillbit) SchemaPath(org.apache.drill.common.expression.SchemaPath) RemoteServiceSet(org.apache.drill.exec.server.RemoteServiceSet) BootStrapContext(org.apache.drill.exec.server.BootStrapContext) StandardDeviation(org.apache.commons.math.stat.descriptive.moment.StandardDeviation) DrillClient(org.apache.drill.exec.client.DrillClient) OperatorTest(org.apache.drill.categories.OperatorTest) Test(org.junit.Test)

Aggregations

Mean (org.apache.commons.math.stat.descriptive.moment.Mean)9 StandardDeviation (org.apache.commons.math.stat.descriptive.moment.StandardDeviation)3 Variance (org.apache.commons.math.stat.descriptive.moment.Variance)3 Test (org.junit.Test)3 DataCell (org.knime.core.data.DataCell)3 DataRow (org.knime.core.data.DataRow)3 OperatorTest (org.apache.drill.categories.OperatorTest)2 DrillConfig (org.apache.drill.common.config.DrillConfig)2 SchemaPath (org.apache.drill.common.expression.SchemaPath)2 DrillClient (org.apache.drill.exec.client.DrillClient)2 RecordBatchLoader (org.apache.drill.exec.record.RecordBatchLoader)2 QueryDataBatch (org.apache.drill.exec.rpc.user.QueryDataBatch)2 BootStrapContext (org.apache.drill.exec.server.BootStrapContext)2 Drillbit (org.apache.drill.exec.server.Drillbit)2 RemoteServiceSet (org.apache.drill.exec.server.RemoteServiceSet)2 BigIntVector (org.apache.drill.exec.vector.BigIntVector)2 Float8Vector (org.apache.drill.exec.vector.Float8Vector)2 IntVector (org.apache.drill.exec.vector.IntVector)2 FilterColumnRow (org.knime.base.data.filter.column.FilterColumnRow)2 DoubleCell (org.knime.core.data.def.DoubleCell)2