@@ -230,27 +230,29 @@ class MatrixDiagonalIndexIterator:
230230 Custom iterator class to return successive diagonal indices of a matrix
231231 '''
232232
233- def __init__ (self , m , n , k_start = 0 ):
233+ def __init__ (self , m , n , k_start = 0 , bandwidth = None ):
234234 '''
235- __init__(self, m, n, k_start=0):
235+ __init__(self, m, n, k_start=0, bandwidth=None ):
236236
237237 Arguments:
238- m (int) : number of rows in matrix
239- n (int) : number of columns in matrix
240- k_start (int) : (k_start, k_start) index to begin from
238+ m (int) : number of rows in matrix
239+ n (int) : number of columns in matrix
240+ k_start (int) : (k_start, k_start) index to begin from
241+ bandwidth (int) : bandwidth to constrain indices within
241242 '''
242- self .m = m
243- self .n = n
244- self .k = k_start
245- self .k_max = self .m + self .n - k_start
243+ self .m = m
244+ self .n = n
245+ self .k = k_start
246+ self .k_max = self .m + self .n - k_start - 1
247+ self .bandwidth = bandwidth
246248
247249 def __iter__ (self ):
248250 return self
249251
250252 def __next__ (self ):
251253 if hasattr (self , 'i' ) and hasattr (self , 'j' ):
252254
253- if self .k == self .k_max - 1 :
255+ if self .k == self .k_max :
254256 raise StopIteration
255257
256258 elif self .k < self .m and self .k < self .n :
@@ -278,4 +280,13 @@ def __next__(self):
278280 self .j = [self .k ]
279281 self .k += 1
280282
281- return self .i .copy (), self .j .copy ()
283+ if bandwidth :
284+ i_scb , j_scb = sakoe_chiba_band (self .i .copy (), self .j .copy (), self .m , self .n , bandwidth )
285+ return i_scb , j_scb
286+ else :
287+ return self .i .copy (), self .j .copy ()
288+
289+ def sakoe_chiba_band (i_list , j_list , m , n , bandwidth = 1 ):
290+ i_scb , j_scb = zip (* [(i , j ) for i ,j in zip (i_list , j_list )
291+ if abs (2 * (i * (n - 1 ) - j * (m - 1 ))) < max (m , n )* (bandwidth + 1 )])
292+ return list (i_scb ), list (j_scb )
0 commit comments