@@ -549,6 +549,74 @@ def test_get_table_column_schema(self):
549549 )
550550 self .assertEqual (result , expected )
551551
552+ def test_peek_iterator_aborted (self ):
553+ """
554+ Checking that an Aborted exception is retried in case it happened
555+ while streaming the first element with a PeekIterator.
556+ """
557+ from google .api_core .exceptions import Aborted
558+ from google .cloud .spanner_dbapi .connection import connect
559+
560+ with mock .patch (
561+ "google.cloud.spanner_v1.instance.Instance.exists" , return_value = True ,
562+ ):
563+ with mock .patch (
564+ "google.cloud.spanner_v1.database.Database.exists" , return_value = True ,
565+ ):
566+ connection = connect ("test-instance" , "test-database" )
567+
568+ cursor = connection .cursor ()
569+ with mock .patch (
570+ "google.cloud.spanner_dbapi.utils.PeekIterator.__init__" ,
571+ side_effect = (Aborted ("Aborted" ), None ),
572+ ):
573+ with mock .patch (
574+ "google.cloud.spanner_dbapi.connection.Connection.retry_transaction"
575+ ) as retry_mock :
576+ with mock .patch (
577+ "google.cloud.spanner_dbapi.connection.Connection.run_statement" ,
578+ return_value = ((1 , 2 , 3 ), None ),
579+ ):
580+ cursor .execute ("SELECT * FROM table_name" )
581+
582+ retry_mock .assert_called_with ()
583+
584+ def test_peek_iterator_aborted_autocommit (self ):
585+ """
586+ Checking that an Aborted exception is retried in case it happened while
587+ streaming the first element with a PeekIterator in autocommit mode.
588+ """
589+ from google .api_core .exceptions import Aborted
590+ from google .cloud .spanner_dbapi .connection import connect
591+
592+ with mock .patch (
593+ "google.cloud.spanner_v1.instance.Instance.exists" , return_value = True ,
594+ ):
595+ with mock .patch (
596+ "google.cloud.spanner_v1.database.Database.exists" , return_value = True ,
597+ ):
598+ connection = connect ("test-instance" , "test-database" )
599+
600+ connection .autocommit = True
601+ cursor = connection .cursor ()
602+ with mock .patch (
603+ "google.cloud.spanner_dbapi.utils.PeekIterator.__init__" ,
604+ side_effect = (Aborted ("Aborted" ), None ),
605+ ):
606+ with mock .patch (
607+ "google.cloud.spanner_dbapi.connection.Connection.retry_transaction"
608+ ) as retry_mock :
609+ with mock .patch (
610+ "google.cloud.spanner_dbapi.connection.Connection.run_statement" ,
611+ return_value = ((1 , 2 , 3 ), None ),
612+ ):
613+ with mock .patch (
614+ "google.cloud.spanner_v1.database.Database.snapshot"
615+ ):
616+ cursor .execute ("SELECT * FROM table_name" )
617+
618+ retry_mock .assert_called_with ()
619+
552620 def test_fetchone_retry_aborted (self ):
553621 """Check that aborted fetch re-executing transaction."""
554622 from google .api_core .exceptions import Aborted
0 commit comments