Search in sources :

Example 6 with Covariance

use of org.apache.commons.math3.stat.correlation.Covariance in project gatk-protected by broadinstitute.

the class CoverageModelParameters method write.

/**
     * Writes the model to disk.
     *
     * @param outputPath model output path
     */
public static void write(@Nonnull CoverageModelParameters model, @Nonnull final String outputPath) {
    /* create output directory if it doesn't exist */
    createOutputPath(Utils.nonNull(outputPath, "The output path string must be non-null"));
    /* write targets list */
    final File targetListFile = new File(outputPath, CoverageModelGlobalConstants.TARGET_LIST_OUTPUT_FILE);
    TargetWriter.writeTargetsToFile(targetListFile, model.getTargetList());
    final List<String> targetNames = model.getTargetList().stream().map(Target::getName).collect(Collectors.toList());
    /* write target mean bias to file */
    final File targetMeanBiasFile = new File(outputPath, CoverageModelGlobalConstants.TARGET_MEAN_LOG_BIAS_OUTPUT_FILE);
    Nd4jIOUtils.writeNDArrayMatrixToTextFile(model.getTargetMeanLogBias().transpose(), targetMeanBiasFile, MEAN_LOG_BIAS_MATRIX_NAME, targetNames, null);
    /* write target unexplained variance to file */
    final File targetUnexplainedVarianceFile = new File(outputPath, CoverageModelGlobalConstants.TARGET_UNEXPLAINED_VARIANCE_OUTPUT_FILE);
    Nd4jIOUtils.writeNDArrayMatrixToTextFile(model.getTargetUnexplainedVariance().transpose(), targetUnexplainedVarianceFile, TARGET_UNEXPLAINED_VARIANCE_MATRIX_NAME, targetNames, null);
    if (model.isBiasCovariatesEnabled()) {
        /* write mean bias covariates to file */
        final List<String> meanBiasCovariatesNames = IntStream.range(0, model.getNumLatents()).mapToObj(li -> String.format(BIAS_COVARIATE_COLUMN_NAME_FORMAT, li)).collect(Collectors.toList());
        final File meanBiasCovariatesFile = new File(outputPath, CoverageModelGlobalConstants.MEAN_BIAS_COVARIATES_OUTPUT_FILE);
        Nd4jIOUtils.writeNDArrayMatrixToTextFile(model.getMeanBiasCovariates(), meanBiasCovariatesFile, MEAN_BIAS_COVARIATES_MATRIX_NAME, targetNames, meanBiasCovariatesNames);
        /* write norm_2 of mean bias covariates to file */
        final INDArray WTW = model.getMeanBiasCovariates().transpose().mmul(model.getMeanBiasCovariates());
        final double[] biasCovariatesNorm2 = IntStream.range(0, model.getNumLatents()).mapToDouble(li -> WTW.getDouble(li, li)).toArray();
        final File biasCovariatesNorm2File = new File(outputPath, CoverageModelGlobalConstants.MEAN_BIAS_COVARIATES_NORM2_OUTPUT_FILE);
        Nd4jIOUtils.writeNDArrayMatrixToTextFile(Nd4j.create(biasCovariatesNorm2, new int[] { 1, model.getNumLatents() }), biasCovariatesNorm2File, MEAN_BIAS_COVARIATES_NORM_2_MATRIX_NAME, null, meanBiasCovariatesNames);
        /* if ARD is enabled, write the ARD coefficients and covariance of W as well */
        if (model.isARDEnabled()) {
            final File biasCovariatesARDCoefficientsFile = new File(outputPath, CoverageModelGlobalConstants.BIAS_COVARIATES_ARD_COEFFICIENTS_OUTPUT_FILE);
            Nd4jIOUtils.writeNDArrayMatrixToTextFile(model.getBiasCovariateARDCoefficients(), biasCovariatesARDCoefficientsFile, BIAS_COVARIATES_ARD_COEFFICIENTS_MATRIX_NAME, null, meanBiasCovariatesNames);
        }
    }
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) Nd4j(org.nd4j.linalg.factory.Nd4j) Collectors(java.util.stream.Collectors) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) Sets(com.google.cloud.dataflow.sdk.repackaged.com.google.common.collect.Sets) Logger(org.apache.logging.log4j.Logger) ReadCountCollection(org.broadinstitute.hellbender.tools.exome.ReadCountCollection) Pair(org.apache.commons.lang3.tuple.Pair) UserException(org.broadinstitute.hellbender.exceptions.UserException) java.io(java.io) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) RandomGeneratorFactory(org.apache.commons.math3.random.RandomGeneratorFactory) Target(org.broadinstitute.hellbender.tools.exome.Target) TargetTableReader(org.broadinstitute.hellbender.tools.exome.TargetTableReader) INDArray(org.nd4j.linalg.api.ndarray.INDArray) TargetWriter(org.broadinstitute.hellbender.tools.exome.TargetWriter) Utils(org.broadinstitute.hellbender.utils.Utils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 7 with Covariance

use of org.apache.commons.math3.stat.correlation.Covariance in project knime-core by knime.

the class Learner method perform.

/**
 * @param data The data table.
 * @param exec The execution context used for reporting progress.
 * @return An object which holds the results.
 * @throws CanceledExecutionException when method is cancelled
 * @throws InvalidSettingsException When settings are inconsistent with the data
 */
public LogisticRegressionContent perform(final BufferedDataTable data, final ExecutionContext exec) throws CanceledExecutionException, InvalidSettingsException {
    exec.checkCanceled();
    int iter = 0;
    boolean converged = false;
    final RegressionTrainingData trainingData = new RegressionTrainingData(data, m_outSpec, m_specialColumns, true, m_targetReferenceCategory, m_sortTargetCategories, m_sortFactorsCategories);
    int targetIndex = data.getDataTableSpec().findColumnIndex(m_outSpec.getTargetCols().get(0).getName());
    final int tcC = trainingData.getDomainValues().get(targetIndex).size();
    final int rC = trainingData.getRegressorCount();
    final RealMatrix beta = new Array2DRowRealMatrix(1, (tcC - 1) * (rC + 1));
    Double loglike = 0.0;
    Double loglikeOld = 0.0;
    exec.setMessage("Iterative optimization. Processing iteration 1.");
    // main loop
    while (iter < m_maxIter && !converged) {
        RealMatrix betaOld = beta.copy();
        loglikeOld = loglike;
        // Do heavy work in a separate thread which allows to interrupt it
        // note the queue may block if no more threads are available (e.g. thread count = 1)
        // as soon as we stall in 'get' this thread reduces the number of running thread
        Future<Double> future = ThreadPool.currentPool().enqueue(new Callable<Double>() {

            @Override
            public Double call() throws Exception {
                final ExecutionMonitor progMon = exec.createSubProgress(1.0 / m_maxIter);
                irlsRls(trainingData, beta, rC, tcC, progMon);
                progMon.setProgress(1.0);
                return likelihood(trainingData.iterator(), beta, rC, tcC, exec);
            }
        });
        try {
            loglike = future.get();
        } catch (InterruptedException e) {
            future.cancel(true);
            exec.checkCanceled();
            throw new RuntimeException(e);
        } catch (ExecutionException e) {
            if (e.getCause() instanceof RuntimeException) {
                throw (RuntimeException) e.getCause();
            } else {
                throw new RuntimeException(e.getCause());
            }
        }
        if (Double.isInfinite(loglike) || Double.isNaN(loglike)) {
            throw new RuntimeException(FAILING_MSG);
        }
        exec.checkCanceled();
        // test for decreasing likelihood
        while ((Double.isInfinite(loglike) || Double.isNaN(loglike) || loglike < loglikeOld) && iter > 0) {
            converged = true;
            for (int k = 0; k < beta.getRowDimension(); k++) {
                if (abs(beta.getEntry(k, 0) - betaOld.getEntry(k, 0)) > m_eps * abs(betaOld.getEntry(k, 0))) {
                    converged = false;
                    break;
                }
            }
            if (converged) {
                break;
            }
            // half the step size of beta
            beta.setSubMatrix((beta.add(betaOld)).scalarMultiply(0.5).getData(), 0, 0);
            exec.checkCanceled();
            loglike = likelihood(trainingData.iterator(), beta, rC, tcC, exec);
            exec.checkCanceled();
        }
        // test for convergence
        converged = true;
        for (int k = 0; k < beta.getRowDimension(); k++) {
            if (abs(beta.getEntry(k, 0) - betaOld.getEntry(k, 0)) > m_eps * abs(betaOld.getEntry(k, 0))) {
                converged = false;
                break;
            }
        }
        iter++;
        LOGGER.debug("#Iterations: " + iter);
        LOGGER.debug("Log Likelihood: " + loglike);
        StringBuilder betaBuilder = new StringBuilder();
        for (int i = 0; i < beta.getRowDimension() - 1; i++) {
            betaBuilder.append(Double.toString(beta.getEntry(i, 0)));
            betaBuilder.append(", ");
        }
        if (beta.getRowDimension() > 0) {
            betaBuilder.append(Double.toString(beta.getEntry(beta.getRowDimension() - 1, 0)));
        }
        LOGGER.debug("beta: " + betaBuilder.toString());
        exec.checkCanceled();
        exec.setMessage("Iterative optimization. #Iterations: " + iter + " | Log-likelihood: " + DoubleFormat.formatDouble(loglike) + ". Processing iteration " + (iter + 1) + ".");
    }
    // The covariance matrix
    RealMatrix covMat = new QRDecomposition(A).getSolver().getInverse().scalarMultiply(-1);
    List<String> factorList = new ArrayList<String>();
    List<String> covariateList = new ArrayList<String>();
    Map<String, List<DataCell>> factorDomainValues = new HashMap<String, List<DataCell>>();
    for (int i : trainingData.getActiveCols()) {
        DataColumnSpec columnSpec = data.getDataTableSpec().getColumnSpec(i);
        if (trainingData.getIsNominal().get(i)) {
            String factor = columnSpec.getName();
            factorList.add(factor);
            List<DataCell> values = trainingData.getDomainValues().get(i);
            factorDomainValues.put(factor, values);
        } else {
            if (columnSpec.getType().isCompatible(BitVectorValue.class) || columnSpec.getType().isCompatible(ByteVectorValue.class)) {
                int length = trainingData.getVectorLengths().getOrDefault(i, 0).intValue();
                for (int j = 0; j < length; ++j) {
                    covariateList.add(columnSpec.getName() + "[" + j + "]");
                }
            } else {
                covariateList.add(columnSpec.getName());
            }
        }
    }
    final Map<? extends Integer, Integer> vectorIndexLengths = trainingData.getVectorLengths();
    final Map<String, Integer> vectorLengths = new LinkedHashMap<String, Integer>();
    for (DataColumnSpec spec : m_specialColumns) {
        int colIndex = data.getSpec().findColumnIndex(spec.getName());
        if (colIndex >= 0) {
            vectorLengths.put(spec.getName(), vectorIndexLengths.get(colIndex));
        }
    }
    // create content
    LogisticRegressionContent content = new LogisticRegressionContent(m_outSpec, factorList, covariateList, vectorLengths, m_targetReferenceCategory, m_sortTargetCategories, m_sortFactorsCategories, beta, loglike, covMat, iter);
    return content;
}
Also used : HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) ArrayList(java.util.ArrayList) ByteVectorValue(org.knime.core.data.vector.bytevector.ByteVectorValue) LinkedHashMap(java.util.LinkedHashMap) DataColumnSpec(org.knime.core.data.DataColumnSpec) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RegressionTrainingData(org.knime.base.node.mine.regression.RegressionTrainingData) ArrayList(java.util.ArrayList) List(java.util.List) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) QRDecomposition(org.apache.commons.math3.linear.QRDecomposition) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) DataCell(org.knime.core.data.DataCell) BitVectorValue(org.knime.core.data.vector.bitvector.BitVectorValue)

