Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions azure/worker/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class FunctionInfo(typing.NamedTuple):
outputs: typing.Set[str]
requires_context: bool
is_async: bool
has_return: bool


class DispatcherMeta(type):
Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(self, loop, host, port, worker_id, request_id,
self._grpc_thread = threading.Thread(
name='grpc-thread', target=self.__poll_grpc)

self._logger = logging.getLogger('python-azure-worker')

@classmethod
async def connect(cls, host, port, worker_id, request_id,
connect_timeout):
Expand Down Expand Up @@ -184,14 +187,25 @@ def _register_function(self, function_id: str, func: callable,
bindings = {}
return_binding = None
for name, desc in metadata.bindings.items():
if desc.direction == protos.BindingInfo.inout:
raise TypeError(
f'cannot load the {func_name} function: '
f'"inout" bindings are not supported')

if name == '$return':
# TODO:
# * add proper gRPC->Python type reflection;
# * convert the type from function.json to a Python type;
# * enforce return type of a function call in Python;
# * use the return type information to marshal the result into
# a correct gRPC type.
return_binding = desc # NoQA

if desc.direction != protos.BindingInfo.out:
raise TypeError(
f'cannot load the {func_name} function: '
f'"$return" binding must have direction set to "out"')

return_binding = desc
else:
bindings[name] = desc

Expand Down Expand Up @@ -251,14 +265,19 @@ def _register_function(self, function_id: str, func: callable,
directory=metadata.directory,
outputs=frozenset(outputs),
requires_context=requires_context,
is_async=inspect.iscoroutinefunction(func))
is_async=inspect.iscoroutinefunction(func),
has_return=return_binding is not None)

async def _dispatch_grpc_request(self, request):
content_type = request.WhichOneof('content')
request_handler = getattr(self, f'_handle__{content_type}', None)
if request_handler is None:
raise RuntimeError(
# Don't crash on unknown messages. Some of them can be ignored;
# and if something goes really wrong the host can always just
# kill the worker's process.
self._logger.error(
f'unknown StreamingMessage content type {content_type}')
return

resp = await request_handler(request)
self._grpc_resp_queue.put_nowait(resp)
Expand Down Expand Up @@ -349,11 +368,15 @@ async def _handle__invocation_request(self, req):
name=name,
data=rpc_val))

return_value = None
if fi.has_return:
return_value = rpc_types.to_outgoing_proto(call_result)

return protos.StreamingMessage(
request_id=self.request_id,
invocation_response=protos.InvocationResponse(
invocation_id=invocation_id,
return_value=rpc_types.to_outgoing_proto(call_result),
return_value=return_value,
result=protos.StatusResult(
status=protos.StatusResult.Success),
output_data=output_data))
Expand Down
1 change: 1 addition & 0 deletions azure/worker/protos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FunctionLoadResponse,
InvocationRequest,
InvocationResponse,
WorkerHeartbeat,
BindingInfo,
StatusResult,
RpcException,
Expand Down
17 changes: 17 additions & 0 deletions azure/worker/tests/broken_functions/inout_param/function.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"scriptFile": "main.py",
"disabled": false,
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req"
},
{
"type": "http",
"direction": "inout",
"name": "abc"
}
]
}
2 changes: 2 additions & 0 deletions azure/worker/tests/broken_functions/inout_param/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def main(req, abc):
return 'trust me, it is OK!'
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"scriptFile": "main.py",
"disabled": false,
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req"
},
{
"type": "http",
"direction": "in",
"name": "$return"
}
]
}
2 changes: 2 additions & 0 deletions azure/worker/tests/broken_functions/return_param_in/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def main(req):
return 'trust me, it is OK!'
12 changes: 12 additions & 0 deletions azure/worker/tests/functions/no_return/function.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"scriptFile": "main.py",
"disabled": false,
"bindings": [
{
"authLevel": "anonymous",
"type": "httpTrigger",
"direction": "in",
"name": "req"
}
]
}
8 changes: 8 additions & 0 deletions azure/worker/tests/functions/no_return/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import logging


logger = logging.getLogger('test')


def main(req):
logger.error('hi')
30 changes: 30 additions & 0 deletions azure/worker/tests/test_broken_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,33 @@ async def test_load_broken__syntax_error(self):
protos.StatusResult.Failure)

self.assertIn('SyntaxError', r.response.result.exception.message)

async def test_load_broken__inout_param(self):
async with testutils.start_mockhost(
script_root='broken_functions') as host:

func_id, r = await host.load_function('inout_param')

self.assertEqual(r.response.function_id, func_id)
self.assertEqual(r.response.result.status,
protos.StatusResult.Failure)

self.assertRegex(
r.response.result.exception.message,
r'.*cannot load the inout_param function'
r'.*"inout" bindings.*')

async def test_load_broken__return_param_in(self):
async with testutils.start_mockhost(
script_root='broken_functions') as host:

func_id, r = await host.load_function('return_param_in')

self.assertEqual(r.response.function_id, func_id)
self.assertEqual(r.response.result.status,
protos.StatusResult.Failure)

self.assertRegex(
r.response.result.exception.message,
r'.*cannot load the return_param_in function'
r'.*"\$return" .* set to "out"')
4 changes: 4 additions & 0 deletions azure/worker/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def test_return_str(self):
self.assertEqual(r.status_code, 200)
self.assertEqual(r.text, 'Hello World!')

def test_no_return(self):
r = self.webhost.request('GET', 'no_return')
self.assertEqual(r.status_code, 204)

def test_async_return_str(self):
r = self.webhost.request('GET', 'async_return_str')
self.assertEqual(r.status_code, 200)
Expand Down
22 changes: 22 additions & 0 deletions azure/worker/tests/test_mockhost.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,25 @@ async def test_call_function_out_int_param(self):
name='foo',
data=protos.TypedData(int=42))
])

async def test_handles_unsupported_messages_gracefully(self):
async with testutils.start_mockhost() as host:
# Intentionally send a message to worker that isn't
# going to be ever supported by it. The idea is that
# workers should survive such messages and continue
# their operation. If anything, the host can always
# terminate the worker.
await host.send(
protos.StreamingMessage(
worker_heartbeat=protos.WorkerHeartbeat()))

_, r = await host.load_function('return_out')
self.assertEqual(r.response.result.status,
protos.StatusResult.Success)

for log in r.logs:
if 'unknown StreamingMessage' in log.message:
break
else:
raise AssertionError('the worker did not log about an '
'"unknown StreamingMessage"')
12 changes: 9 additions & 3 deletions azure/worker/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def EventStream(self, client_response_iterator, context):

yield message

if wait_for is None:
continue

response = None
logs = []

Expand Down Expand Up @@ -169,7 +172,7 @@ async def load_function(self, name):
type=b['type'],
direction=direction)

r = await self.send(
r = await self.communicate(
protos.StreamingMessage(
function_load_request=protos.FunctionLoadRequest(
function_id=func.id,
Expand All @@ -191,7 +194,7 @@ async def invoke_function(
func = self._available_functions[name]
invocation_id = self.make_id()

r = await self.send(
r = await self.communicate(
protos.StreamingMessage(
invocation_request=protos.InvocationRequest(
invocation_id=invocation_id,
Expand All @@ -201,7 +204,10 @@ async def invoke_function(

return invocation_id, r

async def send(self, message, *, wait_for):
async def send(self, message):
self._in_queue.put_nowait((message, None))

async def communicate(self, message, *, wait_for):
self._in_queue.put_nowait((message, wait_for))
return await self._out_aqueue.get()

Expand Down