123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- import asyncio
- import sys
- import threading
- import time
- class Local:
- """
- A drop-in replacement for threading.locals that also works with asyncio
- Tasks (via the current_task asyncio method), and passes locals through
- sync_to_async and async_to_sync.
- Specifically:
- - Locals work per-coroutine on any thread not spawned using asgiref
- - Locals work per-thread on any thread not spawned using asgiref
- - Locals are shared with the parent coroutine when using sync_to_async
- - Locals are shared with the parent thread when using async_to_sync
- (and if that thread was launched using sync_to_async, with its parent
- coroutine as well, with this working for indefinite levels of nesting)
- Set thread_critical to True to not allow locals to pass from an async Task
- to a thread it spawns. This is needed for code that truly needs
- thread-safety, as opposed to things used for helpful context (e.g. sqlite
- does not like being called from a different thread to the one it is from).
- Thread-critical code will still be differentiated per-Task within a thread
- as it is expected it does not like concurrent access.
- This doesn't use contextvars as it needs to support 3.6. Once it can support
- 3.7 only, we can then reimplement the storage more nicely.
- """
- CLEANUP_INTERVAL = 60 # seconds
- def __init__(self, thread_critical=False):
- self._storage = {}
- self._last_cleanup = time.time()
- self._clean_lock = threading.Lock()
- self._thread_critical = thread_critical
- def _get_context_id(self):
- """
- Get the ID we should use for looking up variables
- """
- # Prevent a circular reference
- from .sync import AsyncToSync, SyncToAsync
- # First, pull the current task if we can
- context_id = SyncToAsync.get_current_task()
- # OK, let's try for a thread ID
- if context_id is None:
- context_id = threading.current_thread()
- # If we're thread-critical, we stop here, as we can't share contexts.
- if self._thread_critical:
- return context_id
- # Now, take those and see if we can resolve them through the launch maps
- for i in range(sys.getrecursionlimit()):
- try:
- if isinstance(context_id, threading.Thread):
- # Threads have a source task in SyncToAsync
- context_id = SyncToAsync.launch_map[context_id]
- else:
- # Tasks have a source thread in AsyncToSync
- context_id = AsyncToSync.launch_map[context_id]
- except KeyError:
- break
- else:
- # Catch infinite loops (they happen if you are screwing around
- # with AsyncToSync implementations)
- raise RuntimeError("Infinite launch_map loops")
- return context_id
- def _cleanup(self):
- """
- Cleans up any references to dead threads or tasks
- """
- for key in list(self._storage.keys()):
- if isinstance(key, threading.Thread):
- if not key.is_alive():
- del self._storage[key]
- elif isinstance(key, asyncio.Task):
- if key.done():
- del self._storage[key]
- self._last_cleanup = time.time()
- def _maybe_cleanup(self):
- """
- Cleans up if enough time has passed
- """
- if time.time() - self._last_cleanup > self.CLEANUP_INTERVAL:
- with self._clean_lock:
- self._cleanup()
- def __getattr__(self, key):
- context_id = self._get_context_id()
- if key in self._storage.get(context_id, {}):
- return self._storage[context_id][key]
- else:
- raise AttributeError("%r object has no attribute %r" % (self, key))
- def __setattr__(self, key, value):
- if key in ("_storage", "_last_cleanup", "_clean_lock", "_thread_critical"):
- return super().__setattr__(key, value)
- self._maybe_cleanup()
- self._storage.setdefault(self._get_context_id(), {})[key] = value
- def __delattr__(self, key):
- context_id = self._get_context_id()
- if key in self._storage.get(context_id, {}):
- del self._storage[context_id][key]
- else:
- raise AttributeError("%r object has no attribute %r" % (self, key))
|