Example 8 with Covariance

use of org.apache.commons.math3.stat.correlation.Covariance in project knime-core by knime.

the class BinaryNominalSplitsPCA method calculateWeightedCovarianceMatrix.

/**
 * Calculates the weighted covariance matrix of the class probability vectors of the CombinedAttributeValues in
 * attVals
 *
 * @param attVals
 * @param meanClassProbabilityVec
 * @param totalWeight
 * @param numTargetVals
 * @return The weighted covariance matrix of the class probability vectors of the CombinedAttributeValues
 */
static RealMatrix calculateWeightedCovarianceMatrix(final CombinedAttributeValues[] attVals, final RealVector meanClassProbabilityVec, final double totalWeight, final int numTargetVals) {
    RealMatrix weightedCovarianceMatrix = MatrixUtils.createRealMatrix(numTargetVals, numTargetVals);
    for (CombinedAttributeValues attVal : attVals) {
        RealVector diff = attVal.m_classProbabilityVector.subtract(meanClassProbabilityVec);
        weightedCovarianceMatrix = weightedCovarianceMatrix.add(diff.outerProduct(diff).scalarMultiply(attVal.m_totalWeight));
    }
    weightedCovarianceMatrix = weightedCovarianceMatrix.scalarMultiply(1 / (totalWeight - 1));
    return weightedCovarianceMatrix;
}
Also used : RealMatrix(org.apache.commons.math3.linear.RealMatrix) RealVector(org.apache.commons.math3.linear.RealVector)

