笔记:蓝桥杯python搜索(3-2)——DFS剪支和记忆化搜索
目录
一、DFS剪支
二、例题
P2942 数字王国之军训军队
P3075 特殊的多边形
三、记忆化搜索
四、例题
例题 P3820 混境之地
P216 地宫取宝
一、DFS剪支
- 在搜索过程中,如果需要完全遍历所有情况可能需要很多时间
- 在搜索到某种状态时,根据当前状态判断出后续无解,则该状态无需继续深入搜索
- 例如:给定N个正整数,求出有多少个子集之和小于等于K。在搜索过程中当前选择的数字和已经超过K则不需要继续搜索。
- 可行性剪枝:当前状态和题意不符,并且往后的所有情况和题意都不符,那么就可以进行剪枝。
- 最优性剪枝:在搜索过程中,当前状态已经不如已经找到的最优解,也可以剪枝,不需要继续搜索。
二、例题
P2942 数字王国之军训军队
数字王国开学了,它们也和我们人类一样有开学前的军训,现在一共有 n 名学生,每个学生有自己的一个名字 ai(数字王国里的名字就是一个正整数,注意学生们可能出现重名的情况),此时叛逆教官来看了之后感觉十分别扭,决定将学生重新分队。
排队规则为:将学生分成若干队,每队里面至少一个学生,且每队里面学生的名字不能出现倍数关系(注意名字相同也算是倍数关系)。
现在请你帮忙算算最少可以分成几队?
例:有 4 名学生 (2,3,4,4),最少可以分成 (2,3)、(4)、(4) 共 3 队。
输入格式
第一行包含一个正整数 n,表示学生数量。
第二行包含 n 个由空格隔开的整数,第 i 个整数表示第 i 个学生的名字 ai。
输出格式
输出共 1 行,包含一个整数,表示最少可以分成几队。
DFS 搜索,枚举每个学生分到每个组内
可行性剪枝:要满足题目条件
最优性剪枝:判断当前状态是否比 ans 更劣
def check(x,group):
for y in group:
if x%y ==0 or y%x==0:
return False
return True
def dfs(depth):
if depth==n:
global answer
answer=min(answer,len(Groups))
return
for every_group in Groups:
if check(a[depth],every_group):
every_group.append(a[depth])
dfs(depth+1)
every_group.pop()
Groups.append([a[depth]])
dfs(depth+1)
Groups.pop()
n=int(input())
a=list(map(int,input().split()))
Groups=[]
answer=n
dfs(0)
print(answer)
# 判断x能否加入group组
def check(x, group):
# 要保证不能存在倍数关系
for y in group:
if x % y == 0 or y % x == 0:
return False
return True
# depth表示当前为第depth个学生
def dfs(depth):
global ans
# 最优性剪枝:当前已经比ans大,说明该策略不可行
if len(Groups) > ans:
return
if depth == n:
ans = min(len(Groups), ans)
return
for each_group in Groups:
# 枚举第depth个学生能否加入当前组each_group
# 剪枝:必须满足题意
if check(a[depth], each_group):
each_group.append(a[depth])
dfs(depth + 1)
each_group.pop()
# 单独作为一组,将 a[depth] 作为一个新的分组添加,即 Groups.append([a[depth]])
Groups.append([a[depth]])
dfs(depth + 1)
Groups.pop()
n = int(input())
a = list(map(int, input().split()))
# ans表示最少能分多少队
ans = n
# Groups表示分组情况
Groups = []
dfs(0)
print(ans)
P3075 特殊的多边形
假设一个 n 边形 n 条边为 a1,a2,a3,⋯,an,定义该 n 边形的值 v=a1×a2×a3×⋯×an。
定义两个 n 边形不同是指至少有一条边的长度在一个 n 边形中有使用而另一个 n 边形没有用到,如 n 边形 (3,4,5,6)和 (3,5,4,6) 是两个相同的 n 边形,(3,4,5,6)和 (4,5,6,7) 是两个不相同的 n 边形。
现在有 t 和 n,表示 t 个询问并且询问的是 n 边形,每个询问给定一个区间 [l,r],问有多少个 n 边形(要求该 n 边形自己的 n 条边的长度互不相同)的值在该区间范围内。
输入格式
第一行包含两个正整数 t、n,表示有 t 个询问,询问的是 n 边形。
接下来 t 行,每行有两个空格隔开的正整数 l、r,表示询问区间 [l,r]。
输出格式
输出共 t 行,第 i行对应第 i 个查询的 n 边形个数。
-
先考虑简单版:乘积为 v 有多少种 n 边形
-
DFS 处理出所有乘积对应的所有可能
-
维护一个递增的边长序列(唯一性)
-
枚举第 i 边的长度,最小最大范围(剪枝)
-
最终 check 是否满足 N 边形:
-
最小的 N - 1 条边之和大于第 N 边
-
-
-
预处理+前缀和 O(1)查询答案
import os
import sys
# 请在此输入您的代码
def dfs(depth, last_val, tot, mul):
"""
:param depth: 第depth条边长
:param last_val: 上一边长长度
:param tot: 累计和
:param mul: 累计乘积
"""
if depth == n:
# 前n-1条边之和大于第n条边
if tot - path[-1] > path[-1]:
ans[mul] += 1
return
for i in range(last_val + 1, 100001):
# 最优性剪枝, 后续还有n-depth个数字, 每个数字都要>=i
# 累计乘积要不超过100000: mul * (i ** (n - depth))
if mul * (i ** (n - depth)) > 100000:
break
path.append(i)
dfs(depth + 1, i, tot + i, mul * i)
path.pop()
t, n = map(int, input().split())
ans = [0] * 100001
path = []
dfs(0, 0, 0, 1)
for i in range(1, 100001):
ans[i] += ans[i - 1]
for _ in range(t):
l, r = map(int, input().split())
print(ans[r] - ans[l - 1])
三、记忆化搜索
- 记忆化:通过记录已经遍历过的状态的信息,从而避免对同一状态重复遍历的搜索实现方式。
- 记忆化=dfs+额外字典
- 如果先前已经搜索过:直接查字典,返回字典中结果
- 如果先前没有搜索过:继续搜索,最终将该状态结果记录到字典中
- 斐波那契数列:设F[0] = 1, F[1] = 1, F[n] = F[n - 1] + F[n - 2],求F[n],结果对1e9 + 7取模。 0 <= n <= 10000
- 样例输入: 5000
- 样例输出: 976496506
def f(n):
if n==0 or n==1:
return 1
return (f(n-1)+f(n-2))%(1e9+7)
n=int(input())
print(f(n))
# 直接递归存在大量重复计算
计算 F(5)
时,要先算 F(4)
和 F(3)
;算 F(4)
又需算 F(3)
和 F(2)
,这里 F(3)
就被重复计算了。随着 n
增大,像 F(2)
、F(3)
这类中间结果会被反复多次计算,导致时间复杂度呈指数级增长,效率极低。
- 每次搜索时将当前状态答案记录到字典中
- 后续搜索直接返回结果
import sys
sys.setrecursionlimit(100000)
# 记忆化1
dic={0:1,1:1}
def f(n):
if n in dic.keys():
return dic[n]
dic[n]=(f(n-1)+f(n-2))%1000000007
return dic[n]
n=int(input())
print(f(n))
from functools import lru_cache #记忆化搜索2 @lru_cache(maxsize=None)
from functools import lru_cache
# 记忆化2
@lru_cache(maxsize=None)
def f(n):
if n==0 or n==1:
return 1
return f(n-1)+f(n-2)
n=int(input())
print(f(n))
四、例题
例题 P3820 混境之地
小蓝有一天误入了一个混境之地。
好消息是:他误打误撞拿到了一张地图,并从中获取到以下信息:
混境之地是一个 n⋅m 大小的矩阵,其中第 i 行第 j 列的的点 hij 表示第 i 行第 j 列的高度。
他现在所在位置的坐标为 (A,B) ,而这个混境之地出口的坐标为 (C,D) ,当站在出口时即表示可以逃离混境之地。
小蓝有一个喷气背包,使用时,可以原地升高 k 个单位高度。
坏消息是:
由于小蓝的体力透支,所以只可以往低于当前高度的方向走。
喷漆背包燃料不足,只可以最后使用一次。
小蓝可以往上下左右四个方向行走,不消耗能量。
小蓝想知道他能否逃离这个混境之地,如果可以逃离这里,输入
Yes
,反之输出No
。
输入格式
第 1 行输入三个正整数 n,m 和 k , n,m 表示混境之地的大小, k 表示使用一次喷气背包可以升高的高度。
第 2 行输入四个正整数 A,B,C,D ,表示小蓝当前所在位置的坐标,以及混境之地出口的坐标。
第 3 行至第 n+2 行,每行 m 个整数,表示混境之地不同位置的高度。
输出格式
输出数据共一行一个字符串:
若小蓝可以逃离混境之地,则输出
Yes
。若小蓝无法逃离混境之地,则输出
No
。
走到(x,y),z表示是否使用喷气背包
当x,y,z固定时,具有唯一解,因此可以使用记忆化搜索
时间复杂度<x,y,z>三元组数量:1000*1000*2
from functools import lru_cache
#记忆化搜索2
@lru_cache(maxsize=None)
def dfs(x, y, z):
# print(x, y, z)
# 当前处于(x,y), z表示是否使用喷气背包
# 如果能逃离, 返回True, 否则返回False
if x == C - 1 and y == D - 1:
return True
#没走到终点,四个方向判断
for i in range(4):
xx, yy = x + dir[i][0], y + dir[i][1]
#不能越界
if xx < 0 or xx >= n or yy < 0 or yy >= m:
continue
#新坐标比旧坐标低,走到新坐标(xx,yy)
if Map[xx][yy] < Map[x][y]:
if dfs(xx, yy, z):
return True
#在(x,y)处使用喷气背包
elif Map[xx][yy] < Map[x][y] + k and z == False:
if dfs(xx, yy, True):
return True
return False
# 四个方向
dir = [(1, 0), (0, 1), (-1, 0), (0, -1)]
n, m, k = map(int, input().split())
A, B, C, D = map(int, input().split())
Map = []
for i in range(n):
Map.append(list(map(int, input().split())))
if dfs(A - 1, B - 1, False):
print("Yes")
else:
print("No")
P216 地宫取宝
X 国王有一个地宫宝库。是 n×m 个格子的矩阵。每个格子放一件宝贝。每个宝贝贴着价值标签。
地宫的入口在左上角,出口在右下角。
小明被带到地宫的入口,国王要求他只能向右或向下行走。
走过某个格子时,如果那个格子中的宝贝价值比小明手中任意宝贝价值都大,小明就可以拿起它(当然,也可以不拿)。
当小明走到出口时,如果他手中的宝贝恰好是 k 件,则这些宝贝就可以送给小明。
请你帮小明算一算,在给定的局面下,他有多少种不同的行动方案能获得这 k 件宝贝。
输入描述
输入一行 3 个整数,用空格分开:n,m,k (1≤n,m≤50,1≤k≤12)。
接下来有 n 行数据,每行有m 个整数 Ci (0≤Ci≤12) 代表这个格子上的宝物的价值。
输出描述
要求输出一个整数,表示正好取 k 个宝贝的行动方案数。该数字可能很大,输出它对 10^9+7 取模的结果。
从(x,y)出发,先前已经有宝物z件,已有的最大宝物价值为w的方案数记为dfs(x,y,z,w)
只要确定四元组,就确定当前的方案数
答案=dfs(1,1,0,-1)
最终点=dfs(n,m,z,?) or dfs(n,m,z - 1,?)
(x,y)处可选,可不选,然后可以往右走或者往下走
from functools import lru_cache
@lru_cache(maxsize=None)
def dfs(x, y, z, w):
#从(x,y)出发,先前已有z件宝物,已有宝物最大价值为w
#走到终点
if x == n and y == m:
#当前不需要选择
if z == k:
return 1
#当前需要选择
if z == k - 1:
if w < a[x][y]:
return 1
#其他情况,均为0
return 0
ans = 0
#遍历两个方向
for delta_x, delta_y in [(1, 0), (0, 1)]:
xx, yy = x + delta_x, y + delta_y
if xx <= n and yy <= m:
# 当前不选择
ans += dfs(xx, yy, z, w)
# 当前选择
if w < a[x][y]:
ans += dfs(xx, yy, z + 1, a[x][y])
return ans % 1000000007
n, m, k = map(int, input().split())
a = [[0]*(m+1)]
for i in range(n):
a.append([0] + list(map(int, input().split())))
print(dfs(1, 1, 0, -1))