local.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import asyncio
  2. import sys
  3. import threading
  4. import time
  5. class Local:
  6. """
  7. A drop-in replacement for threading.locals that also works with asyncio
  8. Tasks (via the current_task asyncio method), and passes locals through
  9. sync_to_async and async_to_sync.
  10. Specifically:
  11. - Locals work per-coroutine on any thread not spawned using asgiref
  12. - Locals work per-thread on any thread not spawned using asgiref
  13. - Locals are shared with the parent coroutine when using sync_to_async
  14. - Locals are shared with the parent thread when using async_to_sync
  15. (and if that thread was launched using sync_to_async, with its parent
  16. coroutine as well, with this working for indefinite levels of nesting)
  17. Set thread_critical to True to not allow locals to pass from an async Task
  18. to a thread it spawns. This is needed for code that truly needs
  19. thread-safety, as opposed to things used for helpful context (e.g. sqlite
  20. does not like being called from a different thread to the one it is from).
  21. Thread-critical code will still be differentiated per-Task within a thread
  22. as it is expected it does not like concurrent access.
  23. This doesn't use contextvars as it needs to support 3.6. Once it can support
  24. 3.7 only, we can then reimplement the storage more nicely.
  25. """
  26. CLEANUP_INTERVAL = 60 # seconds
  27. def __init__(self, thread_critical=False):
  28. self._storage = {}
  29. self._last_cleanup = time.time()
  30. self._clean_lock = threading.Lock()
  31. self._thread_critical = thread_critical
  32. def _get_context_id(self):
  33. """
  34. Get the ID we should use for looking up variables
  35. """
  36. # Prevent a circular reference
  37. from .sync import AsyncToSync, SyncToAsync
  38. # First, pull the current task if we can
  39. context_id = SyncToAsync.get_current_task()
  40. # OK, let's try for a thread ID
  41. if context_id is None:
  42. context_id = threading.current_thread()
  43. # If we're thread-critical, we stop here, as we can't share contexts.
  44. if self._thread_critical:
  45. return context_id
  46. # Now, take those and see if we can resolve them through the launch maps
  47. for i in range(sys.getrecursionlimit()):
  48. try:
  49. if isinstance(context_id, threading.Thread):
  50. # Threads have a source task in SyncToAsync
  51. context_id = SyncToAsync.launch_map[context_id]
  52. else:
  53. # Tasks have a source thread in AsyncToSync
  54. context_id = AsyncToSync.launch_map[context_id]
  55. except KeyError:
  56. break
  57. else:
  58. # Catch infinite loops (they happen if you are screwing around
  59. # with AsyncToSync implementations)
  60. raise RuntimeError("Infinite launch_map loops")
  61. return context_id
  62. def _cleanup(self):
  63. """
  64. Cleans up any references to dead threads or tasks
  65. """
  66. for key in list(self._storage.keys()):
  67. if isinstance(key, threading.Thread):
  68. if not key.is_alive():
  69. del self._storage[key]
  70. elif isinstance(key, asyncio.Task):
  71. if key.done():
  72. del self._storage[key]
  73. self._last_cleanup = time.time()
  74. def _maybe_cleanup(self):
  75. """
  76. Cleans up if enough time has passed
  77. """
  78. if time.time() - self._last_cleanup > self.CLEANUP_INTERVAL:
  79. with self._clean_lock:
  80. self._cleanup()
  81. def __getattr__(self, key):
  82. context_id = self._get_context_id()
  83. if key in self._storage.get(context_id, {}):
  84. return self._storage[context_id][key]
  85. else:
  86. raise AttributeError("%r object has no attribute %r" % (self, key))
  87. def __setattr__(self, key, value):
  88. if key in ("_storage", "_last_cleanup", "_clean_lock", "_thread_critical"):
  89. return super().__setattr__(key, value)
  90. self._maybe_cleanup()
  91. self._storage.setdefault(self._get_context_id(), {})[key] = value
  92. def __delattr__(self, key):
  93. context_id = self._get_context_id()
  94. if key in self._storage.get(context_id, {}):
  95. del self._storage[context_id][key]
  96. else:
  97. raise AttributeError("%r object has no attribute %r" % (self, key))