Skip to content

Commit 2c291fb

Browse files
committed
added sakoe-chiba band to MatrixDiagonalIndexIterator
1 parent 7a04c77 commit 2c291fb

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

sdtw.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -230,27 +230,29 @@ class MatrixDiagonalIndexIterator:
230230
Custom iterator class to return successive diagonal indices of a matrix
231231
'''
232232

233-
def __init__(self, m, n, k_start=0):
233+
def __init__(self, m, n, k_start=0, bandwidth=None):
234234
'''
235-
__init__(self, m, n, k_start=0):
235+
__init__(self, m, n, k_start=0, bandwidth=None):
236236
237237
Arguments:
238-
m (int) : number of rows in matrix
239-
n (int) : number of columns in matrix
240-
k_start (int) : (k_start, k_start) index to begin from
238+
m (int) : number of rows in matrix
239+
n (int) : number of columns in matrix
240+
k_start (int) : (k_start, k_start) index to begin from
241+
bandwidth (int) : bandwidth to constrain indices within
241242
'''
242-
self.m = m
243-
self.n = n
244-
self.k = k_start
245-
self.k_max = self.m + self.n - k_start
243+
self.m = m
244+
self.n = n
245+
self.k = k_start
246+
self.k_max = self.m + self.n - k_start - 1
247+
self.bandwidth = bandwidth
246248

247249
def __iter__(self):
248250
return self
249251

250252
def __next__(self):
251253
if hasattr(self, 'i') and hasattr(self, 'j'):
252254

253-
if self.k == self.k_max-1:
255+
if self.k == self.k_max:
254256
raise StopIteration
255257

256258
elif self.k < self.m and self.k < self.n:
@@ -278,4 +280,13 @@ def __next__(self):
278280
self.j = [self.k]
279281
self.k+=1
280282

281-
return self.i.copy(), self.j.copy()
283+
if bandwidth:
284+
i_scb, j_scb = sakoe_chiba_band(self.i.copy(), self.j.copy(), self.m, self.n, bandwidth)
285+
return i_scb, j_scb
286+
else:
287+
return self.i.copy(), self.j.copy()
288+
289+
def sakoe_chiba_band(i_list, j_list, m, n, bandwidth=1):
290+
i_scb, j_scb = zip(*[(i, j) for i,j in zip(i_list, j_list)
291+
if abs(2*(i*(n-1) - j*(m-1))) < max(m, n)*(bandwidth+1)])
292+
return list(i_scb), list(j_scb)

0 commit comments

Comments
 (0)