Search in sources :

Example 56 with StreamObserver

use of org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.StreamObserver in project beam by apache.

the class ArtifactStagingService method reverseArtifactRetrievalService.

@Override
public StreamObserver<ArtifactApi.ArtifactResponseWrapper> reverseArtifactRetrievalService(StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
    return new StreamObserver<ArtifactApi.ArtifactResponseWrapper>() {

        /**
         * The maximum number of parallel threads to use to stage.
         */
        public static final int THREAD_POOL_SIZE = 10;

        /**
         * The maximum number of bytes to buffer across all writes before throttling.
         */
        // 100 MB
        public static final int MAX_PENDING_BYTES = 100 << 20;

        IdGenerator idGenerator = IdGenerators.incrementingLongs();

        String stagingToken;

        Map<String, List<RunnerApi.ArtifactInformation>> toResolve;

        Map<String, List<Future<RunnerApi.ArtifactInformation>>> stagedFutures;

        ExecutorService stagingExecutor;

        OverflowingSemaphore totalPendingBytes;

        State state = State.START;

        Queue<String> pendingResolves;

        String currentEnvironment;

        Queue<RunnerApi.ArtifactInformation> pendingGets;

        BlockingQueue<ByteString> currentOutput;

        @Override
        @SuppressFBWarnings(value = "SF_SWITCH_FALLTHROUGH", justification = "fallthrough intended")
        public synchronized // synchronization.
        void onNext(ArtifactApi.ArtifactResponseWrapper responseWrapper) {
            switch(state) {
                case START:
                    stagingToken = responseWrapper.getStagingToken();
                    LOG.info("Staging artifacts for {}.", stagingToken);
                    toResolve = toStage.get(stagingToken);
                    if (toResolve == null) {
                        responseObserver.onError(new StatusException(Status.INVALID_ARGUMENT.withDescription("Unknown staging token " + stagingToken)));
                        return;
                    }
                    stagedFutures = new ConcurrentHashMap<>();
                    pendingResolves = new ArrayDeque<>();
                    pendingResolves.addAll(toResolve.keySet());
                    stagingExecutor = Executors.newFixedThreadPool(THREAD_POOL_SIZE);
                    totalPendingBytes = new OverflowingSemaphore(MAX_PENDING_BYTES);
                    resolveNextEnvironment(responseObserver);
                    break;
                case RESOLVE:
                    {
                        currentEnvironment = pendingResolves.remove();
                        stagedFutures.put(currentEnvironment, new ArrayList<>());
                        pendingGets = new ArrayDeque<>();
                        for (RunnerApi.ArtifactInformation artifact : responseWrapper.getResolveArtifactResponse().getReplacementsList()) {
                            Optional<RunnerApi.ArtifactInformation> fetched = getLocal();
                            if (fetched.isPresent()) {
                                stagedFutures.get(currentEnvironment).add(CompletableFuture.completedFuture(fetched.get()));
                            } else {
                                pendingGets.add(artifact);
                                responseObserver.onNext(ArtifactApi.ArtifactRequestWrapper.newBuilder().setGetArtifact(ArtifactApi.GetArtifactRequest.newBuilder().setArtifact(artifact)).build());
                            }
                        }
                        LOG.info("Getting {} artifacts for {}.{}.", pendingGets.size(), stagingToken, pendingResolves.peek());
                        if (pendingGets.isEmpty()) {
                            resolveNextEnvironment(responseObserver);
                        } else {
                            state = State.GET;
                        }
                        break;
                    }
                case GET:
                    RunnerApi.ArtifactInformation currentArtifact = pendingGets.remove();
                    String name = createFilename(currentEnvironment, currentArtifact);
                    try {
                        LOG.debug("Storing artifacts for {} as {}", stagingToken, name);
                        currentOutput = new ArrayBlockingQueue<ByteString>(100);
                        stagedFutures.get(currentEnvironment).add(stagingExecutor.submit(new StoreArtifact(stagingToken, name, currentArtifact, currentOutput, totalPendingBytes)));
                    } catch (Exception exn) {
                        LOG.error("Error submitting.", exn);
                        responseObserver.onError(exn);
                    }
                    state = State.GETCHUNK;
                case GETCHUNK:
                    try {
                        ByteString chunk = responseWrapper.getGetArtifactResponse().getData();
                        if (chunk.size() > 0) {
                            // Make sure we don't accidentally send the EOF value.
                            totalPendingBytes.aquire(chunk.size());
                            currentOutput.put(chunk);
                        }
                        if (responseWrapper.getIsLast()) {
                            // The EOF value.
                            currentOutput.put(ByteString.EMPTY);
                            if (pendingGets.isEmpty()) {
                                resolveNextEnvironment(responseObserver);
                            } else {
                                state = State.GET;
                                LOG.debug("Waiting for {}", pendingGets.peek());
                            }
                        }
                    } catch (Exception exn) {
                        LOG.error("Error submitting.", exn);
                        onError(exn);
                    }
                    break;
                default:
                    responseObserver.onError(new StatusException(Status.INVALID_ARGUMENT.withDescription("Illegal state " + state)));
            }
        }

        private void resolveNextEnvironment(StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
            if (pendingResolves.isEmpty()) {
                finishStaging(responseObserver);
            } else {
                state = State.RESOLVE;
                LOG.info("Resolving artifacts for {}.{}.", stagingToken, pendingResolves.peek());
                responseObserver.onNext(ArtifactApi.ArtifactRequestWrapper.newBuilder().setResolveArtifact(ArtifactApi.ResolveArtifactsRequest.newBuilder().addAllArtifacts(toResolve.get(pendingResolves.peek()))).build());
            }
        }

        private void finishStaging(StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
            LOG.debug("Finishing staging for {}.", stagingToken);
            Map<String, List<RunnerApi.ArtifactInformation>> staged = new HashMap<>();
            try {
                for (Map.Entry<String, List<Future<RunnerApi.ArtifactInformation>>> entry : stagedFutures.entrySet()) {
                    List<RunnerApi.ArtifactInformation> envStaged = new ArrayList<>();
                    for (Future<RunnerApi.ArtifactInformation> future : entry.getValue()) {
                        envStaged.add(future.get());
                    }
                    staged.put(entry.getKey(), envStaged);
                }
                ArtifactStagingService.this.staged.put(stagingToken, staged);
                stagingExecutor.shutdown();
                state = State.DONE;
                LOG.info("Artifacts fully staged for {}.", stagingToken);
                responseObserver.onCompleted();
            } catch (Exception exn) {
                LOG.error("Error staging artifacts", exn);
                responseObserver.onError(exn);
                state = State.ERROR;
                return;
            }
        }

        /**
         * Return an alternative artifact if we do not need to get this over the artifact API, or
         * possibly at all.
         */
        private Optional<RunnerApi.ArtifactInformation> getLocal() {
            return Optional.empty();
        }

        /**
         * Attempts to provide a reasonable filename for the artifact.
         *
         * @param index a monotonically increasing index, which provides uniqueness
         * @param environment the environment id
         * @param artifact the artifact itself
         */
        private String createFilename(String environment, RunnerApi.ArtifactInformation artifact) {
            String path;
            try {
                if (artifact.getRoleUrn().equals(ArtifactRetrievalService.STAGING_TO_ARTIFACT_URN)) {
                    path = RunnerApi.ArtifactStagingToRolePayload.parseFrom(artifact.getRolePayload()).getStagedName();
                } else if (artifact.getTypeUrn().equals(ArtifactRetrievalService.FILE_ARTIFACT_URN)) {
                    path = RunnerApi.ArtifactFilePayload.parseFrom(artifact.getTypePayload()).getPath();
                } else if (artifact.getTypeUrn().equals(ArtifactRetrievalService.URL_ARTIFACT_URN)) {
                    path = RunnerApi.ArtifactUrlPayload.parseFrom(artifact.getTypePayload()).getUrl();
                } else {
                    path = "artifact";
                }
            } catch (InvalidProtocolBufferException exn) {
                throw new RuntimeException(exn);
            }
            // Limit to the last contiguous alpha-numeric sequence. In particular, this will exclude
            // all path separators.
            List<String> components = Splitter.onPattern("[^A-Za-z-_.]]").splitToList(path);
            String base = components.get(components.size() - 1);
            return clip(String.format("%s-%s-%s", idGenerator.getId(), clip(environment, 25), base), 100);
        }

        private String clip(String s, int maxLength) {
            return s.length() < maxLength ? s : s.substring(0, maxLength);
        }

        @Override
        public void onError(Throwable throwable) {
            stagingExecutor.shutdownNow();
            LOG.error("Error staging artifacts", throwable);
            state = State.ERROR;
        }

        @Override
        public void onCompleted() {
            Preconditions.checkArgument(state == State.DONE);
        }
    };
}
Also used : ArtifactApi(org.apache.beam.model.jobmanagement.v1.ArtifactApi) HashMap(java.util.HashMap) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) ArrayList(java.util.ArrayList) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) StatusException(org.apache.beam.vendor.grpc.v1p43p2.io.grpc.StatusException) ArrayList(java.util.ArrayList) List(java.util.List) ImmutableList(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList) BlockingQueue(java.util.concurrent.BlockingQueue) ArrayBlockingQueue(java.util.concurrent.ArrayBlockingQueue) Queue(java.util.Queue) StreamObserver(org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.StreamObserver) BlockingQueue(java.util.concurrent.BlockingQueue) ArrayBlockingQueue(java.util.concurrent.ArrayBlockingQueue) Optional(java.util.Optional) InvalidProtocolBufferException(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.InvalidProtocolBufferException) IdGenerator(org.apache.beam.sdk.fn.IdGenerator) ArrayDeque(java.util.ArrayDeque) InvalidProtocolBufferException(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.InvalidProtocolBufferException) IOException(java.io.IOException) ExecutionException(java.util.concurrent.ExecutionException) StatusException(org.apache.beam.vendor.grpc.v1p43p2.io.grpc.StatusException) ExecutorService(java.util.concurrent.ExecutorService) CompletableFuture(java.util.concurrent.CompletableFuture) Future(java.util.concurrent.Future) HashMap(java.util.HashMap) ConcurrentMap(java.util.concurrent.ConcurrentMap) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap)

