Search in sources :

Example 6 with RestorableConfig

use of io.prestosql.spi.snapshot.RestorableConfig in project hetu-core by openlookeng.

the class WorkProcessorUtils method mergeSorted.

static <T> WorkProcessor<T> mergeSorted(List<WorkProcessor<T>> processorList, Comparator<T> comparator) {
    requireNonNull(comparator, "comparator is null");
    checkArgument(processorList.size() > 0, "There must be at least one base processor");
    PriorityQueue<ElementAndProcessor<T>> queue = new PriorityQueue<>(2, comparing(ElementAndProcessor::getElement, comparator));
    return create(new WorkProcessor.Process<T>() {

        @RestorableConfig(stateClassName = "MergeSortedState", uncapturedFields = { "val$queue" })
        private final RestorableConfig restorableConfig = null;

        int nextProcessor;

        WorkProcessor<T> processor = requireNonNull(processorList.get(nextProcessor++));

        @Override
        public ProcessState<T> process() {
            while (true) {
                if (processor.process()) {
                    if (!processor.isFinished()) {
                        queue.add(new ElementAndProcessor<>(processor.getResult(), processor));
                    }
                } else if (processor.isBlocked()) {
                    return ProcessState.blocked(processor.getBlockedFuture());
                } else {
                    return ProcessState.yield();
                }
                if (nextProcessor < processorList.size()) {
                    processor = requireNonNull(processorList.get(nextProcessor++));
                    continue;
                }
                if (queue.isEmpty()) {
                    processor = null;
                    return ProcessState.finished();
                }
                ElementAndProcessor<T> elementAndProcessor = queue.poll();
                processor = elementAndProcessor.getProcessor();
                return ProcessState.ofResult(elementAndProcessor.getElement());
            }
        }

        @Override
        public Object capture(BlockEncodingSerdeProvider serdeProvider) {
            MergeSortedState myState = new MergeSortedState();
            myState.processorList = new Object[processorList.size()];
            for (int i = 0; i < processorList.size(); i++) {
                myState.processorList[i] = processorList.get(i).capture(serdeProvider);
            }
            myState.nextProcessor = nextProcessor;
            // Record which processors are in queue
            myState.queueProcessorIndex = new ArrayList<>();
            for (ElementAndProcessor enp : queue) {
                myState.queueProcessorIndex.add(processorList.indexOf(enp.processor));
            }
            myState.processor = processorList.indexOf(processor);
            return myState;
        }

        @Override
        public void restore(Object state, BlockEncodingSerdeProvider serdeProvider) {
            MergeSortedState myState = (MergeSortedState) state;
            checkArgument(myState.processorList.length == processorList.size());
            for (int i = 0; i < myState.processorList.length; i++) {
                processorList.get(i).restore(myState.processorList[i], serdeProvider);
            }
            nextProcessor = myState.nextProcessor;
            queue.clear();
            for (Integer queueProcessorIndex : myState.queueProcessorIndex) {
                checkArgument(queueProcessorIndex < processorList.size(), "Processor index exceeded processor list.");
                queue.add(new ElementAndProcessor<>(processorList.get(queueProcessorIndex).getResult(), processorList.get(queueProcessorIndex)));
            }
            this.processor = processorList.get(myState.processor);
        }

        @Override
        public Object captureResult(T result, BlockEncodingSerdeProvider serdeProvider) {
            for (int i = 0; i < processorList.size(); i++) {
                if (((ProcessWorkProcessor) processorList.get(i)).state.getType() == ProcessState.Type.RESULT && processorList.get(i).getResult() == result) {
                    return i;
                }
            }
            throw new IllegalArgumentException("Unable to capture result.");
        }

        @Override
        public T restoreResult(Object resultState, BlockEncodingSerdeProvider serdeProvider) {
            checkArgument(((int) resultState) < processorList.size());
            ProcessWorkProcessor<T> targetProcessor = (ProcessWorkProcessor) processorList.get((int) resultState);
            checkArgument(targetProcessor.state != null && targetProcessor.state.getType() == ProcessState.Type.RESULT);
            return targetProcessor.getResult();
        }
    });
}
Also used : ArrayList(java.util.ArrayList) BlockEncodingSerdeProvider(io.prestosql.spi.snapshot.BlockEncodingSerdeProvider) PriorityQueue(java.util.PriorityQueue) ProcessState(io.prestosql.operator.WorkProcessor.ProcessState) RestorableConfig(io.prestosql.spi.snapshot.RestorableConfig)

Example 7 with RestorableConfig

use of io.prestosql.spi.snapshot.RestorableConfig in project hetu-core by openlookeng.

the class TestSnapshotCompleteness method test.

