Search in sources :

Example 1 with ExecutionContext

use of com.amazon.randomcutforest.state.ExecutionContext in project random-cut-forest-by-aws by aws.

the class V1JsonToV3StateConverter method convert.

/**
 * the function merges a collection of RCF-1.0 models with same model parameters
 * and fixes the number of trees in the new model (which has to be less or equal
 * than the sum of the old models) The conversion uses the execution context of
 * the first forest and can be adjusted subsequently by setters
 *
 * @param serializedForests A non-empty list of forests (together having more
 *                          trees than numberOfTrees)
 * @param numberOfTrees     the new number of trees
 * @param precision         the precision of the new forest
 * @return a merged RCF with the first numberOfTrees trees
 */
public RandomCutForestState convert(List<V1SerializedRandomCutForest> serializedForests, int numberOfTrees, Precision precision) {
    checkArgument(serializedForests.size() > 0, "incorrect usage of convert");
    checkArgument(numberOfTrees > 0, "incorrect parameter");
    int sum = 0;
    for (int i = 0; i < serializedForests.size(); i++) {
        sum += serializedForests.get(i).getNumberOfTrees();
    }
    checkArgument(sum >= numberOfTrees, "incorrect parameters");
    RandomCutForestState state = new RandomCutForestState();
    state.setNumberOfTrees(numberOfTrees);
    state.setDimensions(serializedForests.get(0).getDimensions());
    state.setTimeDecay(serializedForests.get(0).getLambda());
    state.setSampleSize(serializedForests.get(0).getSampleSize());
    state.setShingleSize(1);
    state.setCenterOfMassEnabled(serializedForests.get(0).isCenterOfMassEnabled());
    state.setOutputAfter(serializedForests.get(0).getOutputAfter());
    state.setStoreSequenceIndexesEnabled(serializedForests.get(0).isStoreSequenceIndexesEnabled());
    state.setTotalUpdates(serializedForests.get(0).getExecutor().getExecutor().getTotalUpdates());
    state.setCompact(true);
    state.setInternalShinglingEnabled(false);
    state.setBoundingBoxCacheFraction(1.0);
    state.setSaveSamplerStateEnabled(true);
    state.setSaveTreeStateEnabled(false);
    state.setSaveCoordinatorStateEnabled(true);
    state.setPrecision(precision.name());
    state.setCompressed(false);
    state.setPartialTreeState(false);
    ExecutionContext executionContext = new ExecutionContext();
    executionContext.setParallelExecutionEnabled(serializedForests.get(0).isParallelExecutionEnabled());
    executionContext.setThreadPoolSize(serializedForests.get(0).getThreadPoolSize());
    state.setExecutionContext(executionContext);
    SamplerConverter samplerConverter = new SamplerConverter(state.getDimensions(), state.getNumberOfTrees() * state.getSampleSize() + 1, precision, numberOfTrees);
    serializedForests.stream().flatMap(f -> Arrays.stream(f.getExecutor().getExecutor().getTreeUpdaters())).limit(numberOfTrees).map(V1SerializedRandomCutForest.TreeUpdater::getSampler).forEach(samplerConverter::addSampler);
    state.setPointStoreState(samplerConverter.getPointStoreState(precision));
    state.setCompactSamplerStates(samplerConverter.compactSamplerStates);
    return state;
}
Also used : Arrays(java.util.Arrays) RandomCutForestState(com.amazon.randomcutforest.state.RandomCutForestState) ExecutionContext(com.amazon.randomcutforest.state.ExecutionContext) CompactSamplerState(com.amazon.randomcutforest.state.sampler.CompactSamplerState) URL(java.net.URL) Precision(com.amazon.randomcutforest.config.Precision) CommonUtils.checkArgument(com.amazon.randomcutforest.CommonUtils.checkArgument) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) IOException(java.io.IOException) Reader(java.io.Reader) PointStoreState(com.amazon.randomcutforest.state.store.PointStoreState) RandomCutTree(com.amazon.randomcutforest.tree.RandomCutTree) ArrayList(java.util.ArrayList) PointStore(com.amazon.randomcutforest.store.PointStore) ITree(com.amazon.randomcutforest.tree.ITree) List(java.util.List) Optional(java.util.Optional) PointStoreMapper(com.amazon.randomcutforest.state.store.PointStoreMapper) IPointStore(com.amazon.randomcutforest.store.IPointStore) Collections(java.util.Collections) ExecutionContext(com.amazon.randomcutforest.state.ExecutionContext) RandomCutForestState(com.amazon.randomcutforest.state.RandomCutForestState)

Aggregations

CommonUtils.checkArgument (com.amazon.randomcutforest.CommonUtils.checkArgument)1 Precision (com.amazon.randomcutforest.config.Precision)1 ExecutionContext (com.amazon.randomcutforest.state.ExecutionContext)1 RandomCutForestState (com.amazon.randomcutforest.state.RandomCutForestState)1 CompactSamplerState (com.amazon.randomcutforest.state.sampler.CompactSamplerState)1 PointStoreMapper (com.amazon.randomcutforest.state.store.PointStoreMapper)1 PointStoreState (com.amazon.randomcutforest.state.store.PointStoreState)1 IPointStore (com.amazon.randomcutforest.store.IPointStore)1 PointStore (com.amazon.randomcutforest.store.PointStore)1 ITree (com.amazon.randomcutforest.tree.ITree)1 RandomCutTree (com.amazon.randomcutforest.tree.RandomCutTree)1 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)1 IOException (java.io.IOException)1 Reader (java.io.Reader)1 URL (java.net.URL)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 Collections (java.util.Collections)1 List (java.util.List)1 Optional (java.util.Optional)1