Example 57 with StreamObserver

use of org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.StreamObserver in project beam by apache.

the class ArtifactRetrievalService method getArtifact.

@Override
public void getArtifact(ArtifactApi.GetArtifactRequest request, StreamObserver<ArtifactApi.GetArtifactResponse> responseObserver) {
    try {
        InputStream inputStream = getArtifact(request.getArtifact());
        byte[] buffer = new byte[bufferSize];
        int bytesRead;
        while ((bytesRead = inputStream.read(buffer)) > 0) {
            responseObserver.onNext(ArtifactApi.GetArtifactResponse.newBuilder().setData(ByteString.copyFrom(buffer, 0, bytesRead)).build());
        }
        responseObserver.onCompleted();
    } catch (IOException exn) {
        exn.printStackTrace();
        responseObserver.onError(exn);
    } catch (UnsupportedOperationException exn) {
        responseObserver.onError(new StatusException(Status.INVALID_ARGUMENT.withDescription(exn.getMessage())));
    }
}
Also used : StatusException(org.apache.beam.vendor.grpc.v1p43p2.io.grpc.StatusException) InputStream(java.io.InputStream) IOException(java.io.IOException)

Example 58 with StreamObserver

use of org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.StreamObserver in project beam by apache.

the class BeamFnDataGrpcServiceTest method testMessageReceivedBySingleClientWhenThereAreMultipleClients.

