use of com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState in project ml-commons by opensearch-project.
the class FixedInTimeRandomCutForest method predict.
@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for FIT RCF prediction.");
}
ThresholdedRandomCutForestState state = (ThresholdedRandomCutForestState) ModelSerDeSer.deserialize(model.getContent());
ThresholdedRandomCutForest forest = trcfMapper.toModel(state);
List<Map<String, Object>> predictResult = process(dataFrame, forest);
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build();
}
use of com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState in project ml-commons by opensearch-project.
the class FixedInTimeRandomCutForest method train.
@Override
public Model train(DataFrame dataFrame) {
ThresholdedRandomCutForest forest = createThresholdedRandomCutForest(dataFrame);
process(dataFrame, forest);
Model model = new Model();
model.setName(FunctionName.FIT_RCF.name());
model.setVersion(1);
ThresholdedRandomCutForestState state = trcfMapper.toState(forest);
model.setContent(ModelSerDeSer.serialize(state));
return model;
}
Aggregations