use of com.ibm.cohort.cql.spark.aggregation.ContextDefinition in project quality-measure-and-cohort-service by Alvearie.
the class SparkCqlEvaluator method createDataTypeAliases.
private Map<String, String> createDataTypeAliases(List<ContextDefinition> filteredContexts, CqlToElmTranslator translator) {
Set<String> dataTypes = filteredContexts.stream().map(ContextDefinition::getPrimaryDataType).collect(Collectors.toSet());
Collection<ModelInfo> modelInfos = translator.getRegisteredModelInfos().values();
Map<String, String> dataTypeAliases = new HashMap<>();
for (ModelInfo modelInfo : modelInfos) {
modelInfo.getTypeInfo().stream().filter(ClassInfo.class::isInstance).map(ClassInfo.class::cast).filter(classInfo -> dataTypes.contains(classInfo.getName())).forEach(info -> {
String dataType = info.getName();
QName baseType = ModelUtils.getBaseTypeName(modelInfo, info);
if (baseType != null) {
// for inheritance types
dataTypeAliases.put(dataType, baseType.getLocalPart());
}
Collection<String> choiceTypes = ModelUtils.getChoiceTypeNames(info);
// for choice types
choiceTypes.forEach(choiceType -> dataTypeAliases.put(dataType, choiceType));
});
}
return dataTypeAliases;
}
use of com.ibm.cohort.cql.spark.aggregation.ContextDefinition in project quality-measure-and-cohort-service by Alvearie.
the class SparkSchemaCreatorTest method makeContextDefinition.
private ContextDefinition makeContextDefinition(String name, String primaryDataType, String primaryKeyColumn) {
ContextDefinition contextDefinition = new ContextDefinition();
contextDefinition.setName(name);
contextDefinition.setPrimaryDataType(primaryDataType);
contextDefinition.setPrimaryKeyColumn(primaryKeyColumn);
return contextDefinition;
}
use of com.ibm.cohort.cql.spark.aggregation.ContextDefinition in project quality-measure-and-cohort-service by Alvearie.
the class SparkCqlEvaluator method run.
public void run(PrintStream out) throws Exception {
EvaluationSummary evaluationSummary = new EvaluationSummary();
long startTimeMillis = System.currentTimeMillis();
evaluationSummary.setStartTimeMillis(startTimeMillis);
evaluationSummary.setJobStatus(JobStatus.FAIL);
SparkSession.Builder sparkBuilder = SparkSession.builder();
try (SparkSession spark = sparkBuilder.getOrCreate()) {
final LongAccumulator contextAccum = spark.sparkContext().longAccumulator("Context");
final CollectionAccumulator<EvaluationError> errorAccumulator = spark.sparkContext().collectionAccumulator("EvaluationErrors");
try {
spark.sparkContext().setLocalProperty("mdc." + CORRELATION_ID, MDC.get(CORRELATION_ID));
evaluationSummary.setCorrelationId(MDC.get(CORRELATION_ID));
boolean useJava8API = Boolean.valueOf(spark.conf().get("spark.sql.datetime.java8API.enabled"));
this.typeConverter = new SparkTypeConverter(useJava8API);
this.hadoopConfiguration = new SerializableConfiguration(spark.sparkContext().hadoopConfiguration());
evaluationSummary.setApplicationId(spark.sparkContext().applicationId());
CqlToElmTranslator cqlTranslator = getCqlTranslator();
SparkOutputColumnEncoder columnEncoder = getSparkOutputColumnEncoder();
ContextDefinitions contexts = readContextDefinitions(args.contextDefinitionPath);
List<ContextDefinition> filteredContexts = contexts.getContextDefinitions();
if (args.aggregationContexts != null && !args.aggregationContexts.isEmpty()) {
filteredContexts = filteredContexts.stream().filter(def -> args.aggregationContexts.contains(def.getName())).collect(Collectors.toList());
}
if (filteredContexts.isEmpty()) {
throw new IllegalArgumentException("At least one context definition is required (after filtering if enabled).");
}
Map<String, StructType> resultSchemas = calculateSparkSchema(filteredContexts.stream().map(ContextDefinition::getName).collect(Collectors.toList()), contexts, columnEncoder, cqlTranslator);
ZonedDateTime batchRunTime = ZonedDateTime.now();
final LongAccumulator perContextAccum = spark.sparkContext().longAccumulator("PerContext");
CustomMetricSparkPlugin.contextAccumGauge.setAccumulator(contextAccum);
CustomMetricSparkPlugin.perContextAccumGauge.setAccumulator(perContextAccum);
CustomMetricSparkPlugin.totalContextsToProcessCounter.inc(filteredContexts.size());
CustomMetricSparkPlugin.currentlyEvaluatingContextGauge.setValue(0);
ColumnRuleCreator columnRuleCreator = new ColumnRuleCreator(getFilteredJobSpecificationWithIds().getEvaluations(), getCqlTranslator(), createLibraryProvider());
Map<String, String> dataTypeAliases = createDataTypeAliases(filteredContexts, cqlTranslator);
for (ContextDefinition context : filteredContexts) {
final String contextName = context.getName();
ContextRetriever contextRetriever = new ContextRetriever(args.inputPaths, new DefaultDatasetRetriever(spark, args.inputFormat), args.disableColumnFiltering ? null : columnRuleCreator.getDataRequirementsForContext(context));
StructType resultsSchema = resultSchemas.get(contextName);
if (resultsSchema == null || resultsSchema.fields().length == 0) {
LOG.warn("Context " + contextName + " has no defines configured. Skipping.");
} else {
LOG.info("Evaluating context " + contextName);
long contextStartMillis = System.currentTimeMillis();
final String outputPath = MapUtils.getRequiredKey(args.outputPaths, context.getName(), "outputPath");
JavaPairRDD<Object, List<Row>> rowsByContextId = contextRetriever.retrieveContext(context);
CustomMetricSparkPlugin.currentlyEvaluatingContextGauge.setValue(CustomMetricSparkPlugin.currentlyEvaluatingContextGauge.getValue() + 1);
JavaPairRDD<Object, Row> resultsByContext = rowsByContextId.flatMapToPair(x -> evaluate(contextName, resultsSchema, x, dataTypeAliases, perContextAccum, errorAccumulator, batchRunTime));
writeResults(spark, resultsSchema, resultsByContext, outputPath);
long contextEndMillis = System.currentTimeMillis();
LOG.info(String.format("Wrote results for context %s to %s", contextName, outputPath));
evaluationSummary.addContextCount(contextName, perContextAccum.value());
evaluationSummary.addContextRuntime(contextName, contextEndMillis - contextStartMillis);
contextAccum.add(1);
perContextAccum.reset();
}
}
CustomMetricSparkPlugin.currentlyEvaluatingContextGauge.setValue(0);
try {
Boolean metricsEnabledStr = Boolean.valueOf(spark.conf().get("spark.ui.prometheus.enabled"));
if (metricsEnabledStr) {
LOG.info("Prometheus metrics enabled, sleeping for 7 seconds to finish gathering metrics");
// sleep for over 5 seconds because Prometheus only polls
// every 5 seconds. If spark finishes and goes away immediately after completing,
// Prometheus will never be able to poll for the final set of metrics for the spark-submit
// The default promtheus config map was changed from 2 minute scrape interval to 5 seconds for spark pods
Thread.sleep(7000);
} else {
LOG.info("Prometheus metrics not enabled");
}
} catch (NoSuchElementException e) {
LOG.info("spark.ui.prometheus.enabled is not set");
}
evaluationSummary.setJobStatus(JobStatus.SUCCESS);
} catch (Exception e) {
// If we experience an error that would make the program halt, capture the error
// and report it in the batch summary file
ByteArrayOutputStream errorDetailStream = new ByteArrayOutputStream();
try (PrintStream printStream = new PrintStream(errorDetailStream)) {
printStream.write(e.getMessage().getBytes());
printStream.write('\n');
if (e.getCause() != null) {
printStream.write(e.getCause().getMessage().getBytes());
printStream.write('\n');
}
e.printStackTrace(printStream);
evaluationSummary.setErrorList(Collections.singletonList(new EvaluationError(null, null, null, errorDetailStream.toString())));
}
throw e;
} finally {
long endTimeMillis = System.currentTimeMillis();
evaluationSummary.setEndTimeMillis(endTimeMillis);
evaluationSummary.setRuntimeMillis(endTimeMillis - startTimeMillis);
if (args.metadataOutputPath != null) {
if (evaluationSummary.getErrorList() == null) {
evaluationSummary.setErrorList(errorAccumulator.value());
}
if (CollectionUtils.isNotEmpty(evaluationSummary.getErrorList())) {
evaluationSummary.setJobStatus(JobStatus.FAIL);
}
evaluationSummary.setTotalContexts(contextAccum.value());
OutputMetadataWriter writer = getOutputMetadataWriter();
writer.writeMetadata(evaluationSummary);
}
}
}
}
use of com.ibm.cohort.cql.spark.aggregation.ContextDefinition in project quality-measure-and-cohort-service by Alvearie.
the class SparkSchemaCreator method getDataTypeForContextKey.
private Tuple2<String, DataType> getDataTypeForContextKey(String contextName, Set<Tuple2<String, String>> usingInfos) {
ContextDefinition contextDefinition = contextDefinitions.getContextDefinitionByName(contextName);
String primaryDataType = contextDefinition.getPrimaryDataType();
String primaryKeyColumn = contextDefinition.getPrimaryKeyColumn();
DataType keyType = null;
ModelManager modelManager = translator.newModelManager();
// Try to find the key column's type information from a single model info.
for (Tuple2<String, String> usingInfo : usingInfos) {
VersionedIdentifier modelInfoIdentifier = new VersionedIdentifier().withId(usingInfo._1()).withVersion(usingInfo._2());
ModelInfo modelInfo = modelManager.getModelInfoLoader().getModelInfo(modelInfoIdentifier);
// Look for a ClassInfo element matching primaryDataType for the context
List<ClassInfo> classInfos = getClassInfos(primaryDataType, modelInfo);
if (!classInfos.isEmpty()) {
if (classInfos.size() == 1) {
ClassInfo classInfo = classInfos.get(0);
List<ClassInfoElement> elements = classInfo.getElement().stream().filter(x -> x.getName().equals(primaryKeyColumn)).collect(Collectors.toList());
// check base type
String baseType = classInfo.getBaseType();
if (classInfo.getBaseType() != null) {
List<ClassInfo> baseClassInfos = getClassInfos(baseType, modelInfo);
baseClassInfos.stream().map(ClassInfo::getElement).flatMap(List::stream).filter(element -> element.getName().equals(primaryKeyColumn)).forEach(elements::add);
}
// check choice types
Collection<String> choiceTypes = ModelUtils.getChoiceTypeNames(classInfo);
choiceTypes.stream().map(type -> getClassInfos(type, modelInfo)).flatMap(List::stream).map(ClassInfo::getElement).flatMap(List::stream).filter(element -> element.getName().equals(primaryKeyColumn)).findFirst().ifPresent(elements::add);
// A future ModelInfo file may contain the information
if (elements.isEmpty()) {
continue;
} else if (elements.size() == 1) {
String elementType = elements.get(0).getElementType();
// store it
if (keyType == null) {
keyType = getSparkTypeForSystemValue(elementType);
} else {
throw new IllegalArgumentException("Multiple definitions found for " + primaryDataType + "." + primaryKeyColumn + " in the provided ModelInfo files. Cannot infer key type for context: " + contextName);
}
} else if (elements.size() > 1) {
throw new IllegalArgumentException("ModelInfo " + modelInfoIdentifier + " contains multiple element definitions for " + primaryKeyColumn + " for type " + primaryDataType);
}
} else {
throw new IllegalArgumentException("ModelInfo " + modelInfoIdentifier + " contains multiple definitions for type " + primaryDataType);
}
}
}
if (keyType == null) {
throw new IllegalArgumentException("Could not locate type information for " + primaryDataType + "." + primaryKeyColumn + " in the provided ModelInfo files. Cannot infer key type for context: " + contextName);
}
return new Tuple2<>(contextDefinition.getPrimaryKeyColumn(), keyType);
}
Aggregations