@Test
public void testMessageReceivedBySingleClientWhenThereAreMultipleClients() throws Exception {
    BlockingQueue<Elements> clientInboundElements = new LinkedBlockingQueue<>();
    ExecutorService executorService = Executors.newCachedThreadPool();
    CountDownLatch waitForInboundElements = new CountDownLatch(1);
    int numberOfClients = 3;
    for (int client = 0; client < numberOfClients; ++client) {
        executorService.submit(() -> {
            ManagedChannel channel = ManagedChannelFactory.createDefault().withInterceptors(Arrays.asList(AddHarnessIdInterceptor.create(WORKER_ID))).forDescriptor(service.getApiServiceDescriptor());
            StreamObserver<BeamFnApi.Elements> outboundObserver = BeamFnDataGrpc.newStub(channel).data(TestStreams.withOnNext(clientInboundElements::add).build());
            waitForInboundElements.await();
            outboundObserver.onCompleted();
            return null;
        });
    }
    for (int i = 0; i < 3; ++i) {
        CloseableFnDataReceiver<WindowedValue<String>> consumer = service.getDataService(WORKER_ID).send(LogicalEndpoint.data(Integer.toString(i), TRANSFORM_ID), CODER);
        consumer.accept(valueInGlobalWindow("A" + i));
        consumer.accept(valueInGlobalWindow("B" + i));
        consumer.accept(valueInGlobalWindow("C" + i));
        consumer.close();
    }
    // Specifically copy the elements to a new list so we perform blocking calls on the queue
    // to ensure the elements arrive.
    List<Elements> copy = new ArrayList<>();
    for (int i = 0; i < numberOfClients; ++i) {
        copy.add(clientInboundElements.take());
    }
    assertThat(copy, containsInAnyOrder(elementsWithData("0"), elementsWithData("1"), elementsWithData("2")));
    waitForInboundElements.countDown();
}
Also used : ArrayList(java.util.ArrayList) Elements(org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements) LinkedBlockingQueue(java.util.concurrent.LinkedBlockingQueue) CountDownLatch(java.util.concurrent.CountDownLatch) LogicalEndpoint(org.apache.beam.sdk.fn.data.LogicalEndpoint) WindowedValue(org.apache.beam.sdk.util.WindowedValue) ExecutorService(java.util.concurrent.ExecutorService) ManagedChannel(org.apache.beam.vendor.grpc.v1p43p2.io.grpc.ManagedChannel) Test(org.junit.Test)

