Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 92 additions & 19 deletions workers/proxy_worker/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@
# Library worker import reloaded in init and reload request
_library_worker = None

# Thread-local invocation ID registry for efficient lookup
_thread_invocation_registry: typing.Dict[int, str] = {}
_registry_lock = threading.Lock()

# Global current invocation tracker (as a fallback)
_current_invocation_id: Optional[str] = None
_current_invocation_lock = threading.Lock()


class ContextEnabledTask(asyncio.Task):
AZURE_INVOCATION_ID = '__azure_function_invocation_id__'
Expand All @@ -61,16 +69,63 @@ def set_azure_invocation_id(self, invocation_id: str) -> None:
_invocation_id_local = threading.local()


def set_thread_invocation_id(thread_id: int, invocation_id: str) -> None:
"""Set the invocation ID for a specific thread"""
with _registry_lock:
_thread_invocation_registry[thread_id] = invocation_id


def clear_thread_invocation_id(thread_id: int) -> None:
"""Clear the invocation ID for a specific thread"""
with _registry_lock:
_thread_invocation_registry.pop(thread_id, None)


def get_thread_invocation_id(thread_id: int) -> Optional[str]:
"""Get the invocation ID for a specific thread"""
with _registry_lock:
return _thread_invocation_registry.get(thread_id)


def set_current_invocation_id(invocation_id: str) -> None:
"""Set the global current invocation ID"""
global _current_invocation_id
with _current_invocation_lock:
_current_invocation_id = invocation_id


def get_global_current_invocation_id() -> Optional[str]:
"""Get the global current invocation ID"""
with _current_invocation_lock:
return _current_invocation_id


def get_current_invocation_id() -> Optional[Any]:
loop = asyncio._get_running_loop()
if loop is not None:
current_task = asyncio.current_task(loop)
if current_task is not None:
task_invocation_id = getattr(current_task,
ContextEnabledTask.AZURE_INVOCATION_ID,
None)
if task_invocation_id is not None:
return task_invocation_id
# Check global current invocation first (most up-to-date)
global_invocation_id = get_global_current_invocation_id()
if global_invocation_id is not None:
return global_invocation_id

# Check asyncio task context
try:
loop = asyncio._get_running_loop()
if loop is not None:
current_task = asyncio.current_task(loop)
if current_task is not None:
task_invocation_id = getattr(current_task,
ContextEnabledTask.AZURE_INVOCATION_ID,
None)
if task_invocation_id is not None:
return task_invocation_id
except RuntimeError:
# No event loop running
pass

# Check the thread-local invocation ID registry
current_thread_id = threading.get_ident()
thread_invocation_id = get_thread_invocation_id(current_thread_id)
if thread_invocation_id is not None:
return thread_invocation_id

return getattr(_invocation_id_local, 'invocation_id', None)

Expand Down Expand Up @@ -516,14 +571,32 @@ async def _handle__invocation_request(self, request):
'invocation_id: %s, worker_id: %s',
self.request_id, function_id, invocation_id, self.worker_id)

invocation_request = WorkerRequest(name="FunctionInvocationRequest",
request=request,
properties={
"threadpool": self._sync_call_tp})
invocation_response = await (
_library_worker.invocation_request( # type: ignore[union-attr]
invocation_request))
# Set the global current invocation ID first (for all threads to access)
set_current_invocation_id(invocation_id)

return protos.StreamingMessage(
request_id=self.request_id,
invocation_response=invocation_response)
# Set the current `invocation_id` to the current task so
# that our logging handler can find it.
current_task = asyncio.current_task()
if current_task is not None and isinstance(current_task, ContextEnabledTask):
current_task.set_azure_invocation_id(invocation_id)

# Register the invocation ID for the current thread
current_thread_id = threading.get_ident()
set_thread_invocation_id(current_thread_id, invocation_id)

try:
invocation_request = WorkerRequest(name="FunctionInvocationRequest",
request=request,
properties={
"threadpool": self._sync_call_tp})
invocation_response = await (
_library_worker.invocation_request( # type: ignore[union-attr]
invocation_request))

return protos.StreamingMessage(
request_id=self.request_id,
invocation_response=invocation_response)
except Exception:
# Clear thread registry on exception to prevent stale IDs
clear_thread_invocation_id(current_thread_id)
raise
Loading
Loading