use of com.amazon.randomcutforest.state.RandomCutForestState in project random-cut-forest-by-aws by aws.
the class V1JsonToV3StateConverter method convert.
/**
* the function merges a collection of RCF-1.0 models with same model parameters
* and fixes the number of trees in the new model (which has to be less or equal
* than the sum of the old models) The conversion uses the execution context of
* the first forest and can be adjusted subsequently by setters
*
* @param serializedForests A non-empty list of forests (together having more
* trees than numberOfTrees)
* @param numberOfTrees the new number of trees
* @param precision the precision of the new forest
* @return a merged RCF with the first numberOfTrees trees
*/
public RandomCutForestState convert(List<V1SerializedRandomCutForest> serializedForests, int numberOfTrees, Precision precision) {
checkArgument(serializedForests.size() > 0, "incorrect usage of convert");
checkArgument(numberOfTrees > 0, "incorrect parameter");
int sum = 0;
for (int i = 0; i < serializedForests.size(); i++) {
sum += serializedForests.get(i).getNumberOfTrees();
}
checkArgument(sum >= numberOfTrees, "incorrect parameters");
RandomCutForestState state = new RandomCutForestState();
state.setNumberOfTrees(numberOfTrees);
state.setDimensions(serializedForests.get(0).getDimensions());
state.setTimeDecay(serializedForests.get(0).getLambda());
state.setSampleSize(serializedForests.get(0).getSampleSize());
state.setShingleSize(1);
state.setCenterOfMassEnabled(serializedForests.get(0).isCenterOfMassEnabled());
state.setOutputAfter(serializedForests.get(0).getOutputAfter());
state.setStoreSequenceIndexesEnabled(serializedForests.get(0).isStoreSequenceIndexesEnabled());
state.setTotalUpdates(serializedForests.get(0).getExecutor().getExecutor().getTotalUpdates());
state.setCompact(true);
state.setInternalShinglingEnabled(false);
state.setBoundingBoxCacheFraction(1.0);
state.setSaveSamplerStateEnabled(true);
state.setSaveTreeStateEnabled(false);
state.setSaveCoordinatorStateEnabled(true);
state.setPrecision(precision.name());
state.setCompressed(false);
state.setPartialTreeState(false);
ExecutionContext executionContext = new ExecutionContext();
executionContext.setParallelExecutionEnabled(serializedForests.get(0).isParallelExecutionEnabled());
executionContext.setThreadPoolSize(serializedForests.get(0).getThreadPoolSize());
state.setExecutionContext(executionContext);
SamplerConverter samplerConverter = new SamplerConverter(state.getDimensions(), state.getNumberOfTrees() * state.getSampleSize() + 1, precision, numberOfTrees);
serializedForests.stream().flatMap(f -> Arrays.stream(f.getExecutor().getExecutor().getTreeUpdaters())).limit(numberOfTrees).map(V1SerializedRandomCutForest.TreeUpdater::getSampler).forEach(samplerConverter::addSampler);
state.setPointStoreState(samplerConverter.getPointStoreState(precision));
state.setCompactSamplerStates(samplerConverter.compactSamplerStates);
return state;
}
use of com.amazon.randomcutforest.state.RandomCutForestState in project random-cut-forest-by-aws by aws.
the class V1JsonToV3StateConverterTest method testConvert.
@ParameterizedTest
@MethodSource("args")
public void testConvert(V1JsonResource jsonResource, Precision precision) {
String resource = jsonResource.getResource();
try (InputStream is = V1JsonToV3StateConverterTest.class.getResourceAsStream(jsonResource.getResource());
BufferedReader rr = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
StringBuilder b = new StringBuilder();
String line;
while ((line = rr.readLine()) != null) {
b.append(line);
}
String json = b.toString();
RandomCutForestState state = converter.convert(json, precision);
assertEquals(jsonResource.getDimensions(), state.getDimensions());
assertEquals(jsonResource.getNumberOfTrees(), state.getNumberOfTrees());
assertEquals(jsonResource.getSampleSize(), state.getSampleSize());
RandomCutForest forest = new RandomCutForestMapper().toModel(state, 0);
assertEquals(jsonResource.getDimensions(), forest.getDimensions());
assertEquals(jsonResource.getNumberOfTrees(), forest.getNumberOfTrees());
assertEquals(jsonResource.getSampleSize(), forest.getSampleSize());
// perform a simple validation of the deserialized forest by update and scoring
// with a few points
Random random = new Random(0);
for (int i = 0; i < 100; i++) {
double[] point = getPoint(jsonResource.getDimensions(), random);
double score = forest.getAnomalyScore(point);
assertTrue(score > 0);
forest.update(point);
}
String newString = new ObjectMapper().writeValueAsString(new RandomCutForestMapper().toState(forest));
System.out.println(" Old size " + json.length() + ", new Size " + newString.length() + ", improvement factor " + json.length() / newString.length());
} catch (IOException e) {
fail("Unable to load JSON resource");
}
}
use of com.amazon.randomcutforest.state.RandomCutForestState in project random-cut-forest-by-aws by aws.
the class V1JsonToV3StateConverterTest method testMerge.
@ParameterizedTest
@MethodSource("args")
public void testMerge(V1JsonResource jsonResource, Precision precision) {
String resource = jsonResource.getResource();
try (InputStream is = V1JsonToV3StateConverterTest.class.getResourceAsStream(jsonResource.getResource());
BufferedReader rr = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
StringBuilder b = new StringBuilder();
String line;
while ((line = rr.readLine()) != null) {
b.append(line);
}
String json = b.toString();
int number = new Random().nextInt(10) + 1;
int testNumberOfTrees = Math.min(100, 1 + new Random().nextInt(number * jsonResource.getNumberOfTrees() - 1));
ArrayList<String> models = new ArrayList<>();
for (int i = 0; i < number; i++) {
models.add(json);
}
RandomCutForestState state = converter.convert(models, testNumberOfTrees, precision).get();
assertEquals(jsonResource.getDimensions(), state.getDimensions());
assertEquals(testNumberOfTrees, state.getNumberOfTrees());
assertEquals(jsonResource.getSampleSize(), state.getSampleSize());
RandomCutForest forest = new RandomCutForestMapper().toModel(state, 0);
assertEquals(jsonResource.getDimensions(), forest.getDimensions());
assertEquals(testNumberOfTrees, forest.getNumberOfTrees());
assertEquals(jsonResource.getSampleSize(), forest.getSampleSize());
// perform a simple validation of the deserialized forest by update and scoring
// with a few points
Random random = new Random(0);
for (int i = 0; i < 100; i++) {
double[] point = getPoint(jsonResource.getDimensions(), random);
double score = forest.getAnomalyScore(point);
assertTrue(score > 0);
forest.update(point);
}
int expectedSize = (int) Math.floor(1.0 * testNumberOfTrees * json.length() / (number * jsonResource.getNumberOfTrees()));
String newString = new ObjectMapper().writeValueAsString(new RandomCutForestMapper().toState(forest));
System.out.println(" Copied " + number + " times, old number of trees " + jsonResource.getNumberOfTrees() + ", new trees " + testNumberOfTrees + ", Expected Old size " + expectedSize + ", new Size " + newString.length());
} catch (IOException e) {
fail("Unable to load JSON resource");
}
}
use of com.amazon.randomcutforest.state.RandomCutForestState in project random-cut-forest-by-aws by aws.
the class RandomCutForestShingledFunctionalTest method testUpdate.
@Test
public void testUpdate() {
int dimensions = 10;
RandomCutForest forest = RandomCutForest.builder().numberOfTrees(100).compact(true).dimensions(dimensions).randomSeed(0).sampleSize(200).precision(Precision.FLOAT_32).build();
double[][] trainingData = genShingledData(1000, dimensions, 0);
double[][] testData = genShingledData(100, dimensions, 1);
for (int i = 0; i < testData.length; i++) {
RandomCutForestMapper mapper = new RandomCutForestMapper();
mapper.setSaveExecutorContextEnabled(true);
mapper.setSaveTreeStateEnabled(true);
double score = forest.getAnomalyScore(testData[i]);
forest.update(testData[i]);
RandomCutForestState forestState = mapper.toState(forest);
forest = mapper.toModel(forestState);
}
}
use of com.amazon.randomcutforest.state.RandomCutForestState in project random-cut-forest-by-aws by aws.
the class StateMapperShingledBenchmark method roundTripFromProtostuff.
@Benchmark
@OperationsPerInvocation(NUM_TEST_SAMPLES)
public byte[] roundTripFromProtostuff(BenchmarkState state, Blackhole blackhole) {
bytes = state.protostuff;
double[][] testData = state.testData;
for (int i = 0; i < NUM_TEST_SAMPLES; i++) {
Schema<RandomCutForestState> schema = RuntimeSchema.getSchema(RandomCutForestState.class);
RandomCutForestState forestState = schema.newMessage();
ProtostuffIOUtil.mergeFrom(bytes, forestState, schema);
RandomCutForestMapper mapper = new RandomCutForestMapper();
mapper.setSaveExecutorContextEnabled(true);
mapper.setSaveTreeStateEnabled(state.saveTreeState);
RandomCutForest forest = mapper.toModel(forestState);
double score = forest.getAnomalyScore(testData[i]);
blackhole.consume(score);
forest.update(testData[i]);
forestState = mapper.toState(forest);
LinkedBuffer buffer = LinkedBuffer.allocate(512);
try {
bytes = ProtostuffIOUtil.toByteArray(forestState, schema, buffer);
} finally {
buffer.clear();
}
}
return bytes;
}
Aggregations