use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class FixedInTimeRandomCutForest method process.
private List<Map<String, Object>> process(DataFrame dataFrame, ThresholdedRandomCutForest forest) {
List<Double> pointList = new ArrayList<>();
ColumnMeta[] columnMetas = dataFrame.columnMetas();
List<Map<String, Object>> predictResult = new ArrayList<>();
for (int rowNum = 0; rowNum < dataFrame.size(); rowNum++) {
Row row = dataFrame.getRow(rowNum);
long timestamp = -1;
for (int i = 0; i < columnMetas.length; i++) {
ColumnMeta columnMeta = columnMetas[i];
ColumnValue value = row.getValue(i);
// TODO: sort dataframe by time field with asc order. Currently consider the date already sorted by time.
if (timeField != null && timeField.equals(columnMeta.getName())) {
ColumnType columnType = columnMeta.getColumnType();
if (columnType == ColumnType.LONG) {
timestamp = value.longValue();
} else if (columnType == ColumnType.STRING) {
try {
timestamp = simpleDateFormat.parse(value.stringValue()).getTime();
} catch (ParseException e) {
log.error("Failed to parse timestamp " + value.stringValue(), e);
throw new MLValidationException("Failed to parse timestamp " + value.stringValue());
}
} else {
throw new MLValidationException("Wrong data type of time field. Should use LONG or STRING, but got " + columnType);
}
} else {
pointList.add(value.doubleValue());
}
}
double[] point = pointList.stream().mapToDouble(d -> d).toArray();
pointList.clear();
Map<String, Object> result = new HashMap<>();
AnomalyDescriptor process = forest.process(point, timestamp);
result.put(timeField, timestamp);
result.put("score", process.getRCFScore());
result.put("anomaly_grade", process.getAnomalyGrade());
predictResult.add(result);
}
return predictResult;
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLEngineTest method trainKMeansModel.
private Model trainKMeansModel() {
KMeansParams parameters = KMeansParams.builder().centroids(2).iterations(10).distanceType(KMeansParams.DistanceType.EUCLIDEAN).build();
DataFrame trainDataFrame = constructKMeansDataFrame(100);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(parameters).inputDataset(inputDataset).build();
return MLEngine.train(mlInput);
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class AnomalyDetectionLibSVMTest method constructDataFrame.
private DataFrame constructDataFrame(Dataset<Event> data, boolean training, List<Event.EventType> labels) {
Iterator<Example<Event>> iterator = data.iterator();
List<ColumnMeta> columns = null;
DataFrame dataFrame = null;
while (iterator.hasNext()) {
Example<Event> example = iterator.next();
if (columns == null) {
columns = new ArrayList<>();
List<ColumnValue> columnValues = new ArrayList<>();
for (Feature feature : example) {
columns.add(new ColumnMeta(feature.getName(), ColumnType.DOUBLE));
columnValues.add(new DoubleValue(feature.getValue()));
}
ColumnMeta[] columnMetas = columns.toArray(new ColumnMeta[columns.size()]);
dataFrame = new DefaultDataFrame(columnMetas);
addRow(columnValues, training, example, dataFrame, labels);
} else {
List<ColumnValue> columnValues = new ArrayList<>();
for (Feature feature : example) {
columnValues.add(new DoubleValue(feature.getValue()));
}
addRow(columnValues, training, example, dataFrame, labels);
}
}
return dataFrame;
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLInputDatasetHandler method parseSearchQueryInput.
/**
* Create DataFrame based on given search query
* @param mlInputDataset MLInputDataset
* @param listener ActionListener
*/
public void parseSearchQueryInput(MLInputDataset mlInputDataset, ActionListener<DataFrame> listener) {
if (!mlInputDataset.getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
throw new IllegalArgumentException("Input dataset is not SEARCH_QUERY type.");
}
SearchQueryInputDataset inputDataset = (SearchQueryInputDataset) mlInputDataset;
SearchRequest searchRequest = new SearchRequest();
searchRequest.source(inputDataset.getSearchSourceBuilder());
List<String> indicesList = inputDataset.getIndices();
String[] indices = new String[indicesList.size()];
indices = indicesList.toArray(indices);
searchRequest.indices(indices);
client.search(searchRequest, ActionListener.wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
listener.onFailure(new IllegalArgumentException("No document found"));
return;
}
SearchHits hits = r.getHits();
List<Map<String, Object>> input = new ArrayList<>();
SearchHit[] searchHits = hits.getHits();
for (SearchHit hit : searchHits) {
input.add(hit.getSourceAsMap());
}
DataFrame dataFrame = DataFrameBuilder.load(input);
listener.onResponse(dataFrame);
return;
}, e -> {
log.error("Failed to search" + e);
listener.onFailure(e);
}));
return;
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLInput method parse.
public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException {
String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT);
FunctionName algorithm = FunctionName.from(algorithmName);
MLAlgoParams mlParameters = null;
SearchSourceBuilder searchSourceBuilder = null;
List<String> sourceIndices = new ArrayList<>();
DataFrame dataFrame = null;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
switch(fieldName) {
case ML_PARAMETERS_FIELD:
mlParameters = parser.namedObject(MLAlgoParams.class, algorithmName, null);
break;
case INPUT_INDEX_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
sourceIndices.add(parser.text());
}
break;
case INPUT_QUERY_FIELD:
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
searchSourceBuilder = SearchSourceBuilder.fromXContent(parser, false);
break;
case INPUT_DATA_FIELD:
dataFrame = DefaultDataFrame.parse(parser);
default:
parser.skipChildren();
break;
}
}
return new MLInput(algorithm, mlParameters, searchSourceBuilder, sourceIndices, dataFrame, null);
}
Aggregations