use of beagle.InstanceDetails in project beast2 by CompEvol.
the class BeagleTreeLikelihood method initialize.
private boolean initialize() {
m_nNodeCount = treeInput.get().getNodeCount();
m_bUseAmbiguities = m_useAmbiguities.get();
m_bUseTipLikelihoods = m_useTipLikelihoods.get();
if (!(siteModelInput.get() instanceof SiteModel.Base)) {
throw new IllegalArgumentException("siteModel input should be of type SiteModel.Base");
}
m_siteModel = (SiteModel.Base) siteModelInput.get();
m_siteModel.setDataType(dataInput.get().getDataType());
substitutionModel = m_siteModel.substModelInput.get();
branchRateModel = branchRateModelInput.get();
if (branchRateModel == null) {
branchRateModel = new StrictClockModel();
}
m_branchLengths = new double[m_nNodeCount];
storedBranchLengths = new double[m_nNodeCount];
m_nStateCount = dataInput.get().getMaxStateCount();
patternCount = dataInput.get().getPatternCount();
// System.err.println("Attempt to load BEAGLE TreeLikelihood");
// this.branchSubstitutionModel.getEigenCount();
eigenCount = 1;
double[] categoryRates = m_siteModel.getCategoryRates(null);
// check for invariant rates category
if (m_siteModel.hasPropInvariantCategory) {
for (int i = 0; i < categoryRates.length; i++) {
if (categoryRates[i] == 0) {
proportionInvariant = m_siteModel.getRateForCategory(i, null);
int stateCount = dataInput.get().getMaxStateCount();
int patterns = dataInput.get().getPatternCount();
calcConstantPatternIndices(patterns, stateCount);
invariantCategory = i;
double[] tmp = new double[categoryRates.length - 1];
for (int k = 0; k < invariantCategory; k++) {
tmp[k] = categoryRates[k];
}
for (int k = invariantCategory + 1; k < categoryRates.length; k++) {
tmp[k - 1] = categoryRates[k];
}
categoryRates = tmp;
break;
}
}
if (constantPattern != null && constantPattern.size() > dataInput.get().getPatternCount()) {
// if there are many more constant patterns than patterns (each pattern can
// have a number of constant patters, one for each state) it is less efficient
// to just calculate the TreeLikelihood for constant sites than optimising
Log.debug("switch off constant sites optimisiation: calculating through separate TreeLikelihood category (as in the olden days)");
invariantCategory = -1;
proportionInvariant = 0;
constantPattern = null;
categoryRates = m_siteModel.getCategoryRates(null);
}
}
this.categoryCount = m_siteModel.getCategoryCount() - (invariantCategory >= 0 ? 1 : 0);
tipCount = treeInput.get().getLeafNodeCount();
internalNodeCount = m_nNodeCount - tipCount;
int compactPartialsCount = tipCount;
if (m_bUseAmbiguities) {
// if we are using ambiguities then we don't use tip partials
compactPartialsCount = 0;
}
// one partials buffer for each tip and two for each internal node (for store restore)
partialBufferHelper = new BufferIndexHelper(m_nNodeCount, tipCount);
// two eigen buffers for each decomposition for store and restore.
eigenBufferHelper = new BufferIndexHelper(eigenCount, 0);
// two matrices for each node less the root
matrixBufferHelper = new BufferIndexHelper(m_nNodeCount, 0);
// one scaling buffer for each internal node plus an extra for the accumulation, then doubled for store/restore
scaleBufferHelper = new BufferIndexHelper(getScaleBufferCount(), 0);
// Attempt to get the resource order from the System Property
if (resourceOrder == null) {
resourceOrder = parseSystemPropertyIntegerArray(RESOURCE_ORDER_PROPERTY);
}
if (preferredOrder == null) {
preferredOrder = parseSystemPropertyIntegerArray(PREFERRED_FLAGS_PROPERTY);
}
if (requiredOrder == null) {
requiredOrder = parseSystemPropertyIntegerArray(REQUIRED_FLAGS_PROPERTY);
}
if (scalingOrder == null) {
scalingOrder = parseSystemPropertyStringArray(SCALING_PROPERTY);
}
// first set the rescaling scheme to use from the parser
// = rescalingScheme;
rescalingScheme = PartialsRescalingScheme.DEFAULT;
rescalingScheme = DEFAULT_RESCALING_SCHEME;
int[] resourceList = null;
long preferenceFlags = 0;
long requirementFlags = 0;
if (scalingOrder.size() > 0) {
this.rescalingScheme = PartialsRescalingScheme.parseFromString(scalingOrder.get(instanceCount % scalingOrder.size()));
}
if (resourceOrder.size() > 0) {
// added the zero on the end so that a CPU is selected if requested resource fails
resourceList = new int[] { resourceOrder.get(instanceCount % resourceOrder.size()), 0 };
if (resourceList[0] > 0) {
// Add preference weight against CPU
preferenceFlags |= BeagleFlag.PROCESSOR_GPU.getMask();
}
}
if (preferredOrder.size() > 0) {
preferenceFlags = preferredOrder.get(instanceCount % preferredOrder.size());
}
if (requiredOrder.size() > 0) {
requirementFlags = requiredOrder.get(instanceCount % requiredOrder.size());
}
if (scaling.get().equals(Scaling.always)) {
this.rescalingScheme = PartialsRescalingScheme.ALWAYS;
}
if (scaling.get().equals(Scaling.none)) {
this.rescalingScheme = PartialsRescalingScheme.NONE;
}
// Define default behaviour here
if (this.rescalingScheme == PartialsRescalingScheme.DEFAULT) {
// if GPU: the default is^H^Hwas dynamic scaling in BEAST, now NONE
if (resourceList != null && resourceList[0] > 1) {
// this.rescalingScheme = PartialsRescalingScheme.DYNAMIC;
this.rescalingScheme = PartialsRescalingScheme.NONE;
} else {
// if CPU: just run as fast as possible
// this.rescalingScheme = PartialsRescalingScheme.NONE;
// Dynamic should run as fast as none until first underflow
this.rescalingScheme = PartialsRescalingScheme.DYNAMIC;
}
}
if (this.rescalingScheme == PartialsRescalingScheme.AUTO) {
preferenceFlags |= BeagleFlag.SCALING_AUTO.getMask();
useAutoScaling = true;
} else {
// preferenceFlags |= BeagleFlag.SCALING_MANUAL.getMask();
}
String r = System.getProperty(RESCALE_FREQUENCY_PROPERTY);
if (r != null) {
rescalingFrequency = Integer.parseInt(r);
if (rescalingFrequency < 1) {
rescalingFrequency = RESCALE_FREQUENCY;
}
}
if (preferenceFlags == 0 && resourceList == null) {
// else determine dataset characteristics
if (// TODO determine good cut-off
m_nStateCount == 4 && patternCount < 10000)
preferenceFlags |= BeagleFlag.PROCESSOR_CPU.getMask();
}
if (substitutionModel.canReturnComplexDiagonalization()) {
requirementFlags |= BeagleFlag.EIGEN_COMPLEX.getMask();
}
instanceCount++;
try {
beagle = BeagleFactory.loadBeagleInstance(tipCount, partialBufferHelper.getBufferCount(), compactPartialsCount, m_nStateCount, patternCount, // eigenBufferCount
eigenBufferHelper.getBufferCount(), matrixBufferHelper.getBufferCount(), categoryCount, // Always allocate; they may become necessary
scaleBufferHelper.getBufferCount(), resourceList, preferenceFlags, requirementFlags);
} catch (Exception e) {
beagle = null;
}
if (beagle == null) {
return false;
}
InstanceDetails instanceDetails = beagle.getDetails();
ResourceDetails resourceDetails = null;
if (instanceDetails != null) {
resourceDetails = BeagleFactory.getResourceDetails(instanceDetails.getResourceNumber());
if (resourceDetails != null) {
StringBuilder sb = new StringBuilder(" Using BEAGLE version: " + BeagleInfo.getVersion() + " resource ");
sb.append(resourceDetails.getNumber()).append(": ");
sb.append(resourceDetails.getName()).append("\n");
if (resourceDetails.getDescription() != null) {
String[] description = resourceDetails.getDescription().split("\\|");
for (String desc : description) {
if (desc.trim().length() > 0) {
sb.append(" ").append(desc.trim()).append("\n");
}
}
}
sb.append(" with instance flags: ").append(instanceDetails.toString());
Log.info.println(sb.toString());
} else {
Log.warning.println(" Error retrieving BEAGLE resource for instance: " + instanceDetails.toString());
beagle = null;
return false;
}
} else {
Log.warning.println(" No external BEAGLE resources available, or resource list/requirements not met, using Java implementation");
beagle = null;
return false;
}
Log.warning.println(" " + (m_bUseAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood.");
Log.warning.println(" " + (m_bUseTipLikelihoods ? "Using" : "Ignoring") + " character uncertainty in tree likelihood.");
Log.warning.println(" With " + patternCount + " unique site patterns.");
Node[] nodes = treeInput.get().getNodesAsArray();
for (int i = 0; i < tipCount; i++) {
int taxon = dataInput.get().getTaxonIndex(nodes[i].getID());
if (m_bUseAmbiguities || m_bUseTipLikelihoods) {
setPartials(beagle, i, taxon);
} else {
setStates(beagle, i, taxon);
}
}
if (dataInput.get().isAscertained) {
ascertainedSitePatterns = true;
}
double[] patternWeights = new double[patternCount];
for (int i = 0; i < patternCount; i++) {
patternWeights[i] = dataInput.get().getPatternWeight(i);
}
beagle.setPatternWeights(patternWeights);
if (this.rescalingScheme == PartialsRescalingScheme.AUTO && resourceDetails != null && (resourceDetails.getFlags() & BeagleFlag.SCALING_AUTO.getMask()) == 0) {
// If auto scaling in BEAGLE is not supported then do it here
this.rescalingScheme = PartialsRescalingScheme.DYNAMIC;
Log.warning.println(" Auto rescaling not supported in BEAGLE, using : " + this.rescalingScheme.getText());
} else {
Log.warning.println(" Using rescaling scheme : " + this.rescalingScheme.getText());
}
if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) {
// If false, BEAST does not rescale until first under-/over-flow.
everUnderflowed = false;
}
updateSubstitutionModel = true;
updateSiteModel = true;
// some subst models (e.g. WAG) never become dirty, so set up subst models right now
setUpSubstModel();
// set up sitemodel
beagle.setCategoryRates(categoryRates);
currentCategoryRates = categoryRates;
currentFreqs = new double[m_nStateCount];
currentCategoryWeights = new double[categoryRates.length];
return true;
}
Aggregations