4646import io .grpc .InternalServerInterceptors ;
4747import io .grpc .Metadata ;
4848import io .grpc .ServerCall ;
49+ import io .grpc .ServerCallExecutorSupplier ;
4950import io .grpc .ServerCallHandler ;
5051import io .grpc .ServerInterceptor ;
5152import io .grpc .ServerMethodDefinition ;
5253import io .grpc .ServerServiceDefinition ;
5354import io .grpc .ServerTransportFilter ;
5455import io .grpc .Status ;
56+ import io .grpc .StatusException ;
5557import io .perfmark .Link ;
5658import io .perfmark .PerfMark ;
5759import io .perfmark .Tag ;
@@ -125,6 +127,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
125127 private final InternalChannelz channelz ;
126128 private final CallTracer serverCallTracer ;
127129 private final Deadline .Ticker ticker ;
130+ private final ServerCallExecutorSupplier executorSupplier ;
128131
129132 /**
130133 * Construct a server.
@@ -159,6 +162,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
159162 this .serverCallTracer = builder .callTracerFactory .create ();
160163 this .ticker = checkNotNull (builder .ticker , "ticker" );
161164 channelz .addServer (this );
165+ this .executorSupplier = builder .executorSupplier ;
162166 }
163167
164168 /**
@@ -469,11 +473,11 @@ private void streamCreatedInternal(
469473 final Executor wrappedExecutor ;
470474 // This is a performance optimization that avoids the synchronization and queuing overhead
471475 // that comes with SerializingExecutor.
472- if (executor == directExecutor ()) {
476+ if (executorSupplier != null || executor != directExecutor ()) {
477+ wrappedExecutor = new SerializingExecutor (executor );
478+ } else {
473479 wrappedExecutor = new SerializeReentrantCallsDirectExecutor ();
474480 stream .optimizeForDirectExecutor ();
475- } else {
476- wrappedExecutor = new SerializingExecutor (executor );
477481 }
478482
479483 if (headers .containsKey (MESSAGE_ENCODING_KEY )) {
@@ -499,52 +503,124 @@ private void streamCreatedInternal(
499503
500504 final JumpToApplicationThreadServerStreamListener jumpListener
501505 = new JumpToApplicationThreadServerStreamListener (
502- wrappedExecutor , executor , stream , context , tag );
506+ wrappedExecutor , executor , stream , context , tag );
503507 stream .setListener (jumpListener );
504- // Run in wrappedExecutor so jumpListener.setListener() is called before any callbacks
505- // are delivered, including any errors. Callbacks can still be triggered, but they will be
506- // queued.
507-
508- final class StreamCreated extends ContextRunnable {
509- StreamCreated () {
508+ final SettableFuture <ServerCallParameters <?,?>> future = SettableFuture .create ();
509+ // Run in serializing executor so jumpListener.setListener() is called before any callbacks
510+ // are delivered, including any errors. MethodLookup() and HandleServerCall() are proactively
511+ // queued before any callbacks are queued at serializing executor.
512+ // MethodLookup() runs on the default executor.
513+ // When executorSupplier is enabled, MethodLookup() may set/change the executor in the
514+ // SerializingExecutor before it finishes running.
515+ // Then HandleServerCall() and callbacks would switch to the executorSupplier executor.
516+ // Otherwise, they all run on the default executor.
517+
518+ final class MethodLookup extends ContextRunnable {
519+ MethodLookup () {
510520 super (context );
511521 }
512522
513523 @ Override
514524 public void runInContext () {
515- PerfMark .startTask ("ServerTransportListener$StreamCreated .startCall" , tag );
525+ PerfMark .startTask ("ServerTransportListener$MethodLookup .startCall" , tag );
516526 PerfMark .linkIn (link );
517527 try {
518528 runInternal ();
519529 } finally {
520- PerfMark .stopTask ("ServerTransportListener$StreamCreated .startCall" , tag );
530+ PerfMark .stopTask ("ServerTransportListener$MethodLookup .startCall" , tag );
521531 }
522532 }
523533
524534 private void runInternal () {
525- ServerStreamListener listener = NOOP_LISTENER ;
535+ ServerMethodDefinition <?, ?> wrapMethod ;
536+ ServerCallParameters <?, ?> callParams ;
526537 try {
527538 ServerMethodDefinition <?, ?> method = registry .lookupMethod (methodName );
528539 if (method == null ) {
529540 method = fallbackRegistry .lookupMethod (methodName , stream .getAuthority ());
530541 }
531542 if (method == null ) {
532543 Status status = Status .UNIMPLEMENTED .withDescription (
533- "Method not found: " + methodName );
544+ "Method not found: " + methodName );
534545 // TODO(zhangkun83): this error may be recorded by the tracer, and if it's kept in
535546 // memory as a map whose key is the method name, this would allow a misbehaving
536547 // client to blow up the server in-memory stats storage by sending large number of
537548 // distinct unimplemented method
538549 // names. (https://github.com/grpc/grpc-java/issues/2285)
539550 stream .close (status , new Metadata ());
540551 context .cancel (null );
552+ future .cancel (false );
541553 return ;
542554 }
543- listener = startCall (stream , methodName , method , headers , context , statsTraceCtx , tag );
555+ wrapMethod = wrapMethod (stream , method , statsTraceCtx );
556+ callParams = maySwitchExecutor (wrapMethod , stream , headers , context , tag );
557+ future .set (callParams );
544558 } catch (Throwable t ) {
545559 stream .close (Status .fromThrowable (t ), new Metadata ());
546560 context .cancel (null );
561+ future .cancel (false );
547562 throw t ;
563+ }
564+ }
565+
566+ private <ReqT , RespT > ServerCallParameters <ReqT , RespT > maySwitchExecutor (
567+ final ServerMethodDefinition <ReqT , RespT > methodDef ,
568+ final ServerStream stream ,
569+ final Metadata headers ,
570+ final Context .CancellableContext context ,
571+ final Tag tag ) {
572+ final ServerCallImpl <ReqT , RespT > call = new ServerCallImpl <>(
573+ stream ,
574+ methodDef .getMethodDescriptor (),
575+ headers ,
576+ context ,
577+ decompressorRegistry ,
578+ compressorRegistry ,
579+ serverCallTracer ,
580+ tag );
581+ if (executorSupplier != null ) {
582+ Executor switchingExecutor = executorSupplier .getExecutor (call , headers );
583+ if (switchingExecutor != null ) {
584+ ((SerializingExecutor )wrappedExecutor ).setExecutor (switchingExecutor );
585+ }
586+ }
587+ return new ServerCallParameters <>(call , methodDef .getServerCallHandler ());
588+ }
589+ }
590+
591+ final class HandleServerCall extends ContextRunnable {
592+ HandleServerCall () {
593+ super (context );
594+ }
595+
596+ @ Override
597+ public void runInContext () {
598+ PerfMark .startTask ("ServerTransportListener$HandleServerCall.startCall" , tag );
599+ PerfMark .linkIn (link );
600+ try {
601+ runInternal ();
602+ } finally {
603+ PerfMark .stopTask ("ServerTransportListener$HandleServerCall.startCall" , tag );
604+ }
605+ }
606+
607+ private void runInternal () {
608+ ServerStreamListener listener = NOOP_LISTENER ;
609+ ServerCallParameters <?,?> callParameters ;
610+ try {
611+ if (future .isCancelled ()) {
612+ return ;
613+ }
614+ if (!future .isDone () || (callParameters = future .get ()) == null ) {
615+ Status status = Status .INTERNAL .withDescription (
616+ "Unexpected failure retrieving server call parameters." );
617+ throw new StatusException (status );
618+ }
619+ listener = startWrappedCall (methodName , callParameters , headers );
620+ } catch (Throwable ex ) {
621+ stream .close (Status .fromThrowable (ex ), new Metadata ());
622+ context .cancel (null );
623+ throw new IllegalStateException (ex );
548624 } finally {
549625 jumpListener .setListener (listener );
550626 }
@@ -568,7 +644,8 @@ public void cancelled(Context context) {
568644 }
569645 }
570646
571- wrappedExecutor .execute (new StreamCreated ());
647+ wrappedExecutor .execute (new MethodLookup ());
648+ wrappedExecutor .execute (new HandleServerCall ());
572649 }
573650
574651 private Context .CancellableContext createContext (
@@ -593,9 +670,8 @@ private Context.CancellableContext createContext(
593670 }
594671
595672 /** Never returns {@code null}. */
596- private <ReqT , RespT > ServerStreamListener startCall (ServerStream stream , String fullMethodName ,
597- ServerMethodDefinition <ReqT , RespT > methodDef , Metadata headers ,
598- Context .CancellableContext context , StatsTraceContext statsTraceCtx , Tag tag ) {
673+ private <ReqT , RespT > ServerMethodDefinition <?,?> wrapMethod (ServerStream stream ,
674+ ServerMethodDefinition <ReqT , RespT > methodDef , StatsTraceContext statsTraceCtx ) {
599675 // TODO(ejona86): should we update fullMethodName to have the canonical path of the method?
600676 statsTraceCtx .serverCallStarted (
601677 new ServerCallInfoImpl <>(
@@ -609,34 +685,31 @@ private <ReqT, RespT> ServerStreamListener startCall(ServerStream stream, String
609685 ServerMethodDefinition <ReqT , RespT > interceptedDef = methodDef .withServerCallHandler (handler );
610686 ServerMethodDefinition <?, ?> wMethodDef = binlog == null
611687 ? interceptedDef : binlog .wrapMethodDefinition (interceptedDef );
612- return startWrappedCall (fullMethodName , wMethodDef , stream , headers , context , tag );
688+ return wMethodDef ;
689+ }
690+
691+ private final class ServerCallParameters <ReqT , RespT > {
692+ ServerCallImpl <ReqT , RespT > call ;
693+ ServerCallHandler <ReqT , RespT > callHandler ;
694+
695+ public ServerCallParameters (ServerCallImpl <ReqT , RespT > call ,
696+ ServerCallHandler <ReqT , RespT > callHandler ) {
697+ this .call = call ;
698+ this .callHandler = callHandler ;
699+ }
613700 }
614701
615702 private <WReqT , WRespT > ServerStreamListener startWrappedCall (
616703 String fullMethodName ,
617- ServerMethodDefinition <WReqT , WRespT > methodDef ,
618- ServerStream stream ,
619- Metadata headers ,
620- Context .CancellableContext context ,
621- Tag tag ) {
622-
623- ServerCallImpl <WReqT , WRespT > call = new ServerCallImpl <>(
624- stream ,
625- methodDef .getMethodDescriptor (),
626- headers ,
627- context ,
628- decompressorRegistry ,
629- compressorRegistry ,
630- serverCallTracer ,
631- tag );
632-
633- ServerCall .Listener <WReqT > listener =
634- methodDef .getServerCallHandler ().startCall (call , headers );
635- if (listener == null ) {
704+ ServerCallParameters <WReqT , WRespT > params ,
705+ Metadata headers ) {
706+ ServerCall .Listener <WReqT > callListener =
707+ params .callHandler .startCall (params .call , headers );
708+ if (callListener == null ) {
636709 throw new NullPointerException (
637- "startCall() returned a null listener for method " + fullMethodName );
710+ "startCall() returned a null listener for method " + fullMethodName );
638711 }
639- return call .newServerStreamListener (listener );
712+ return params . call .newServerStreamListener (callListener );
640713 }
641714 }
642715
0 commit comments