Example 59 with StreamObserver

use of org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.StreamObserver in project beam by apache.

the class BeamFnDataGrpcServiceTest method testMultipleClientsSendMessagesAreDirectedToProperConsumers.

@Test
public void testMultipleClientsSendMessagesAreDirectedToProperConsumers() throws Exception {
    LinkedBlockingQueue<BeamFnApi.Elements> clientInboundElements = new LinkedBlockingQueue<>();
    ExecutorService executorService = Executors.newCachedThreadPool();
    CountDownLatch waitForInboundElements = new CountDownLatch(1);
    for (int i = 0; i < 3; ++i) {
        String instructionId = Integer.toString(i);
        executorService.submit(() -> {
            ManagedChannel channel = ManagedChannelFactory.createDefault().withInterceptors(Arrays.asList(AddHarnessIdInterceptor.create(WORKER_ID))).forDescriptor(service.getApiServiceDescriptor());
            StreamObserver<BeamFnApi.Elements> outboundObserver = BeamFnDataGrpc.newStub(channel).data(TestStreams.withOnNext(clientInboundElements::add).build());
            outboundObserver.onNext(elementsWithData(instructionId));
            waitForInboundElements.await();
            outboundObserver.onCompleted();
            return null;
        });
    }
    List<Collection<WindowedValue<String>>> serverInboundValues = new ArrayList<>();
    Collection<InboundDataClient> inboundDataClients = new ArrayList<>();
    for (int i = 0; i < 3; ++i) {
        BlockingQueue<WindowedValue<String>> serverInboundValue = new LinkedBlockingQueue<>();
        serverInboundValues.add(serverInboundValue);
        inboundDataClients.add(service.getDataService(WORKER_ID).receive(LogicalEndpoint.data(Integer.toString(i), TRANSFORM_ID), CODER, serverInboundValue::add));
    }
    // Waiting for the client provides the necessary synchronization for the elements to arrive.
    for (InboundDataClient inboundDataClient : inboundDataClients) {
        inboundDataClient.awaitCompletion();
    }
    waitForInboundElements.countDown();
    for (int i = 0; i < 3; ++i) {
        assertThat(serverInboundValues.get(i), contains(valueInGlobalWindow("A" + i), valueInGlobalWindow("B" + i), valueInGlobalWindow("C" + i)));
    }
    assertThat(clientInboundElements, empty());
}
Also used : ArrayList(java.util.ArrayList) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) Elements(org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements) LinkedBlockingQueue(java.util.concurrent.LinkedBlockingQueue) CountDownLatch(java.util.concurrent.CountDownLatch) LogicalEndpoint(org.apache.beam.sdk.fn.data.LogicalEndpoint) InboundDataClient(org.apache.beam.sdk.fn.data.InboundDataClient) WindowedValue(org.apache.beam.sdk.util.WindowedValue) ExecutorService(java.util.concurrent.ExecutorService) ManagedChannel(org.apache.beam.vendor.grpc.v1p43p2.io.grpc.ManagedChannel) Collection(java.util.Collection) Test(org.junit.Test)

