Search in sources :

Example 1 with InstanceDetails

use of beagle.InstanceDetails in project beast2 by CompEvol.

the class BeagleTreeLikelihood method initialize.

private boolean initialize() {
    m_nNodeCount = treeInput.get().getNodeCount();
    m_bUseAmbiguities = m_useAmbiguities.get();
    m_bUseTipLikelihoods = m_useTipLikelihoods.get();
    if (!(siteModelInput.get() instanceof SiteModel.Base)) {
        throw new IllegalArgumentException("siteModel input should be of type SiteModel.Base");
    }
    m_siteModel = (SiteModel.Base) siteModelInput.get();
    m_siteModel.setDataType(dataInput.get().getDataType());
    substitutionModel = m_siteModel.substModelInput.get();
    branchRateModel = branchRateModelInput.get();
    if (branchRateModel == null) {
        branchRateModel = new StrictClockModel();
    }
    m_branchLengths = new double[m_nNodeCount];
    storedBranchLengths = new double[m_nNodeCount];
    m_nStateCount = dataInput.get().getMaxStateCount();
    patternCount = dataInput.get().getPatternCount();
    // System.err.println("Attempt to load BEAGLE TreeLikelihood");
    // this.branchSubstitutionModel.getEigenCount();
    eigenCount = 1;
    double[] categoryRates = m_siteModel.getCategoryRates(null);
    // check for invariant rates category
    if (m_siteModel.hasPropInvariantCategory) {
        for (int i = 0; i < categoryRates.length; i++) {
            if (categoryRates[i] == 0) {
                proportionInvariant = m_siteModel.getRateForCategory(i, null);
                int stateCount = dataInput.get().getMaxStateCount();
                int patterns = dataInput.get().getPatternCount();
                calcConstantPatternIndices(patterns, stateCount);
                invariantCategory = i;
                double[] tmp = new double[categoryRates.length - 1];
                for (int k = 0; k < invariantCategory; k++) {
                    tmp[k] = categoryRates[k];
                }
                for (int k = invariantCategory + 1; k < categoryRates.length; k++) {
                    tmp[k - 1] = categoryRates[k];
                }
                categoryRates = tmp;
                break;
            }
        }
        if (constantPattern != null && constantPattern.size() > dataInput.get().getPatternCount()) {
            // if there are many more constant patterns than patterns (each pattern can
            // have a number of constant patters, one for each state) it is less efficient
            // to just calculate the TreeLikelihood for constant sites than optimising
            Log.debug("switch off constant sites optimisiation: calculating through separate TreeLikelihood category (as in the olden days)");
            invariantCategory = -1;
            proportionInvariant = 0;
            constantPattern = null;
            categoryRates = m_siteModel.getCategoryRates(null);
        }
    }
    this.categoryCount = m_siteModel.getCategoryCount() - (invariantCategory >= 0 ? 1 : 0);
    tipCount = treeInput.get().getLeafNodeCount();
    internalNodeCount = m_nNodeCount - tipCount;
    int compactPartialsCount = tipCount;
    if (m_bUseAmbiguities) {
        // if we are using ambiguities then we don't use tip partials
        compactPartialsCount = 0;
    }
    // one partials buffer for each tip and two for each internal node (for store restore)
    partialBufferHelper = new BufferIndexHelper(m_nNodeCount, tipCount);
    // two eigen buffers for each decomposition for store and restore.
    eigenBufferHelper = new BufferIndexHelper(eigenCount, 0);
    // two matrices for each node less the root
    matrixBufferHelper = new BufferIndexHelper(m_nNodeCount, 0);
    // one scaling buffer for each internal node plus an extra for the accumulation, then doubled for store/restore
    scaleBufferHelper = new BufferIndexHelper(getScaleBufferCount(), 0);
    // Attempt to get the resource order from the System Property
    if (resourceOrder == null) {
        resourceOrder = parseSystemPropertyIntegerArray(RESOURCE_ORDER_PROPERTY);
    }
    if (preferredOrder == null) {
        preferredOrder = parseSystemPropertyIntegerArray(PREFERRED_FLAGS_PROPERTY);
    }
    if (requiredOrder == null) {
        requiredOrder = parseSystemPropertyIntegerArray(REQUIRED_FLAGS_PROPERTY);
    }
    if (scalingOrder == null) {
        scalingOrder = parseSystemPropertyStringArray(SCALING_PROPERTY);
    }
    // first set the rescaling scheme to use from the parser
    // = rescalingScheme;
    rescalingScheme = PartialsRescalingScheme.DEFAULT;
    rescalingScheme = DEFAULT_RESCALING_SCHEME;
    int[] resourceList = null;
    long preferenceFlags = 0;
    long requirementFlags = 0;
    if (scalingOrder.size() > 0) {
        this.rescalingScheme = PartialsRescalingScheme.parseFromString(scalingOrder.get(instanceCount % scalingOrder.size()));
    }
    if (resourceOrder.size() > 0) {
        // added the zero on the end so that a CPU is selected if requested resource fails
        resourceList = new int[] { resourceOrder.get(instanceCount % resourceOrder.size()), 0 };
        if (resourceList[0] > 0) {
            // Add preference weight against CPU
            preferenceFlags |= BeagleFlag.PROCESSOR_GPU.getMask();
        }
    }
    if (preferredOrder.size() > 0) {
        preferenceFlags = preferredOrder.get(instanceCount % preferredOrder.size());
    }
    if (requiredOrder.size() > 0) {
        requirementFlags = requiredOrder.get(instanceCount % requiredOrder.size());
    }
    if (scaling.get().equals(Scaling.always)) {
        this.rescalingScheme = PartialsRescalingScheme.ALWAYS;
    }
    if (scaling.get().equals(Scaling.none)) {
        this.rescalingScheme = PartialsRescalingScheme.NONE;
    }
    // Define default behaviour here
    if (this.rescalingScheme == PartialsRescalingScheme.DEFAULT) {
        // if GPU: the default is^H^Hwas dynamic scaling in BEAST, now NONE
        if (resourceList != null && resourceList[0] > 1) {
            // this.rescalingScheme = PartialsRescalingScheme.DYNAMIC;
            this.rescalingScheme = PartialsRescalingScheme.NONE;
        } else {
            // if CPU: just run as fast as possible
            // this.rescalingScheme = PartialsRescalingScheme.NONE;
            // Dynamic should run as fast as none until first underflow
            this.rescalingScheme = PartialsRescalingScheme.DYNAMIC;
        }
    }
    if (this.rescalingScheme == PartialsRescalingScheme.AUTO) {
        preferenceFlags |= BeagleFlag.SCALING_AUTO.getMask();
        useAutoScaling = true;
    } else {
    // preferenceFlags |= BeagleFlag.SCALING_MANUAL.getMask();
    }
    String r = System.getProperty(RESCALE_FREQUENCY_PROPERTY);
    if (r != null) {
        rescalingFrequency = Integer.parseInt(r);
        if (rescalingFrequency < 1) {
            rescalingFrequency = RESCALE_FREQUENCY;
        }
    }
    if (preferenceFlags == 0 && resourceList == null) {
        // else determine dataset characteristics
        if (// TODO determine good cut-off
        m_nStateCount == 4 && patternCount < 10000)
            preferenceFlags |= BeagleFlag.PROCESSOR_CPU.getMask();
    }
    if (substitutionModel.canReturnComplexDiagonalization()) {
        requirementFlags |= BeagleFlag.EIGEN_COMPLEX.getMask();
    }
    instanceCount++;
    try {
        beagle = BeagleFactory.loadBeagleInstance(tipCount, partialBufferHelper.getBufferCount(), compactPartialsCount, m_nStateCount, patternCount, // eigenBufferCount
        eigenBufferHelper.getBufferCount(), matrixBufferHelper.getBufferCount(), categoryCount, // Always allocate; they may become necessary
        scaleBufferHelper.getBufferCount(), resourceList, preferenceFlags, requirementFlags);
    } catch (Exception e) {
        beagle = null;
    }
    if (beagle == null) {
        return false;
    }
    InstanceDetails instanceDetails = beagle.getDetails();
    ResourceDetails resourceDetails = null;
    if (instanceDetails != null) {
        resourceDetails = BeagleFactory.getResourceDetails(instanceDetails.getResourceNumber());
        if (resourceDetails != null) {
            StringBuilder sb = new StringBuilder("  Using BEAGLE version: " + BeagleInfo.getVersion() + " resource ");
            sb.append(resourceDetails.getNumber()).append(": ");
            sb.append(resourceDetails.getName()).append("\n");
            if (resourceDetails.getDescription() != null) {
                String[] description = resourceDetails.getDescription().split("\\|");
                for (String desc : description) {
                    if (desc.trim().length() > 0) {
                        sb.append("    ").append(desc.trim()).append("\n");
                    }
                }
            }
            sb.append("    with instance flags: ").append(instanceDetails.toString());
            Log.info.println(sb.toString());
        } else {
            Log.warning.println("  Error retrieving BEAGLE resource for instance: " + instanceDetails.toString());
            beagle = null;
            return false;
        }
    } else {
        Log.warning.println("  No external BEAGLE resources available, or resource list/requirements not met, using Java implementation");
        beagle = null;
        return false;
    }
    Log.warning.println("  " + (m_bUseAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood.");
    Log.warning.println("  " + (m_bUseTipLikelihoods ? "Using" : "Ignoring") + " character uncertainty in tree likelihood.");
    Log.warning.println("  With " + patternCount + " unique site patterns.");
    Node[] nodes = treeInput.get().getNodesAsArray();
    for (int i = 0; i < tipCount; i++) {
        int taxon = dataInput.get().getTaxonIndex(nodes[i].getID());
        if (m_bUseAmbiguities || m_bUseTipLikelihoods) {
            setPartials(beagle, i, taxon);
        } else {
            setStates(beagle, i, taxon);
        }
    }
    if (dataInput.get().isAscertained) {
        ascertainedSitePatterns = true;
    }
    double[] patternWeights = new double[patternCount];
    for (int i = 0; i < patternCount; i++) {
        patternWeights[i] = dataInput.get().getPatternWeight(i);
    }
    beagle.setPatternWeights(patternWeights);
    if (this.rescalingScheme == PartialsRescalingScheme.AUTO && resourceDetails != null && (resourceDetails.getFlags() & BeagleFlag.SCALING_AUTO.getMask()) == 0) {
        // If auto scaling in BEAGLE is not supported then do it here
        this.rescalingScheme = PartialsRescalingScheme.DYNAMIC;
        Log.warning.println("  Auto rescaling not supported in BEAGLE, using : " + this.rescalingScheme.getText());
    } else {
        Log.warning.println("  Using rescaling scheme : " + this.rescalingScheme.getText());
    }
    if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) {
        // If false, BEAST does not rescale until first under-/over-flow.
        everUnderflowed = false;
    }
    updateSubstitutionModel = true;
    updateSiteModel = true;
    // some subst models (e.g. WAG) never become dirty, so set up subst models right now
    setUpSubstModel();
    // set up sitemodel
    beagle.setCategoryRates(categoryRates);
    currentCategoryRates = categoryRates;
    currentFreqs = new double[m_nStateCount];
    currentCategoryWeights = new double[categoryRates.length];
    return true;
}
Also used : InstanceDetails(beagle.InstanceDetails) CalculationNode(beast.core.CalculationNode) Node(beast.evolution.tree.Node) SiteModel(beast.evolution.sitemodel.SiteModel) ResourceDetails(beagle.ResourceDetails) StrictClockModel(beast.evolution.branchratemodel.StrictClockModel)

Aggregations

InstanceDetails (beagle.InstanceDetails)1 ResourceDetails (beagle.ResourceDetails)1 CalculationNode (beast.core.CalculationNode)1 StrictClockModel (beast.evolution.branchratemodel.StrictClockModel)1 SiteModel (beast.evolution.sitemodel.SiteModel)1 Node (beast.evolution.tree.Node)1