Example 9 with Covariance

use of org.apache.commons.math3.stat.correlation.Covariance in project knime-core by knime.

the class AbstractSGOptimizer method optimize.

public LogRegLearnerResult optimize(final int maxEpoch, final TrainingData<T> data, final Progress progress) throws CanceledExecutionException {
    final int nRows = data.getRowCount();
    final int nFets = data.getFeatureCount();
    final int nCats = data.getTargetDimension();
    final U updater = m_updaterFactory.create();
    final WeightMatrix<T> beta = new SimpleWeightMatrix<>(nFets, nCats, true);
    int epoch = 0;
    for (; epoch < maxEpoch; epoch++) {
        // notify learning rate strategy that a new epoch starts
        m_lrStrategy.startNewEpoch(epoch);
        progress.setProgress(((double) epoch) / maxEpoch, "Start epoch " + epoch + " of " + maxEpoch);
        for (int k = 0; k < nRows; k++) {
            progress.checkCanceled();
            T x = data.getRandomRow();
            prepareIteration(beta, x, updater, m_regUpdater, k);
            double[] prediction = beta.predict(x);
            double[] sig = m_loss.gradient(x, prediction);
            double stepSize = m_lrStrategy.getCurrentLearningRate(x, prediction, sig);
            // beta is updated in two steps
            m_regUpdater.update(beta, stepSize, k);
            performUpdate(x, updater, sig, beta, stepSize, k);
            double scale = beta.getScale();
            if (scale > 1e10 || scale < -1e10 || (scale > 0 && scale < 1e-10) || (scale < 0 && scale > -1e-10)) {
                normalize(beta, updater, k);
                beta.normalize();
            }
        }
        postProcessEpoch(beta, updater, m_regUpdater);
        if (m_stoppingCriterion.checkConvergence(beta)) {
            break;
        }
    }
    StringBuilder warnBuilder = new StringBuilder();
    if (epoch >= maxEpoch) {
        warnBuilder.append("The algorithm did not reach convergence after the specified number of epochs. " + "Setting the epoch limit higher might result in a better model.");
    }
    double lossSum = totalLoss(beta);
    RealMatrix betaMat = MatrixUtils.createRealMatrix(beta.getWeightVector());
    RealMatrix covMat = null;
    if (m_calcCovMatrix) {
        try {
            covMat = calculateCovariateMatrix(beta);
        } catch (SingularMatrixException e) {
            if (warnBuilder.length() > 0) {
                warnBuilder.append("\n");
            }
            warnBuilder.append("The covariance matrix could not be calculated because the" + " observed fisher information matrix was singular. Did you properly normalize the numerical features?");
            covMat = null;
        }
    }
    m_warning = warnBuilder.length() > 0 ? warnBuilder.toString() : null;
    // in a maximum likelihood sense
    return new LogRegLearnerResult(betaMat, covMat, epoch, -lossSum);
}
Also used : RealMatrix(org.apache.commons.math3.linear.RealMatrix) SingularMatrixException(org.apache.commons.math3.linear.SingularMatrixException) LogRegLearnerResult(org.knime.base.node.mine.regression.logistic.learner4.LogRegLearnerResult)