Example 60 with StreamObserver

use of org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.StreamObserver in project beam by apache.

the class BeamFnLoggingServiceTest method testMultipleClientsFailingIsHandledGracefullyByServer.

@Test(timeout = 5000)
public void testMultipleClientsFailingIsHandledGracefullyByServer() throws Exception {
    Collection<Callable<Void>> tasks = new ArrayList<>();
    ConcurrentLinkedQueue<BeamFnApi.LogEntry> logs = new ConcurrentLinkedQueue<>();
    try (BeamFnLoggingService service = new BeamFnLoggingService(findOpenPort(), logs::add, ServerStreamObserverFactory.fromOptions(PipelineOptionsFactory.create())::from, GrpcContextHeaderAccessorProvider.getHeaderAccessor())) {
        server = ServerFactory.createDefault().create(Arrays.asList(service), service.getApiServiceDescriptor());
        CountDownLatch waitForTermination = new CountDownLatch(3);
        final BlockingQueue<StreamObserver<List>> outboundObservers = new LinkedBlockingQueue<>();
        for (int i = 1; i <= 3; ++i) {
            int instructionId = i;
            tasks.add(() -> {
                ManagedChannel channel = ManagedChannelFactory.createDefault().withInterceptors(Arrays.asList(AddHarnessIdInterceptor.create(WORKER_ID + instructionId))).forDescriptor(service.getApiServiceDescriptor());
                StreamObserver<BeamFnApi.LogEntry.List> outboundObserver = BeamFnLoggingGrpc.newStub(channel).logging(TestStreams.withOnNext(BeamFnLoggingServiceTest::discardMessage).withOnError(waitForTermination::countDown).build());
                outboundObserver.onNext(createLogsWithIds(instructionId, -instructionId));
                outboundObservers.add(outboundObserver);
                return null;
            });
        }
        ExecutorService executorService = Executors.newCachedThreadPool();
        executorService.invokeAll(tasks);
        for (int i = 1; i <= 3; ++i) {
            outboundObservers.take().onError(new RuntimeException("Client " + i));
        }
        waitForTermination.await();
    }
}
Also used : StreamObserver(org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.StreamObserver) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) ArrayList(java.util.ArrayList) CountDownLatch(java.util.concurrent.CountDownLatch) LinkedBlockingQueue(java.util.concurrent.LinkedBlockingQueue) Callable(java.util.concurrent.Callable) ExecutorService(java.util.concurrent.ExecutorService) ManagedChannel(org.apache.beam.vendor.grpc.v1p43p2.io.grpc.ManagedChannel) ArrayList(java.util.ArrayList) List(org.apache.beam.model.fnexecution.v1.BeamFnApi.LogEntry.List) ConcurrentLinkedQueue(java.util.concurrent.ConcurrentLinkedQueue) Test(org.junit.Test)

Aggregations

StreamObserver (io.grpc.stub.StreamObserver)133 Test (org.junit.Test)95 CountDownLatch (java.util.concurrent.CountDownLatch)50 ArrayList (java.util.ArrayList)47 AtomicReference (java.util.concurrent.atomic.AtomicReference)38 StreamObserver (org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.StreamObserver)27 StatusRuntimeException (io.grpc.StatusRuntimeException)26 Status (io.grpc.Status)20 List (java.util.List)18 BeamFnApi (org.apache.beam.model.fnexecution.v1.BeamFnApi)18 ManagedChannel (org.apache.beam.vendor.grpc.v1p43p2.io.grpc.ManagedChannel)18 CompletableFuture (java.util.concurrent.CompletableFuture)17 ExecutorService (java.util.concurrent.ExecutorService)16 SegmentId (io.pravega.controller.stream.api.grpc.v1.Controller.SegmentId)14 ServerRequest (io.pravega.controller.stream.api.grpc.v1.Controller.ServerRequest)14 VisibleForTesting (com.google.common.annotations.VisibleForTesting)12 Strings (com.google.common.base.Strings)12 Throwables (com.google.common.base.Throwables)12 ImmutableMap (com.google.common.collect.ImmutableMap)12 AuthHandler (io.pravega.auth.AuthHandler)12