private int test(Class<?> clazz) throws Exception {
    if (capturedFieldCount.containsKey(clazz)) {
        return capturedFieldCount.get(clazz);
    }
    RestorableConfig config = getConfigAnnotation(clazz);
    if (config != null && config.unsupported()) {
        capturedFieldCount.put(clazz, 0);
        return 0;
    }
    // Get all fields to be captured
    List<String> allFields = new ArrayList<>();
    getAllFields(allFields, clazz);
    // Convert to simple name
    allFields = allFields.stream().map(name -> {
        String[] splits = name.split("\\.");
        return splits[splits.length - 1];
    }).collect(Collectors.toList());
    Set<String> allFieldsSet = new HashSet<>(allFields);
    if (allFieldsSet.size() != allFields.size()) {
        failures.put(clazz, "Set size shrank, there are fields with duplicated name in base class(s).");
    }
    if (!containsCaptureMethod(clazz)) {
        // OK only if a single entry in allFields for the base state, and base has capture function
        String baseClassStateName = "baseState";
        if (config != null) {
            baseClassStateName = config.baseClassStateName();
        }
        if (allFields.size() == 1 && allFields.contains(baseClassStateName) && containsCaptureMethod(clazz.getSuperclass())) {
            // Return 1 to indicate that a subclass needs to include capture baseState
            capturedFieldCount.put(clazz, 1);
            return 1;
        }
        if (!allFields.isEmpty() || !Modifier.isAbstract(clazz.getModifiers())) {
            failures.put(clazz, "No capture function");
        }
    }
    String stateClassName = "";
    if (config != null) {
        stateClassName = config.stateClassName();
    }
    // Default State class name
    if (stateClassName.isEmpty()) {
        stateClassName = clazz.getSimpleName() + "State";
    }
    Class<?> stateClass = findStateClass(clazz, stateClassName);
    if (stateClass != null) {
        if (!Serializable.class.isAssignableFrom(stateClass)) {
            failures.put(clazz, "State class " + stateClass.getSimpleName() + " is doesn't implement Serializable.");
        }
        if (!Modifier.isStatic(stateClass.getModifiers())) {
            failures.put(clazz, "Inner State class is not static.");
        }
        Field[] snapshotFields = stateClass.getDeclaredFields();
        Set<String> snapshotFieldsName = Arrays.stream(snapshotFields).map(Field::getName).collect(Collectors.toSet());
        capturedFieldCount.put(clazz, snapshotFieldsName.size());
        Set<String> fieldsNotCovered = allFields.stream().filter(name -> !snapshotFieldsName.contains(name)).filter(name -> !(name.contains("val$") && snapshotFieldsName.contains(name.substring(name.lastIndexOf("val$") + 4)))).collect(Collectors.toSet());
        if (!fieldsNotCovered.isEmpty()) {
            failures.put(clazz, "Captured and Uncaptured fields doesn't include all fields.\n" + fieldsNotCovered);
        }
        if (config != null) {
            String[] uncapturedFields = config.uncapturedFields();
            Set<String> unusedUncaptured = Arrays.stream(uncapturedFields).filter(snapshotFieldsName::contains).collect(Collectors.toSet());
            if (!unusedUncaptured.isEmpty()) {
                unusedUncapturedFields.computeIfAbsent(clazz, ignored -> new HashSet<>()).addAll(unusedUncaptured);
            }
        }
    } else {
        if (config != null) {
            if (!config.stateClassName().isEmpty()) {
                failures.put(clazz, "Can't find specified state class " + config.stateClassName());
            }
        }
        int capturedCount = allFields.size();
        capturedFieldCount.put(clazz, capturedCount);
        if (capturedCount > 1) {
            failures.put(clazz, "No State class but has more than 1 field that needs capturing: " + allFields);
        }
    }
    return capturedFieldCount.get(clazz);
}
Also used : Arrays(java.util.Arrays) OptionalDouble(java.util.OptionalDouble) Restorable(io.prestosql.spi.snapshot.Restorable) Test(org.testng.annotations.Test) HashMap(java.util.HashMap) RestorableConfig(io.prestosql.spi.snapshot.RestorableConfig) Multimap(com.google.common.collect.Multimap) OptionalInt(java.util.OptionalInt) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) BigDecimal(java.math.BigDecimal) OptionalLong(java.util.OptionalLong) HashMultimap(com.google.common.collect.HashMultimap) Map(java.util.Map) BigInteger(java.math.BigInteger) ClassPath(com.google.common.reflect.ClassPath) Method(java.lang.reflect.Method) Assert.fail(org.testng.Assert.fail) Set(java.util.Set) Field(java.lang.reflect.Field) Collectors(java.util.stream.Collectors) Serializable(java.io.Serializable) List(java.util.List) ParameterizedType(java.lang.reflect.ParameterizedType) Type(java.lang.reflect.Type) Modifier(java.lang.reflect.Modifier) Optional(java.util.Optional) Serializable(java.io.Serializable) ArrayList(java.util.ArrayList) Field(java.lang.reflect.Field) RestorableConfig(io.prestosql.spi.snapshot.RestorableConfig) HashSet(java.util.HashSet)

Aggregations

RestorableConfig (io.prestosql.spi.snapshot.RestorableConfig)7 BlockEncodingSerdeProvider (io.prestosql.spi.snapshot.BlockEncodingSerdeProvider)5 ArrayList (java.util.ArrayList)4 TransformationState (io.prestosql.operator.WorkProcessor.TransformationState)3 List (java.util.List)3 HashMultimap (com.google.common.collect.HashMultimap)2 Multimap (com.google.common.collect.Multimap)2 ClassPath (com.google.common.reflect.ClassPath)2 ProcessState (io.prestosql.operator.WorkProcessor.ProcessState)2 Page (io.prestosql.spi.Page)2 Restorable (io.prestosql.spi.snapshot.Restorable)2 Serializable (java.io.Serializable)2 Field (java.lang.reflect.Field)2 Method (java.lang.reflect.Method)2 Modifier (java.lang.reflect.Modifier)2 ParameterizedType (java.lang.reflect.ParameterizedType)2 Type (java.lang.reflect.Type)2 BigDecimal (java.math.BigDecimal)2 BigInteger (java.math.BigInteger)2 Arrays (java.util.Arrays)2