Leetcode Link: 37. 解数独 - 力扣(LeetCode)

题目

编写一个程序,通过填充空格来解决数独问题。

数独的解法需 遵循如下规则

数字 1-9 在每一行只能出现一次。 数字 1-9 在每一列只能出现一次。 数字 1-9 在每一个以粗实线分隔的 3x3 宫内只能出现一次。(请参考示例图) 数独部分空格内已填入了数字,空白格用 '.' 表示。

解法一

思路:回溯的棋盘问题,在多个维度上需要考虑剪枝、去重的问题

2022.5.14 一个小时,两次提交直接 AC!无任何参考!

在考虑这个问题的时候,很显然每个递归函数需要输入的是”即将填写数字的位置”,但是我觉得如果给行列的参数的话每次还要判断,感觉代码会很丑,而且怕漏条件,所以我只给一个 pos 参数,通过这个参数计算出当前位置所处的行、列,所属的 patch

从之前的回溯中,需要有一个 for 来遍历一个东西,本题中遍历什么东西呢? 显然,本题需要遍历所有的数字,然后判断放在这个位置是不是合规。

有一个难点是怎么保证输入的棋盘是不变的,即在回溯的过程中,如果回溯到一个题目给定数字的位置,需要保证不会把这个位置删除或者改变。通过到达这个节点时判断该节点是不是有值,是可行的。

然而,在做的时候想想,在每个节点,都需要从 1-9 遍历,并遍历此行、此列、此 patch 是不是重复。感觉时间复杂度巨大!

怎么减少这种遍历的时间呢?

我的方法是:构建每行、每列、每个 patch 的”可用数字”,然后就不用遍历 1-9 了,只需要在每个节点遍历其中一个”可用数字”,看看当前拿出来的值是不是在其他两个”可用数字”中,就可以快速判断当前值是不是合规

其实除了”可用数字”,也可以构建”已用数字”,但是这样还需要在决定每个位置的数字的时候遍历 1-9,然后看看在不在”已用数字”中,这样就多了几个判断,感觉还是有点麻烦。

满足题目要求 到达叶子节点就可以,但是要及时返回,不要考虑其他的子树

题解

class Solution:
    def solveSudoku(self, board: List[List[str]]) -> None:
        """
        Do not return anything, modify board in-place instead.
        """
        def backtarcking(pos):
            nonlocal board, row_contents, col_contents, patch_contents
            if pos == 81:
                return True
            # 通过给定pos计算所在行、列、patch
            # 这里最好稍微琢磨琢磨
            i = pos % 9   # 放置行数
            j = pos // 9  # 放置列数
            px = i // 3 # 放置位置所属的patch x
            py = j // 3 # 放置位置所属的patch y
            # 题目中给出直接跳过就可以
            if board[i][j] != '.':
                if backtarcking(pos+1):
                    return True
            else:
                for val in patch_contents[px][py]:
	                # val 在3个地方都是可用得时候,就合规
                    if val in row_contents[i] and val in col_contents[j]:
                        board[i][j] = val
                        # 从"可用数字"中移除
                        patch_contents[px][py].remove(val)
                        row_contents[i].remove(val)
                        col_contents[j].remove(val)
                        if backtarcking(pos+1):
                            return True
                        # 及时修正回来
                        board[i][j] = '.'
                        patch_contents[px][py].add(val)
                        row_contents[i].add(val)
                        col_contents[j].add(val)
 
        ## 先验知识,创建3个"可用数字"
        # 使用set()保存行、列、patch内容供查阅
        tmp = [str(x) for x in range(1,10)]
        tmp = set(tmp) # tmp = {'1', '2', ...,'9'}
        # 保存可以使用的数字,而不是已经使用的数字
        # 后者每次回溯需要从1-9检查一次后才能知道哪个没用过,时间复杂度暴增
        row_contents = [tmp.copy() for _ in range(9)] # [tmp.copy()]*9 和 [tmp]*9 都共享内存
        col_contents = [tmp.copy() for _ in range(9)]
        patch_contents = [[tmp.copy() for _ in range(3)] for _ in range(3)] # 是一个 3*3 的二维list
        for i in range(9):
            for j in range(9):
                px = i // 3   # 获取patch位置
                py = j // 3
                if board[i][j] != '.':
	                # 移除掉已经用了的,就是还没用的
                    row_contents[i].remove(board[i][j])
                    patch_contents[px][py].remove(board[i][j])
                if board[j][i] != '.':
                    col_contents[i].remove(board[j][i])
        backtarcking(0)

复习时新写的,更好理解,使用了set()函数来过滤,大体的思路是一样的。

class Solution:
    def solveSudoku(self, board: List[List[str]]) -> None:
        """
        Do not return anything, modify board in-place instead.
        """
        def bkt(i, j):
            nonlocal row_dict, col_dict, pat_dict, board
            if i == len(board):
                return True
            if board[i][j] == '.':
                pat_idx = self.get_pat_idx(i, j)
                common = row_dict[i] & col_dict[j] & pat_dict[pat_idx]
                if len(common)==0: 
                    return False
                for val in common:
                    board[i][j] = str(val)
                    row_dict[i].remove(val)
                    col_dict[j].remove(val)
                    pat_dict[pat_idx].remove(val)
                    if j+1 < len(board[0]):
                        if bkt(i, j+1): return True
                    else:
                        if bkt(i+1, 0): return True
                    board[i][j] = '.'
                    row_dict[i].add(val)
                    col_dict[j].add(val)
                    pat_dict[pat_idx].add(val)
            else:
                if j+1 < len(board[0]):
                    if bkt(i, j+1): return True
                else:
                    if bkt(i+1, 0): return True
            return False
        # 制作每个col,row,patch可以用的数字
        backup = set([i for i in range(1,10)])
        row_dict = {}
        col_dict = {}
        pat_dict = {}
        for i in range(len(board)):
            row_dict[i] = backup.copy()
            for j in range(len(board[0])):
                if j not in col_dict.keys(): col_dict[j] = backup.copy()
                pat_idx = self.get_pat_idx(i, j)
                if pat_idx not in pat_dict.keys(): pat_dict[pat_idx] = backup.copy()
                if board[i][j] != '.':
                    row_dict[i].remove(int(board[i][j]))
                    col_dict[j].remove(int(board[i][j]))
                    pat_dict[pat_idx].remove(int(board[i][j]))
 
        bkt(0,0)
 
        
    def get_pat_idx(self, i, j):
        return i // 3 * 3 + j // 3 + 1
 

启发和联系

  1. 在创建 x_contents 变量的时候,我们使用的是 x_contents=[tmp.copy() for _ in range(9)] 要十分注意,[tmp]*3[tmp.copy]*3 都是共享内存的,一个变,全都变,开始就吃亏在这
  2. 这里从 pos 到所在行、列、patch 的计算,要细心,当时这里也卡了一小会

    九宫格中本题各索引的合理范围

    • pos: 不溢出的范围是[0, 81)
    • i, 列 j :[0, 9)
    • patch的行 px 和列 py : [0,2)