use of com.sri.ai.praise.model.common.io.ModelPage in project aic-praise by aic-sri-international.
the class Evaluation method evaluate.
public void evaluate(Evaluation.Configuration configuration, PagedModelContainer modelsToEvaluateContainer, List<SolverEvaluatorConfiguration> solverConfigurations, Evaluation.Listener evaluationListener) {
// Note, varying domain sizes etc... is achieved by creating variants of a base model in the provided paged model container
long evaluationStart = System.currentTimeMillis();
try {
List<SolverEvaluator> solvers = instantiateSolvers(solverConfigurations, configuration.getWorkingDirectory());
// Do an initial burn in to ensure any OS caching etc... occurs so as to even out times across runs
ModelPage burnInModel = modelsToEvaluateContainer.getPages().get(0);
String burnInQuery = burnInModel.getDefaultQueriesToRun().get(0);
evaluationListener.notification("Starting solver burn in based on '" + modelsToEvaluateContainer.getName() + " - " + burnInModel.getName() + " : " + burnInQuery + "'");
for (SolverEvaluator solver : solvers) {
SolverCallResult solverResult = callSolver(configuration, solver, burnInModel, burnInQuery);
evaluationListener.notification("Burn in for " + solver.getName() + " complete. Average inference time = " + toDurationString(solverResult.averageInferenceTimeInMilliseconds));
}
// Output the report header line
StringJoiner csvLine = new StringJoiner(",");
csvLine.add("Problem");
csvLine.add("Inference Type");
csvLine.add("Domain Size(s)");
csvLine.add("# runs values averaged over");
for (SolverEvaluator solver : solvers) {
csvLine.add("Solver");
csvLine.add("Result for " + solver.getName());
csvLine.add("Inference ms. for " + solver.getName());
csvLine.add("HH:MM:SS.");
csvLine.add("Translation ms. for " + solver.getName());
csvLine.add("HH:MM:SS.");
}
evaluationListener.notification("Starting to generate Evaluation Report");
evaluationListener.csvResultOutput(csvLine.toString());
// Now evaluate each of the model-query-solver combinations.
for (ModelPage model : modelsToEvaluateContainer.getPages()) {
String domainSizes = getDomainSizes(model.getModel());
for (String query : model.getDefaultQueriesToRun()) {
csvLine = new StringJoiner(",");
String problemName = modelsToEvaluateContainer.getName() + " - " + model.getName() + " : " + query;
evaluationListener.notification("Starting to evaluate " + problemName);
csvLine.add(problemName);
csvLine.add(configuration.type.name());
csvLine.add(domainSizes);
csvLine.add("" + configuration.getNumberRunsToAverageOver());
for (SolverEvaluator solver : solvers) {
SolverCallResult solverResult = callSolver(configuration, solver, model, query);
csvLine.add(solver.getName());
csvLine.add(solverResult.failed ? "FAILED" : "" + solverResult.answer);
csvLine.add("" + solverResult.averageInferenceTimeInMilliseconds);
csvLine.add(toDurationString(solverResult.averageInferenceTimeInMilliseconds));
csvLine.add("" + solverResult.averagelTranslationTimeInMilliseconds);
csvLine.add(toDurationString(solverResult.averagelTranslationTimeInMilliseconds));
evaluationListener.notification("Solver " + solver.getName() + " took an average inference time of " + toDurationString(solverResult.averageInferenceTimeInMilliseconds) + " to solve " + problemName);
}
evaluationListener.csvResultOutput(csvLine.toString());
}
}
} catch (Exception ex) {
evaluationListener.notificationException(ex);
}
long evaluationEnd = System.currentTimeMillis();
evaluationListener.notification("Evaluation took " + toDurationString(evaluationEnd - evaluationStart) + " to run to completion.");
}
use of com.sri.ai.praise.model.common.io.ModelPage in project aic-praise by aic-sri-international.
the class PRAiSE method run.
public static void run(String[] args, Supplier<Theory> theorySupplier) {
try (SGSolverArgs solverArgs = getArgs(args)) {
List<ModelPage> hogModelsToQuery = getHOGModelsToQuery(solverArgs);
for (ModelPage hogModelToQuery : hogModelsToQuery) {
solverArgs.out.print("MODEL NAME = ");
solverArgs.out.println(hogModelToQuery.getName());
solverArgs.out.println("MODEL = ");
solverArgs.out.println(hogModelToQuery.getModel());
HOGMQueryRunner queryRunner = new HOGMQueryRunner(hogModelToQuery.getModel(), hogModelToQuery.getDefaultQueriesToRun());
if (theorySupplier != null) {
queryRunner.setOptionTheory(theorySupplier.get());
}
List<HOGMQueryResult> hogModelQueryResults = queryRunner.query();
hogModelQueryResults.forEach(hogModelQueryResult -> {
solverArgs.out.print("QUERY = ");
solverArgs.out.println(hogModelQueryResult.getQueryString());
solverArgs.out.print(RESULT_PREFIX);
solverArgs.out.println(queryRunner.simplifyAnswer(hogModelQueryResult.getResult(), hogModelQueryResult.getQueryExpression()));
solverArgs.out.print("TOOK = ");
solverArgs.out.println(duration(hogModelQueryResult.getMillisecondsToCompute()) + "\n");
if (hogModelQueryResult.isErrors()) {
hogModelQueryResult.getErrors().forEach(error -> {
solverArgs.out.println("ERROR =" + error.getErrorMessage());
if (error.getThrowable() != null) {
solverArgs.out.println("THROWABLE =");
error.getThrowable().printStackTrace(solverArgs.out);
}
});
}
});
}
} catch (Exception ex) {
System.err.println("Error calling SGSolver");
ex.printStackTrace();
}
}
use of com.sri.ai.praise.model.common.io.ModelPage in project aic-praise by aic-sri-international.
the class PRAiSE method guessLanguageModel.
private static ModelLanguage guessLanguageModel(List<File> inputFiles) {
ModelLanguage result = Arrays.stream(ModelLanguage.values()).filter(ml -> inputFiles.stream().anyMatch(inputFile -> inputFile.getName().toLowerCase().endsWith(ml.getDefaultFileExtension().toLowerCase()))).findFirst().orElse(null);
if (result == null) {
// Check if the input is a container file and if so get the language from it
result = inputFiles.stream().filter(inputFile -> inputFile.getName().toLowerCase().endsWith(PagedModelContainer.DEFAULT_CONTAINER_FILE_EXTENSION.toLowerCase())).map(containerInputFile -> {
ModelLanguage containedLanguage = null;
try {
List<ModelPage> models = PagedModelContainer.getModelPagesFromURI(containerInputFile.toURI());
if (models.size() > 0) {
containedLanguage = models.get(0).getLanguage();
}
} catch (IOException ioe) {
System.err.println(ioe.getMessage());
ioe.printStackTrace();
}
return containedLanguage;
}).findFirst().orElse(null);
}
if (result == null) {
// For simplicity, defaults to HOGMv1 if nothing specified
result = ModelLanguage.HOGMv1;
}
return result;
}
use of com.sri.ai.praise.model.common.io.ModelPage in project aic-praise by aic-sri-international.
the class PRAiSE method getHOGModelsToQuery.
//
// PRIVATE
//
private static List<ModelPage> getHOGModelsToQuery(SGSolverArgs solverArgs) throws IOException {
List<ModelPage> result = new ArrayList<>();
// First handle container files and track non-container files
List<File> nonContainerFiles = new ArrayList<>();
for (File inputFile : solverArgs.inputFiles) {
if (inputFile.getName().endsWith(PagedModelContainer.DEFAULT_CONTAINER_FILE_EXTENSION)) {
if (solverArgs.globalQueries.size() == 0) {
// Take the models as is
result.addAll(PagedModelContainer.getModelPagesFromURI(inputFile.toURI()));
} else {
// each model page
for (ModelPage containerModelPage : PagedModelContainer.getModelPagesFromURI(inputFile.toURI())) {
List<String> combinedQueries = new ArrayList<>(solverArgs.globalQueries);
combinedQueries.addAll(containerModelPage.getDefaultQueriesToRun());
result.add(new ModelPage(containerModelPage.getLanguage(), containerModelPage.getName(), containerModelPage.getModel(), combinedQueries));
}
}
} else {
nonContainerFiles.add(inputFile);
}
}
// to construct a single model file
if (nonContainerFiles.size() > 0) {
String textModel = nonContainerFiles.stream().map(file -> {
String fileContents = "";
try {
fileContents = Files.readAllLines(file.toPath()).stream().collect(Collectors.joining("\n"));
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
return fileContents;
}).collect(Collectors.joining("\n"));
result.add(new ModelPage(solverArgs.inputLanguage, "Model from concatenation of non-container input files", textModel, solverArgs.globalQueries));
}
return result;
}
use of com.sri.ai.praise.model.common.io.ModelPage in project aic-praise by aic-sri-international.
the class AbstractPerspective method newModel.
@Override
public void newModel(ExamplePages examples) {
newModel(() -> {
List<ModelPage> pages = examples.getPages();
Map<Integer, Supplier<ModelPageEditor>> newModelPageIdxs = new HashMap<>();
for (int i = 0; i < pages.size(); i++) {
ModelPage page = pages.get(i);
newModelPageIdxs.put(i, new ModelPageEditorSupplier(page.getModel(), page.getDefaultQueriesToRun()));
}
return FXCollections.observableMap(newModelPageIdxs);
});
modelFile.set(null);
}
Aggregations