Skip to content

Commit 5d5e895

Browse files
authored
feat(experimental): Add read resumption strategy (#1599)
* feat(experimental): Add read resumption strategy * add unit tests * minor fixes * resolving comments
1 parent 5fb85ea commit 5d5e895

File tree

3 files changed

+281
-1
lines changed

3 files changed

+281
-1
lines changed

google/cloud/storage/_experimental/asyncio/retry/base_strategy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,16 @@ def generate_requests(self, state: Any) -> Iterable[Any]:
3131
pass
3232

3333
@abc.abstractmethod
34-
def update_state_from_response(self, state: Any) -> None:
34+
def update_state_from_response(self, response: Any, state: Any) -> None:
3535
"""Updates the state based on a successful server response.
3636
3737
This method is called for every message received from the server. It is
3838
responsible for processing the response and updating the shared state
3939
object.
4040
41+
:type response: Any
42+
:param response: The response message received from the server.
43+
4144
:type state: Any
4245
:param state: The shared state object for the operation, which will be
4346
mutated by this method.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from typing import Any, List, IO
2+
3+
from google.cloud import _storage_v2 as storage_v2
4+
from google.cloud.storage.exceptions import DataCorruption
5+
from google.cloud.storage._experimental.asyncio.retry.base_strategy import (
6+
_BaseResumptionStrategy,
7+
)
8+
9+
class _DownloadState:
10+
"""A helper class to track the state of a single range download."""
11+
def __init__(self, initial_offset: int, initial_length: int, user_buffer: IO[bytes]):
12+
self.initial_offset = initial_offset
13+
self.initial_length = initial_length
14+
self.user_buffer = user_buffer
15+
self.bytes_written = 0
16+
self.next_expected_offset = initial_offset
17+
self.is_complete = False
18+
19+
20+
class _ReadResumptionStrategy(_BaseResumptionStrategy):
21+
"""The concrete resumption strategy for bidi reads."""
22+
23+
def generate_requests(self, state: dict) -> List[storage_v2.ReadRange]:
24+
"""Generates new ReadRange requests for all incomplete downloads.
25+
26+
:type state: dict
27+
:param state: A dictionary mapping a read_id to its corresponding
28+
_DownloadState object.
29+
"""
30+
pending_requests = []
31+
for read_id, read_state in state.items():
32+
if not read_state.is_complete:
33+
new_offset = read_state.initial_offset + read_state.bytes_written
34+
new_length = read_state.initial_length - read_state.bytes_written
35+
36+
new_request = storage_v2.ReadRange(
37+
read_offset=new_offset,
38+
read_length=new_length,
39+
read_id=read_id,
40+
)
41+
pending_requests.append(new_request)
42+
return pending_requests
43+
44+
def update_state_from_response(self, response: storage_v2.BidiReadObjectResponse, state: dict) -> None:
45+
"""Processes a server response, performs integrity checks, and updates state."""
46+
for object_data_range in response.object_data_ranges:
47+
read_id = object_data_range.read_range.read_id
48+
read_state = state[read_id]
49+
50+
# Offset Verification
51+
chunk_offset = object_data_range.read_range.read_offset
52+
if chunk_offset != read_state.next_expected_offset:
53+
raise DataCorruption(response, f"Offset mismatch for read_id {read_id}")
54+
55+
data = object_data_range.checksummed_data.content
56+
chunk_size = len(data)
57+
read_state.bytes_written += chunk_size
58+
read_state.next_expected_offset += chunk_size
59+
read_state.user_buffer.write(data)
60+
61+
# Final Byte Count Verification
62+
if object_data_range.range_end:
63+
read_state.is_complete = True
64+
if read_state.initial_length != 0 and read_state.bytes_written != read_state.initial_length:
65+
raise DataCorruption(response, f"Byte count mismatch for read_id {read_id}")
66+
67+
async def recover_state_on_failure(self, error: Exception, state: Any) -> None:
68+
"""Handles BidiReadObjectRedirectError for reads."""
69+
# This would parse the gRPC error details, extract the routing_token,
70+
# and store it on the shared state object.
71+
pass
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import io
16+
import unittest
17+
import pytest
18+
from google.cloud.storage.exceptions import DataCorruption
19+
20+
from google.cloud import _storage_v2 as storage_v2
21+
from google.cloud.storage._experimental.asyncio.retry.reads_resumption_strategy import (
22+
_DownloadState,
23+
_ReadResumptionStrategy,
24+
)
25+
26+
_READ_ID = 1
27+
28+
29+
class TestDownloadState(unittest.TestCase):
30+
def test_initialization(self):
31+
"""Test that _DownloadState initializes correctly."""
32+
initial_offset = 10
33+
initial_length = 100
34+
user_buffer = io.BytesIO()
35+
state = _DownloadState(initial_offset, initial_length, user_buffer)
36+
37+
self.assertEqual(state.initial_offset, initial_offset)
38+
self.assertEqual(state.initial_length, initial_length)
39+
self.assertEqual(state.user_buffer, user_buffer)
40+
self.assertEqual(state.bytes_written, 0)
41+
self.assertEqual(state.next_expected_offset, initial_offset)
42+
self.assertFalse(state.is_complete)
43+
44+
45+
class TestReadResumptionStrategy(unittest.TestCase):
46+
def test_generate_requests_single_incomplete(self):
47+
"""Test generating a request for a single incomplete download."""
48+
read_state = _DownloadState(0, 100, io.BytesIO())
49+
read_state.bytes_written = 20
50+
state = {_READ_ID: read_state}
51+
52+
read_strategy = _ReadResumptionStrategy()
53+
requests = read_strategy.generate_requests(state)
54+
55+
self.assertEqual(len(requests), 1)
56+
self.assertEqual(requests[0].read_offset, 20)
57+
self.assertEqual(requests[0].read_length, 80)
58+
self.assertEqual(requests[0].read_id, _READ_ID)
59+
60+
def test_generate_requests_multiple_incomplete(self):
61+
"""Test generating requests for multiple incomplete downloads."""
62+
read_id2 = 2
63+
read_state1 = _DownloadState(0, 100, io.BytesIO())
64+
read_state1.bytes_written = 50
65+
read_state2 = _DownloadState(200, 100, io.BytesIO())
66+
state = {_READ_ID: read_state1, read_id2: read_state2}
67+
68+
read_strategy = _ReadResumptionStrategy()
69+
requests = read_strategy.generate_requests(state)
70+
71+
self.assertEqual(len(requests), 2)
72+
req1 = next(request for request in requests if request.read_id == _READ_ID)
73+
req2 = next(request for request in requests if request.read_id == read_id2)
74+
75+
self.assertEqual(req1.read_offset, 50)
76+
self.assertEqual(req1.read_length, 50)
77+
self.assertEqual(req2.read_offset, 200)
78+
self.assertEqual(req2.read_length, 100)
79+
80+
def test_generate_requests_with_complete(self):
81+
"""Test that no request is generated for a completed download."""
82+
read_state = _DownloadState(0, 100, io.BytesIO())
83+
read_state.is_complete = True
84+
state = {_READ_ID: read_state}
85+
86+
read_strategy = _ReadResumptionStrategy()
87+
requests = read_strategy.generate_requests(state)
88+
89+
self.assertEqual(len(requests), 0)
90+
91+
def test_generate_requests_empty_state(self):
92+
"""Test generating requests with an empty state."""
93+
read_strategy = _ReadResumptionStrategy()
94+
requests = read_strategy.generate_requests({})
95+
self.assertEqual(len(requests), 0)
96+
97+
def test_update_state_processes_single_chunk_successfully(self):
98+
"""Test updating state from a successful response."""
99+
buffer = io.BytesIO()
100+
read_state = _DownloadState(0, 100, buffer)
101+
state = {_READ_ID: read_state}
102+
data = b"test_data"
103+
read_strategy = _ReadResumptionStrategy()
104+
105+
response = storage_v2.BidiReadObjectResponse(
106+
object_data_ranges=[
107+
storage_v2.types.ObjectRangeData(
108+
read_range=storage_v2.ReadRange(read_id=_READ_ID, read_offset=0, read_length=len(data)),
109+
checksummed_data=storage_v2.ChecksummedData(content=data),
110+
)
111+
]
112+
)
113+
114+
read_strategy.update_state_from_response(response, state)
115+
116+
self.assertEqual(read_state.bytes_written, len(data))
117+
self.assertEqual(read_state.next_expected_offset, len(data))
118+
self.assertFalse(read_state.is_complete)
119+
self.assertEqual(buffer.getvalue(), data)
120+
121+
def test_update_state_from_response_offset_mismatch(self):
122+
"""Test that an offset mismatch raises DataCorruption."""
123+
read_state = _DownloadState(0, 100, io.BytesIO())
124+
read_state.next_expected_offset = 10
125+
state = {_READ_ID: read_state}
126+
read_strategy = _ReadResumptionStrategy()
127+
128+
response = storage_v2.BidiReadObjectResponse(
129+
object_data_ranges=[
130+
storage_v2.types.ObjectRangeData(
131+
read_range=storage_v2.ReadRange(read_id=_READ_ID, read_offset=0, read_length=4),
132+
checksummed_data=storage_v2.ChecksummedData(content=b"data"),
133+
)
134+
]
135+
)
136+
137+
with pytest.raises(DataCorruption) as exc_info:
138+
read_strategy.update_state_from_response(response, state)
139+
assert "Offset mismatch" in str(exc_info.value)
140+
141+
def test_update_state_from_response_final_byte_count_mismatch(self):
142+
"""Test that a final byte count mismatch raises DataCorruption."""
143+
read_state = _DownloadState(0, 100, io.BytesIO())
144+
state = {_READ_ID: read_state}
145+
read_strategy = _ReadResumptionStrategy()
146+
147+
response = storage_v2.BidiReadObjectResponse(
148+
object_data_ranges=[
149+
storage_v2.types.ObjectRangeData(
150+
read_range=storage_v2.ReadRange(read_id=_READ_ID, read_offset=0, read_length=4),
151+
checksummed_data=storage_v2.ChecksummedData(content=b"data"),
152+
range_end=True,
153+
)
154+
]
155+
)
156+
157+
with pytest.raises(DataCorruption) as exc_info:
158+
read_strategy.update_state_from_response(response, state)
159+
assert "Byte count mismatch" in str(exc_info.value)
160+
161+
def test_update_state_from_response_completes_download(self):
162+
"""Test that the download is marked complete on range_end."""
163+
buffer = io.BytesIO()
164+
data = b"test_data"
165+
read_state = _DownloadState(0, len(data), buffer)
166+
state = {_READ_ID: read_state}
167+
read_strategy = _ReadResumptionStrategy()
168+
169+
response = storage_v2.BidiReadObjectResponse(
170+
object_data_ranges=[
171+
storage_v2.types.ObjectRangeData(
172+
read_range=storage_v2.ReadRange(read_id=_READ_ID, read_offset=0, read_length=len(data)),
173+
checksummed_data=storage_v2.ChecksummedData(content=data),
174+
range_end=True,
175+
)
176+
]
177+
)
178+
179+
read_strategy.update_state_from_response(response, state)
180+
181+
self.assertTrue(read_state.is_complete)
182+
self.assertEqual(read_state.bytes_written, len(data))
183+
self.assertEqual(buffer.getvalue(), data)
184+
185+
def test_update_state_from_response_completes_download_zero_length(self):
186+
"""Test completion for a download with initial_length of 0."""
187+
buffer = io.BytesIO()
188+
data = b"test_data"
189+
read_state = _DownloadState(0, 0, buffer)
190+
state = {_READ_ID: read_state}
191+
read_strategy = _ReadResumptionStrategy()
192+
193+
response = storage_v2.BidiReadObjectResponse(
194+
object_data_ranges=[
195+
storage_v2.types.ObjectRangeData(
196+
read_range=storage_v2.ReadRange(read_id=_READ_ID, read_offset=0, read_length=len(data)),
197+
checksummed_data=storage_v2.ChecksummedData(content=data),
198+
range_end=True,
199+
)
200+
]
201+
)
202+
203+
read_strategy.update_state_from_response(response, state)
204+
205+
self.assertTrue(read_state.is_complete)
206+
self.assertEqual(read_state.bytes_written, len(data))

0 commit comments

Comments
 (0)