Firstly, use PyPy.
$ python2 p.py Total: 29.8 seconds $ pypy p.py Total: 5.1 seconds
PyPy actually does inlining, so the timeit overhead makes up a big part of that. Removing it gives
$ pypy p.py Total: 4.1 seconds
That's a factor-of-7 improvement already.
Often PyPy 3 is faster, so to try that. To do you should
change the prints to functions and use from __future__ import print_function to keep Python 2 compatibility and
change debugger to use def debugger(times, n_runs): and call it with debugger(*main(...)).
It averages a tiny bit faster:
$ pypy3 p.py Total: 4.1 seconds
INDEX_TO_VALUE should be a tuple, not a dictionary:
INDEX_TO_VALUE = ((0, 0, 0), (0, 1, 0), (0, 2, 0), (0, 3, 1), (0, 4, 1), (0, 5, 1), (0, 6, 2), (0, 7, 2), (0, 8, 2), ...)
The comment above it is bad
# index:(row,column,group,value)
The spacing is a bit cramped, but primarily I'm concerned that there are the wrong number of elements! Remove the last one.
Further, a comprehension is much cleaner:
INDEX_TO_VALUE = tuple((row, col, row//3 + col//3*3) for row in range(9) for col in range(9))
although you might want to make a function to help:
def group_at(row, col): return row // 3 + col // 3 * 3 INDEX_TO_VALUE = tuple((row, col, group_at(row, col)) for row in range(9) for col in range(9))
Your timeit function should really avoid the globals by using lexical scoping. This strikes me as a good comromise:
debugged = set() def timeit(func): @functools.wraps(func) def wrapper(*args, **kwargs): t_start = time.clock() ret = func(*args, **kwargs) t_end = time.clock() - t_start wrapper.calls += 1 wrapper.time += t_end return ret wrapper.calls = 0 wrapper.time = 0 debugged.add(wrapper) return wrapper
You still have global data but you localize it to the function. To me, this seems a sensible trade. Note that debugged and the original DEBUGGER_TIMER were not constants so should not be in ALL_CAPS.
Another thing that irks me is your use of strings to represent the board. Tuples of integers makes more sense and 0 can replace your .. Strings might be neglibily faster, but it seems like a poor readability trade to me.
You don't use graph as far as I can tell. Just remove it.
You do:
if vertex not in visited: visited.add(vertex) ... stack.extend(graph[vertex] - visited)
The - visited is redundant with the if vertex not in visited check. Remove it. Further, you don't need any of the checks anyway; the cost of doing them exceeds the benefit, at least after moving to lists of integers. You might need to change solutions to a set.
Your
list(results.add("%s%s%s" % (puzzle[:index], move, puzzle[index + 1:])) for move in moves[1])
is a bit crazy; just do
for move in moves[1]: results.add("%s%s%s" % (puzzle[:index], move, puzzle[index + 1:]))
Cell.board is a global; this is a really bad idea as it wrecks the ability to use this extensibly (say, have two boards). Speed-wise it's very problematic because you rebuild the whole grid each time; it would be better to mutate as you go and undo when backtracking. However, it is possible to improve. I tried something like
@timeit def get_missing_values(row, column, group, board): missing = [1, 2, 3, 4, 5, 6, 7, 8, 9] for c_row, c_column, c_group, c_value in board: if c_row == row or c_column == column or c_group == group: if c_value in missing: missing.remove(c_value) return missing @timeit def create_cells(puzzle): taken = [] untaken = [] for index, value in enumerate(puzzle): row, column, group = INDEX_TO_VALUE[index] if value: taken.append((row, column, group, value)) else: untaken.append((row, column, group)) return taken, untaken @timeit def get_next_moves(puzzle): taken, untaken = create_cells(puzzle) other_option = None len_option = 9 for row, column, group in untaken: missing_values = get_missing_values(row, column, group, taken) if len(missing_values) == 1: return 9*row+column, missing_values elif len(missing_values) < len_option: len_option = len(missing_values) other_option = 9*row+column, missing_values return other_option
This gets rid of the class for convenience (tuples are easier here). get_next_moves now generates the board and passes that to get_missing_values. This improves times:
$ python3 p.py HARD - 3 RUNS Total: 19.8 seconds $ python2 p.py HARD - 3 RUNS Total: 12.7 seconds $ pypy3 p.py HARD - 3 RUNS Total: 2.3 seconds $ pypy p.py HARD - 3 RUNS Total: 2.3 seconds
But get_missing_values expectedly still takes majority time. Using a set instead of a list speeds up CPython but slows down PyPy.
This suggests effort should be solely on a more efficient representation for that. Here's one idea:
@timeit def create_cells(puzzle): rows = [set() for _ in range(9)] columns = [set() for _ in range(9)] groups = [set() for _ in range(9)] untaken = [] for index, value in enumerate(puzzle): row, column, group = INDEX_TO_VALUE[index] if value: rows[row].add(value) columns[column].add(value) groups[group].add(value) else: untaken.append((row, column, group)) return rows, columns, groups, untaken @timeit def get_next_moves(puzzle): rows, columns, groups, untaken = create_cells(puzzle) other_option = None len_option = 9 for row, column, group in untaken: missing_values = {1, 2, 3, 4, 5, 6, 7, 8, 9} - rows[row] - columns[column] - groups[group] if len(missing_values) == 1: return 9*row+column, missing_values elif len(missing_values) < len_option: len_option = len(missing_values) other_option = 9*row+column, missing_values return other_option
This gives, after disabling timeit,
$ python3 p.py Total: 8.1 seconds $ python2 p.py Total: 6.7 seconds $ pypy3 p.py Total: 2.0 seconds $ pypy p.py Total: 2.0 seconds
which is a massive boost for CPython.
It so happens that get_next_moves is still taking the most time, although create_cells is catching up. One idea is to change the sets to bitmasks:
@timeit def create_cells(puzzle): rows = [0] * 9 columns = [0] * 9 groups = [0] * 9 untaken = [] for index, value in enumerate(puzzle): row, column, group = INDEX_TO_VALUE[index] if value: rows[row] |= 1<<(value-1) columns[column] |= 1<<(value-1) groups[group] |= 1<<(value-1) else: untaken.append((row, column, group)) return rows, columns, groups, untaken decode_bits = [tuple(i+1 for i in range(9) if 1<<i & bits) for bits in range(512)] @timeit def get_next_moves(puzzle): rows, columns, groups, untaken = create_cells(puzzle) other_option = None len_option = 9 for row, column, group in untaken: missing_values = decode_bits[0b111111111 & ~rows[row] & ~columns[column] & ~groups[group]] if len(missing_values) == 1: return 9*row+column, missing_values elif len(missing_values) < len_option: len_option = len(missing_values) other_option = 9*row+column, missing_values return other_option
This gives a very good boost in speed to PyPy and improves CPython noticeably:
$ python3 p.py HARD - 3 RUNS Total: 8.0 seconds Mean: 2.66720 seconds Max: 2.67445 seconds Min: 2.66322 seconds create_cells() Called: 47680 times per run (143040 total) Running for 6.101s (in 3 runs) / 2.03374s per run get_next_moves() Called: 47680 times per run (143040 total) Running for 1.189s (in 3 runs) / 0.39625s per run possible_moves() Called: 47680 times per run (143040 total) Running for 0.443s (in 3 runs) / 0.14770s per run depth_first_search() Called: 1 times per run (3 total) Running for 0.268s (in 3 runs) / 0.08946s per run win() Called: 1 times per run (3 total) Running for 0.000s (in 3 runs) / 0.00004s per run main() Called: 1 times per run (1 total) Running for 0.000s (in 3 runs) / 0.00003s per run
and for PyPy:
$ pypy3 p.py HARD - 4 RUNS Total: 1.0 seconds Mean: 0.26078 seconds Max: 0.35315 seconds Min: 0.21801 seconds possible_moves() Called: 47680 times per run (190720 total) Running for 0.519s (in 4 runs) / 0.12972s per run create_cells() Called: 47680 times per run (190720 total) Running for 0.339s (in 4 runs) / 0.08473s per run depth_first_search() Called: 1 times per run (4 total) Running for 0.094s (in 4 runs) / 0.02351s per run get_next_moves() Called: 47680 times per run (190720 total) Running for 0.091s (in 4 runs) / 0.02272s per run win() Called: 1 times per run (4 total) Running for 0.000s (in 4 runs) / 0.00009s per run main() Called: 1 times per run (1 total) Running for 0.000s (in 4 runs) / 0.00005s per run
So the next thing to tackle is create_cells or possible_moves, depending on which interpreter you care about most. Going with create_cells, I currently have:
@timeit def create_cells(puzzle): rows = [0] * 9 columns = [0] * 9 groups = [0] * 9 untaken = [] for index, value in enumerate(puzzle): row, column, group = INDEX_TO_VALUE[index] if value: rows[row] |= 1<<(value-1) columns[column] |= 1<<(value-1) groups[group] |= 1<<(value-1) else: untaken.append((row, column, group)) return rows, columns, groups, untaken
CPython doesn't deduplicate repeated constants like PyPy does, so one should deduplicate it manually. We should also move to using zip instead of enumerate:
@timeit def create_cells(puzzle): rows = [0] * 9 columns = [0] * 9 groups = [0] * 9 untaken = [] for position, value in zip(INDEX_TO_VALUE, puzzle): if value: row, column, group = position bit = 1<<(value-1) rows[row] |= bit columns[column] |= bit groups[group] |= bit else: untaken.append(position) return rows, columns, groups, untaken
CPython is noticably faster:
$ python3 p.py Total: 5.9 seconds $ python2 p.py Total: 4.1 seconds
CPython is now as fast as PyPy was at the start! I didn't get times for PyPy.
It's even a bit faster if we carry the change through; make all operations act on shifted bits and unshift them at the end; eg.
bit_counts = tuple(bin(i).count("1") for i in range(512)) @timeit def get_next_moves(puzzle): rows, columns, groups, untaken = create_cells(puzzle) other_option = None len_option = 9 for row, column, group in untaken: missing_values = 0b111111111 & ~rows[row] & ~columns[column] & ~groups[group] set_bits = bit_counts[missing_values] if set_bits == 1: return 9*row+column, missing_values elif set_bits < len_option: len_option = set_bits other_option = 9*row+column, missing_values return other_option
and similar for the rest of the code.
$ pypy3 p.py Min: 0.11042 seconds $ pypy p.py Min: 0.13626 seconds $ python3 p.py Min: 1.70156 seconds $ python2 p.py Min: 1.24454 seconds
I'm using minimum times now because PyPy gets below the minimum runtime overall. We can see that both interpreters are much faster than previously.
Personally, I would write possible_moves as a generator:
@timeit def possible_moves(puzzle): index_moves = get_next_moves(puzzle) if not index_moves: return index, moves = index_moves for bit in bits: if moves & bit: yield puzzle[:index] + (bit,) + puzzle[index + 1:]
Instead of your timing code, I would have used cProfile. It's built-in and far simpler to use. For CPython I would also sometimes use line_profiler, which gives line-by-line timings. In other words, it's the best thing evah. I would use the time utility to get a time of the whole code when fine-grained times aren't needed.
These get rid of a nontrivial portion of the code.
I would be very careful to stick to PEP8 spacing and add line breaks where they help. Your code is too dense.
I would extract the grid printing code into another function.
main shouldn't really be returning things; stick everything under if __name__ == '__main__' into main.
depth_first_search will only ever return exactly one solution, so there's no need to return a set. Further,
- Try returning early,
- Raise an exception
if not win, as you would have entered an invalid state. - Don't use the top-level
Exception type; use more precise variants.
def depth_first_search(start): stack = [start] solution = None while stack: vertex = stack.pop() if 0 not in vertex: assert win(vertex) if solution and vertex != solution: raise ValueError("More than one solution") solution = vertex else: stack.extend(possible_moves(vertex)) if solution is None: raise ValueError("No solution found") return solution
win doesn't really check if you've won; rename it to, say, validate.
INDEX_TO_VALUE should really be renamed by this point. I would go with POSITIONS.
validate can just be:
def validate(puzzle): return Counter(puzzle) == dict.fromkeys(BITS, 9)
In my opinion, depth_first_search should yield its solutions including duplicates and the callee should be responsible for checking that the right number of solutions are present and removing duplicates.
There should be a puzzle name: puzzle dictionary that is used to search for puzzles. I would also try to format the grid more explicitly.
I would go with:
_ = 0 puzzles = { 'easy': [ 5,3,_, _,7,_, _,_,_, 6,_,_, 1,9,5, _,_,_, _,9,8, _,_,_, _,6,_, 8,_,_, _,6,_, _,_,3, 4,_,_, 8,_,3, _,_,1, 7,_,_, _,2,_, _,_,6, _,6,_, _,_,_, 2,8,_, _,_,_, 4,1,9, _,_,5, _,_,_, _,8,_, _,7,9, ], 'hard': [ 8,_,_, _,_,_, _,_,_, _,_,3, 6,_,_, _,_,_, _,7,_, _,9,_, 2,_,_, _,5,_, _,_,7, _,_,_, _,_,_, _,4,5, 7,_,_, _,_,_, 1,_,_, _,3,_, _,_,1, _,_,_, _,6,8, _,_,8, 5,_,_, _,1,_, _,9,_, _,_,_, 4,_,_, ] }
Here's the full code:
# encoding: utf8 """ A puzzle-solving masterpiece! Solves Soduko. """ from __future__ import print_function import functools import time from collections import Counter def group_at(row, col): return row // 3 + col // 3 * 3 # The row, column and group number for each item in the grid POSITIONS = tuple((row, col, group_at(row, col)) for row in range(9) for col in range(9)) # The number of bits for each value 0 <= i < 512 BIT_COUNTS = tuple(bin(i).count("1") for i in range(512)) # For looping BITS = tuple(1<<i for i in range(9)) # Inverse of above for printing DECODE_BIT = {1<<i: i+1 for i in range(9)} DECODE_BIT[0] = 0 def find_taken(puzzle): """ Return three lists of what numbers are taken in each row, column and group and one list of which positions (row, column, group) are untaken. """ rows = [0] * 9 columns = [0] * 9 groups = [0] * 9 untaken = [] for position, bit in zip(POSITIONS, puzzle): if bit: row, column, group = position rows[row] |= bit columns[column] |= bit groups[group] |= bit else: untaken.append(position) return rows, columns, groups, untaken def get_next_moves(puzzle): """ Return the (index, missing_values) pair with the fewest possible moves. index is an integer 0 <= index < 81 and missing_values is a bitset of length 9. Returns None if there are no candidate moves. """ rows, columns, groups, untaken = find_taken(puzzle) other_option = None len_option = 9 for row, column, group in untaken: missing_values = 0b111111111 & ~rows[row] & ~columns[column] & ~groups[group] set_bits = BIT_COUNTS[missing_values] if set_bits == 1: return 9*row+column, missing_values elif set_bits < len_option: len_option = set_bits other_option = 9*row+column, missing_values return other_option def possible_moves(puzzle, index, moves): """ Yield the states of the grid for after taking the given moves at index on puzzle. index is an integer 0 <= index < 81 and moves is a bitset of length 9. """ for bit in BITS: if moves & bit: yield puzzle[:index] + (bit,) + puzzle[index + 1:] def validate(puzzle): """ Validate that the puzzle contains 9 of each number and is length 81. This does not fully validate that the solution is valid. """ return Counter(puzzle) == dict.fromkeys(BITS, 9) def depth_first_search(puzzle): """ Do a depth-first search of the solution space for the input Soduku puzzle. Yields solutions. May yield duplicates. """ stack = [puzzle] while stack: vertex = stack.pop() if 0 not in vertex: assert validate(vertex) yield vertex else: stack.extend(possible_moves(vertex, *get_next_moves(vertex))) def print_grid(puzzle): """ Print a pretty representation of the input Soduku puzzle. """ for i, bit in enumerate(puzzle, 1): value = DECODE_BIT[bit] or "·" if i % 9 == 0: print(value) else: print(value, end="") if i % 9 and not i % 3: print(" ", end="") if i == 27 or i == 54: print() def main(puzzle_name): _ = 0 puzzles = { 'easy': [ 5,3,_, _,7,_, _,_,_, 6,_,_, 1,9,5, _,_,_, _,9,8, _,_,_, _,6,_, 8,_,_, _,6,_, _,_,3, 4,_,_, 8,_,3, _,_,1, 7,_,_, _,2,_, _,_,6, _,6,_, _,_,_, 2,8,_, _,_,_, 4,1,9, _,_,5, _,_,_, _,8,_, _,7,9, ], 'hard': [ 8,_,_, _,_,_, _,_,_, _,_,3, 6,_,_, _,_,_, _,7,_, _,9,_, 2,_,_, _,5,_, _,_,7, _,_,_, _,_,_, _,4,5, 7,_,_, _,_,_, 1,_,_, _,3,_, _,_,1, _,_,_, _,6,8, _,_,8, 5,_,_, _,1,_, _,9,_, _,_,_, 4,_,_, ] } grid = tuple(i and 1<<(i-1) for i in puzzles[puzzle_name]) print("Puzzle:") print_grid(grid) [result] = set(depth_first_search(grid)) print() print("Result:") print_grid(result) if __name__ == '__main__': main('hard')