@@ -37,11 +37,13 @@ def _make_connection(self, *args, **kwargs):
3737
3838 return Connection (* args , ** kwargs )
3939
40- def _transaction_mock (self ):
40+ def _transaction_mock (self , mock_response = [] ):
4141 from google .rpc .code_pb2 import OK
4242
4343 transaction = mock .Mock (committed = False , rolled_back = False )
44- transaction .batch_update = mock .Mock (return_value = [mock .Mock (code = OK ), []])
44+ transaction .batch_update = mock .Mock (
45+ return_value = [mock .Mock (code = OK ), mock_response ]
46+ )
4547 return transaction
4648
4749 def test_property_connection (self ):
@@ -62,10 +64,12 @@ def test_property_description(self):
6264 self .assertIsInstance (cursor .description [0 ], ColumnInfo )
6365
6466 def test_property_rowcount (self ):
67+ from google .cloud .spanner_dbapi .cursor import _UNSET_COUNT
68+
6569 connection = self ._make_connection (self .INSTANCE , self .DATABASE )
6670 cursor = self ._make_one (connection )
6771
68- assert cursor .rowcount == - 1
72+ self . assertEqual ( cursor .rowcount , _UNSET_COUNT )
6973
7074 def test_callproc (self ):
7175 from google .cloud .spanner_dbapi .exceptions import InterfaceError
@@ -93,25 +97,58 @@ def test_close(self, mock_client):
9397 cursor .execute ("SELECT * FROM database" )
9498
9599 def test_do_execute_update (self ):
96- from google .cloud .spanner_dbapi .checksum import ResultsChecksum
100+ from google .cloud .spanner_dbapi .cursor import _UNSET_COUNT
97101
98102 connection = self ._make_connection (self .INSTANCE , self .DATABASE )
99103 cursor = self ._make_one (connection )
100- cursor ._checksum = ResultsChecksum ()
101104 transaction = mock .MagicMock ()
102105
103106 def run_helper (ret_value ):
104107 transaction .execute_update .return_value = ret_value
105- cursor ._do_execute_update (
108+ res = cursor ._do_execute_update (
106109 transaction = transaction , sql = "SELECT * WHERE true" , params = {},
107110 )
108- return cursor . fetchall ()
111+ return res
109112
110113 expected = "good"
111- self .assertEqual (run_helper (expected ), [expected ])
114+ self .assertEqual (run_helper (expected ), expected )
115+ self .assertEqual (cursor ._row_count , _UNSET_COUNT )
112116
113117 expected = 1234
114- self .assertEqual (run_helper (expected ), [expected ])
118+ self .assertEqual (run_helper (expected ), expected )
119+ self .assertEqual (cursor ._row_count , expected )
120+
121+ def test_do_batch_update (self ):
122+ from google .cloud .spanner_dbapi import connect
123+ from google .cloud .spanner_v1 .param_types import INT64
124+ from google .cloud .spanner_v1 .types .spanner import Session
125+
126+ sql = "DELETE FROM table WHERE col1 = %s"
127+
128+ connection = connect ("test-instance" , "test-database" )
129+
130+ connection .autocommit = True
131+ transaction = self ._transaction_mock (mock_response = [1 , 1 , 1 ])
132+ cursor = connection .cursor ()
133+
134+ with mock .patch (
135+ "google.cloud.spanner_v1.services.spanner.client.SpannerClient.create_session" ,
136+ return_value = Session (),
137+ ):
138+ with mock .patch (
139+ "google.cloud.spanner_v1.session.Session.transaction" ,
140+ return_value = transaction ,
141+ ):
142+ cursor .executemany (sql , [(1 ,), (2 ,), (3 ,)])
143+
144+ transaction .batch_update .assert_called_once_with (
145+ [
146+ ("DELETE FROM table WHERE col1 = @a0" , {"a0" : 1 }, {"a0" : INT64 }),
147+ ("DELETE FROM table WHERE col1 = @a0" , {"a0" : 2 }, {"a0" : INT64 }),
148+ ("DELETE FROM table WHERE col1 = @a0" , {"a0" : 3 }, {"a0" : INT64 }),
149+ ]
150+ )
151+ self .assertEqual (cursor ._row_count , 3 )
115152
116153 def test_execute_programming_error (self ):
117154 from google .cloud .spanner_dbapi .exceptions import ProgrammingError
@@ -704,6 +741,7 @@ def test_setoutputsize(self):
704741
705742 def test_handle_dql (self ):
706743 from google .cloud .spanner_dbapi import utils
744+ from google .cloud .spanner_dbapi .cursor import _UNSET_COUNT
707745
708746 connection = self ._make_connection (self .INSTANCE , mock .MagicMock ())
709747 connection .database .snapshot .return_value .__enter__ .return_value = (
@@ -715,6 +753,7 @@ def test_handle_dql(self):
715753 cursor ._handle_DQL ("sql" , params = None )
716754 self .assertEqual (cursor ._result_set , ["0" ])
717755 self .assertIsInstance (cursor ._itr , utils .PeekIterator )
756+ self .assertEqual (cursor ._row_count , _UNSET_COUNT )
718757
719758 def test_context (self ):
720759 connection = self ._make_connection (self .INSTANCE , self .DATABASE )
0 commit comments