use of com.sri.ai.praise.evaluate.solver.SolverEvaluator in project aic-praise by aic-sri-international.
the class Evaluation method instantiateSolvers.
@SuppressWarnings("unchecked")
private List<SolverEvaluator> instantiateSolvers(List<SolverEvaluatorConfiguration> solverConfigurations, File workingDirectory) {
List<SolverEvaluator> result = new ArrayList<>(solverConfigurations.size());
for (SolverEvaluatorConfiguration configuration : solverConfigurations) {
Class<? extends SolverEvaluator> clazz;
try {
clazz = (Class<? extends SolverEvaluator>) Class.forName(configuration.getImplementationClassName());
} catch (ClassNotFoundException cnfe) {
throw new IllegalArgumentException("Unable to find " + SolverEvaluator.class.getName() + " implementation class: " + configuration.getImplementationClassName(), cnfe);
}
try {
SolverEvaluator solver = clazz.newInstance();
configuration.setWorkingDirectory(workingDirectory);
solver.setConfiguration(configuration);
result.add(solver);
} catch (Throwable t) {
throw new IllegalStateException("Unable to instantiate instance of " + clazz.getName(), t);
}
}
return result;
}
use of com.sri.ai.praise.evaluate.solver.SolverEvaluator in project aic-praise by aic-sri-international.
the class Evaluation method evaluate.
public void evaluate(Evaluation.Configuration configuration, PagedModelContainer modelsToEvaluateContainer, List<SolverEvaluatorConfiguration> solverConfigurations, Evaluation.Listener evaluationListener) {
// Note, varying domain sizes etc... is achieved by creating variants of a base model in the provided paged model container
long evaluationStart = System.currentTimeMillis();
try {
List<SolverEvaluator> solvers = instantiateSolvers(solverConfigurations, configuration.getWorkingDirectory());
// Do an initial burn in to ensure any OS caching etc... occurs so as to even out times across runs
ModelPage burnInModel = modelsToEvaluateContainer.getPages().get(0);
String burnInQuery = burnInModel.getDefaultQueriesToRun().get(0);
evaluationListener.notification("Starting solver burn in based on '" + modelsToEvaluateContainer.getName() + " - " + burnInModel.getName() + " : " + burnInQuery + "'");
for (SolverEvaluator solver : solvers) {
SolverCallResult solverResult = callSolver(configuration, solver, burnInModel, burnInQuery);
evaluationListener.notification("Burn in for " + solver.getName() + " complete. Average inference time = " + toDurationString(solverResult.averageInferenceTimeInMilliseconds));
}
// Output the report header line
StringJoiner csvLine = new StringJoiner(",");
csvLine.add("Problem");
csvLine.add("Inference Type");
csvLine.add("Domain Size(s)");
csvLine.add("# runs values averaged over");
for (SolverEvaluator solver : solvers) {
csvLine.add("Solver");
csvLine.add("Result for " + solver.getName());
csvLine.add("Inference ms. for " + solver.getName());
csvLine.add("HH:MM:SS.");
csvLine.add("Translation ms. for " + solver.getName());
csvLine.add("HH:MM:SS.");
}
evaluationListener.notification("Starting to generate Evaluation Report");
evaluationListener.csvResultOutput(csvLine.toString());
// Now evaluate each of the model-query-solver combinations.
for (ModelPage model : modelsToEvaluateContainer.getPages()) {
String domainSizes = getDomainSizes(model.getModel());
for (String query : model.getDefaultQueriesToRun()) {
csvLine = new StringJoiner(",");
String problemName = modelsToEvaluateContainer.getName() + " - " + model.getName() + " : " + query;
evaluationListener.notification("Starting to evaluate " + problemName);
csvLine.add(problemName);
csvLine.add(configuration.type.name());
csvLine.add(domainSizes);
csvLine.add("" + configuration.getNumberRunsToAverageOver());
for (SolverEvaluator solver : solvers) {
SolverCallResult solverResult = callSolver(configuration, solver, model, query);
csvLine.add(solver.getName());
csvLine.add(solverResult.failed ? "FAILED" : "" + solverResult.answer);
csvLine.add("" + solverResult.averageInferenceTimeInMilliseconds);
csvLine.add(toDurationString(solverResult.averageInferenceTimeInMilliseconds));
csvLine.add("" + solverResult.averagelTranslationTimeInMilliseconds);
csvLine.add(toDurationString(solverResult.averagelTranslationTimeInMilliseconds));
evaluationListener.notification("Solver " + solver.getName() + " took an average inference time of " + toDurationString(solverResult.averageInferenceTimeInMilliseconds) + " to solve " + problemName);
}
evaluationListener.csvResultOutput(csvLine.toString());
}
}
} catch (Exception ex) {
evaluationListener.notificationException(ex);
}
long evaluationEnd = System.currentTimeMillis();
evaluationListener.notification("Evaluation took " + toDurationString(evaluationEnd - evaluationStart) + " to run to completion.");
}
Aggregations