use of com.amazon.randomcutforest.state.ExecutionContext in project random-cut-forest-by-aws by aws.
the class V1JsonToV3StateConverter method convert.
/**
* the function merges a collection of RCF-1.0 models with same model parameters
* and fixes the number of trees in the new model (which has to be less or equal
* than the sum of the old models) The conversion uses the execution context of
* the first forest and can be adjusted subsequently by setters
*
* @param serializedForests A non-empty list of forests (together having more
* trees than numberOfTrees)
* @param numberOfTrees the new number of trees
* @param precision the precision of the new forest
* @return a merged RCF with the first numberOfTrees trees
*/
public RandomCutForestState convert(List<V1SerializedRandomCutForest> serializedForests, int numberOfTrees, Precision precision) {
checkArgument(serializedForests.size() > 0, "incorrect usage of convert");
checkArgument(numberOfTrees > 0, "incorrect parameter");
int sum = 0;
for (int i = 0; i < serializedForests.size(); i++) {
sum += serializedForests.get(i).getNumberOfTrees();
}
checkArgument(sum >= numberOfTrees, "incorrect parameters");
RandomCutForestState state = new RandomCutForestState();
state.setNumberOfTrees(numberOfTrees);
state.setDimensions(serializedForests.get(0).getDimensions());
state.setTimeDecay(serializedForests.get(0).getLambda());
state.setSampleSize(serializedForests.get(0).getSampleSize());
state.setShingleSize(1);
state.setCenterOfMassEnabled(serializedForests.get(0).isCenterOfMassEnabled());
state.setOutputAfter(serializedForests.get(0).getOutputAfter());
state.setStoreSequenceIndexesEnabled(serializedForests.get(0).isStoreSequenceIndexesEnabled());
state.setTotalUpdates(serializedForests.get(0).getExecutor().getExecutor().getTotalUpdates());
state.setCompact(true);
state.setInternalShinglingEnabled(false);
state.setBoundingBoxCacheFraction(1.0);
state.setSaveSamplerStateEnabled(true);
state.setSaveTreeStateEnabled(false);
state.setSaveCoordinatorStateEnabled(true);
state.setPrecision(precision.name());
state.setCompressed(false);
state.setPartialTreeState(false);
ExecutionContext executionContext = new ExecutionContext();
executionContext.setParallelExecutionEnabled(serializedForests.get(0).isParallelExecutionEnabled());
executionContext.setThreadPoolSize(serializedForests.get(0).getThreadPoolSize());
state.setExecutionContext(executionContext);
SamplerConverter samplerConverter = new SamplerConverter(state.getDimensions(), state.getNumberOfTrees() * state.getSampleSize() + 1, precision, numberOfTrees);
serializedForests.stream().flatMap(f -> Arrays.stream(f.getExecutor().getExecutor().getTreeUpdaters())).limit(numberOfTrees).map(V1SerializedRandomCutForest.TreeUpdater::getSampler).forEach(samplerConverter::addSampler);
state.setPointStoreState(samplerConverter.getPointStoreState(precision));
state.setCompactSamplerStates(samplerConverter.compactSamplerStates);
return state;
}
Aggregations