Here's my attempt on a grayscale image with only numpy based on the solution from here, with some minor tweaks (adding channels) it might suit your needs:
import numpy as np # Seperate grayscale images to w * h tiles, add padding with zeros if image not scaled def to_tiles(arr: np.ndarray, tilesize: tuple[int, int]) -> np.ndarray: def f(x: tuple[int, int]) -> tuple[int, int]: tmp = list(x) if tmp[1] > 0: tmp[0] = tmp[0] + 1 return tuple(tmp) # # Stride Implementation # bytelength = np.int8(np.divide(arr.nbytes, arr.size)) assert arr.ndim == 2, "array must be 2d (grayscale) image" a_h, a_w = arr.shape h, w = tilesize assert a_h > h, "tile height is larger than arr height" assert a_w > w, "tile width is larger than arr width" row, row_r = f(np.divmod(a_h, h)) col, col_r = f(np.divmod(a_w, w)) arr = np.pad( arr, [ ( np.int8(np.ceil(np.divide(h-row_r, 2))) if row_r != 0 else 0, np.int8(np.floor(np.divide(h-row_r, 2))) if row_r != 0 else 0, ), ( np.int8(np.ceil(np.divide(w-col_r, 2))) if col_r != 0 else 0, np.int8(np.floor(np.divide(w-col_r, 2))) if col_r != 0 else 0, ), ], "constant", constant_values=(0), ) # # Stride Implementation # arr = np.lib.stride_tricks.as_strided( # arr, shape=(row, col, h, w), strides=(h*a_w*bytelength, w*bytelength, a_w*bytelength, bytelength) # ) arr = arr.reshape(row, h, col, w).swapaxes(1, 2) arr = arr.reshape(-1, h, w) return arr
Here's an example of the result. Image from FUNSD dataset.
tiles = [im[x:x+M,y:y+N] for x in range(0,im.shape[0],M) for y in range(0,im.shape[1],N)]- See my answer below