Example 10 with Covariance

use of org.apache.commons.math3.stat.correlation.Covariance in project knime-core by knime.

the class Learner method perform.

/**
 * @param data The data table.
 * @param exec The execution context used for reporting progress.
 * @return An object which holds the results.
 * @throws CanceledExecutionException When method is cancelled
 * @throws InvalidSettingsException When settings are inconsistent with the data
 */
@Override
public PolyRegContent perform(final BufferedDataTable data, final ExecutionContext exec) throws CanceledExecutionException, InvalidSettingsException {
    exec.checkCanceled();
    RegressionTrainingData trainingData = new RegressionTrainingData(data, m_outSpec, m_failOnMissing);
    int regressorCount = trainingData.getRegressorCount() * m_maxExponent;
    SummaryStatistics[] stats = new SummaryStatistics[regressorCount];
    UpdatingMultipleLinearRegression regr = initStatistics(regressorCount, stats);
    exec.setProgress(0, "Estimating polynomial regression model.");
    processTable(exec, trainingData, stats, regr);
    RegressionResults result = regr.regress();
    RealMatrix beta = MatrixUtils.createRowRealMatrix(result.getParameterEstimates());
    List<String> factorList = new ArrayList<String>();
    List<String> covariateList = createCovariateListAndFillFactors(data, trainingData, factorList);
    // The covariance matrix
    RealMatrix covMat = createCovarianceMatrix(result);
    PolyRegContent content = new PolyRegContent(m_outSpec, (int) stats[0].getN(), factorList, covariateList, beta, m_offsetValue, covMat, result.getRSquared(), result.getAdjustedRSquared(), stats, m_maxExponent);
    return content;
}
Also used : RealMatrix(org.apache.commons.math3.linear.RealMatrix) ArrayList(java.util.ArrayList) RegressionTrainingData(org.knime.base.node.mine.regression.RegressionTrainingData) SummaryStatistics(org.apache.commons.math3.stat.descriptive.SummaryStatistics) RegressionResults(org.apache.commons.math3.stat.regression.RegressionResults) UpdatingMultipleLinearRegression(org.apache.commons.math3.stat.regression.UpdatingMultipleLinearRegression)

Aggregations

RealMatrix (org.apache.commons.math3.linear.RealMatrix)12 ArrayList (java.util.ArrayList)5 java.util (java.util)4 Collectors (java.util.stream.Collectors)4 IntStream (java.util.stream.IntStream)4 Nonnull (javax.annotation.Nonnull)4 Nullable (javax.annotation.Nullable)4 ImmutablePair (org.apache.commons.lang3.tuple.ImmutablePair)4 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)4 Covariance (org.apache.commons.math3.stat.correlation.Covariance)4 Logger (org.apache.logging.log4j.Logger)4 UserException (org.broadinstitute.hellbender.exceptions.UserException)4 Nd4jIOUtils (org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils)4 Utils (org.broadinstitute.hellbender.utils.Utils)4 ParamUtils (org.broadinstitute.hellbender.utils.param.ParamUtils)4 RegressionTrainingData (org.knime.base.node.mine.regression.RegressionTrainingData)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 Nd4j (org.nd4j.linalg.factory.Nd4j)4 NDArrayIndex (org.nd4j.linalg.indexing.NDArrayIndex)4 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)3