1+ import threading
12from functools import partial
23
3- from daphne .testing import DaphneProcess
44from django .contrib .staticfiles .handlers import ASGIStaticFilesHandler
5- from django .core .exceptions import ImproperlyConfigured
65from django .db import connections
76from django .db .backends .base .creation import TEST_DATABASE_PREFIX
87from django .test .testcases import TransactionTestCase
98from django .test .utils import modify_settings
9+ from django .utils .functional import classproperty
1010
1111from channels .routing import get_default_application
1212
@@ -28,65 +28,172 @@ def set_database_connection():
2828 settings .DATABASES ["default" ]["NAME" ] = test_db_name
2929
3030
31+ class ChannelsLiveServerThread (threading .Thread ):
32+ """Thread for running a live ASGI server while the tests are running."""
33+
34+ def __init__ (self , host , get_application , connections_override = None , port = 0 , setup = None ):
35+ self .host = host
36+ self .port = port
37+ self .get_application = get_application
38+ self .connections_override = connections_override
39+ self .setup = setup
40+ self .is_ready = threading .Event ()
41+ self .error = None
42+ super ().__init__ ()
43+
44+ def run (self ):
45+ """
46+ Set up the live server and databases, and then loop over handling
47+ ASGI requests.
48+ """
49+ if self .connections_override :
50+ # Override this thread's database connections with the ones
51+ # provided by the main thread.
52+ for alias , conn in self .connections_override .items ():
53+ connections [alias ] = conn
54+
55+ try :
56+ # Reinstall the reactor for this thread (same as DaphneProcess)
57+ from daphne .testing import _reinstall_reactor
58+ _reinstall_reactor ()
59+
60+ from twisted .internet import reactor
61+ from daphne .endpoints import build_endpoint_description_strings
62+ from daphne .server import Server
63+
64+ # Get the application
65+ application = self .get_application ()
66+
67+ # Create the server
68+ endpoints = build_endpoint_description_strings (
69+ host = self .host , port = self .port
70+ )
71+ self .server = Server (
72+ application = application ,
73+ endpoints = endpoints ,
74+ signal_handlers = False ,
75+ ready_callable = self ._set_ready ,
76+ verbosity = 0 ,
77+ )
78+
79+ # Run setup if provided
80+ if self .setup is not None :
81+ self .setup ()
82+
83+ # Start the server
84+ self .server .run ()
85+ except Exception as e :
86+ self .error = e
87+ self .is_ready .set ()
88+ finally :
89+ connections .close_all ()
90+
91+ def _set_ready (self ):
92+ """Called by Daphne when the server is ready."""
93+ if self .server .listening_addresses :
94+ self .port = self .server .listening_addresses [0 ][1 ]
95+ self .is_ready .set ()
96+
97+ def terminate (self ):
98+ if hasattr (self , "server" ):
99+ # Stop the ASGI server
100+ from twisted .internet import reactor
101+
102+ if reactor .running :
103+ reactor .callFromThread (reactor .stop )
104+ self .join (timeout = 5 )
105+
106+
31107class ChannelsLiveServerTestCase (TransactionTestCase ):
32108 """
33- Does basically the same as TransactionTestCase but also launches a
34- live Daphne server in a separate process, so
35- that the tests may use another test framework, such as Selenium,
36- instead of the built-in dummy client.
109+ Do basically the same as TransactionTestCase but also launch a live ASGI
110+ server in a separate thread so that the tests may use another testing
111+ framework, such as Selenium for example, instead of the built-in dummy
112+ client.
113+ It inherits from TransactionTestCase instead of TestCase because the
114+ threads don't share the same transactions (unless if using in-memory
115+ sqlite) and each thread needs to commit all their transactions so that the
116+ other thread can see the changes.
37117 """
38118
39119 host = "localhost"
40- ProtocolServerProcess = DaphneProcess
41- static_wrapper = ASGIStaticFilesHandler
120+ port = 0
121+ server_thread_class = ChannelsLiveServerThread
122+ static_handler = ASGIStaticFilesHandler
42123 serve_static = True
43124
44- @property
45- def live_server_url (self ):
46- return "http://%s:%s" % (self .host , self . _port )
125+ @classproperty
126+ def live_server_url (cls ):
127+ return "http://%s:%s" % (cls .host , cls . server_thread . port )
47128
48- @property
49- def live_server_ws_url (self ):
50- return "ws://%s:%s" % (self .host , self ._port )
129+ @classproperty
130+ def live_server_ws_url (cls ):
131+ return "ws://%s:%s" % (cls .host , cls .server_thread .port )
132+
133+ @classproperty
134+ def allowed_host (cls ):
135+ return cls .host
51136
52137 @classmethod
53- def setUpClass (cls ):
54- for connection in connections .all ():
55- if cls ._is_in_memory_db (connection ):
56- raise ImproperlyConfigured (
57- "ChannelLiveServerTestCase can not be used with in memory databases"
58- )
138+ def _make_connections_override (cls ):
139+ connections_override = {}
140+ for conn in connections .all ():
141+ # If using in-memory sqlite databases, pass the connections to
142+ # the server thread.
143+ if conn .vendor == "sqlite" and conn .is_in_memory_db ():
144+ connections_override [conn .alias ] = conn
145+ return connections_override
59146
147+ @classmethod
148+ def setUpClass (cls ):
60149 super ().setUpClass ()
61-
62- cls ._live_server_modified_settings = modify_settings (
63- ALLOWED_HOSTS = {"append" : cls .host }
150+ cls .enterClassContext (
151+ modify_settings (ALLOWED_HOSTS = {"append" : cls .allowed_host })
64152 )
65- cls ._live_server_modified_settings .enable ()
153+ cls ._start_server_thread ()
154+
155+ @classmethod
156+ def _start_server_thread (cls ):
157+ connections_override = cls ._make_connections_override ()
158+ for conn in connections_override .values ():
159+ # Explicitly enable thread-shareability for this connection.
160+ conn .inc_thread_sharing ()
161+
162+ cls .server_thread = cls ._create_server_thread (connections_override )
163+ cls .server_thread .daemon = True
164+ cls .server_thread .start ()
165+ cls .addClassCleanup (cls ._terminate_thread )
166+
167+ # Wait for the live server to be ready
168+ cls .server_thread .is_ready .wait ()
169+ if cls .server_thread .error :
170+ raise cls .server_thread .error
66171
172+ @classmethod
173+ def _create_server_thread (cls , connections_override ):
67174 get_application = partial (
68175 make_application ,
69- static_wrapper = cls .static_wrapper if cls .serve_static else None ,
176+ static_wrapper = cls .static_handler if cls .serve_static else None ,
70177 )
71- cls . _server_process = cls .ProtocolServerProcess (
178+ return cls .server_thread_class (
72179 cls .host ,
73180 get_application ,
181+ connections_override = connections_override ,
182+ port = cls .port ,
74183 setup = set_database_connection ,
75184 )
76- cls . _server_process . start ()
77- while True :
78- if not cls . _server_process . ready . wait ( timeout = 1 ):
79- if cls . _server_process . is_alive ():
80- continue
81- raise RuntimeError ( "Server stopped" ) from None
82- break
83- cls . _port = cls . _server_process . port . value
185+
186+ @ classmethod
187+ def _terminate_thread ( cls ):
188+ # Terminate the live server's thread.
189+ cls . server_thread . terminate ()
190+ # Restore shared connections' non-shareability.
191+ for conn in cls . server_thread . connections_override . values ():
192+ conn . dec_thread_sharing ()
84193
85194 @classmethod
86195 def tearDownClass (cls ):
87- cls ._server_process .terminate ()
88- cls ._server_process .join ()
89- cls ._live_server_modified_settings .disable ()
196+ # The cleanup is now handled by addClassCleanup in _start_server_thread
90197 super ().tearDownClass ()
91198
92199 @classmethod
0 commit comments