记忆化搜索
视频讲解
🎥 视频讲解
记忆化搜索
记忆化搜索是在DFS的基础上,用数组或哈希表记录已经计算过的状态,使每个状态只计算一次。
例题1
代码实现
参考实现
#include<bits/stdc++.h>
using namespace std;
const int Mod = 998244353;
const int N = 1e5 + 10;
int dp[N];
int fib(int n) {
if (n == 0) return 0;
if (n == 1) return 1;
if (dp[n] != -1) return dp[n];
return dp[n] = (fib(n - 1) + fib(n - 2)) % Mod;
}
int main() {
int t;
cin >> t;
memset(dp, -1, sizeof dp);
while (t--) {
int n;
cin >> n;
cout << fib(n) << "\n";
}
return 0;
}
import java.util.Scanner;
import java.util.Arrays;
public class Main {
static final int Mod = 998244353;
static final int N = 100010;
static int[] dp = new int[N];
static int fib(int n) {
if (n == 0) return 0;
if (n == 1) return 1;
if (dp[n] != -1) return dp[n];
dp[n] = (fib(n - 1) + fib(n - 2)) % Mod;
return dp[n];
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int t = sc.nextInt();
Arrays.fill(dp, -1);
while (t-- > 0) {
int n = sc.nextInt();
System.out.println(fib(n));
}
}
}
import sys
sys.setrecursionlimit(200000)
Mod = 998244353
N = 100010
dp = [-1] * N
def fib(n: int) -> int:
if n == 0:
return 0
if n == 1:
return 1
if dp[n] != -1:
return dp[n]
dp[n] = (fib(n - 1) + fib(n - 2)) % Mod
return dp[n]
t = int(input())
for _ in range(t):
n = int(input())
print(fib(n))
例题2
代码实现
参考实现
#include <bits/stdc++.h>
using namespace std;
const int N = 3010;
const int mod = 1000000007;
int dp[N][N];
int C(int a, int b) {
if (a == b || b == 0) return 1;
if (dp[a][b] != -1) return dp[a][b];
return dp[a][b] = (C(a - 1, b) + C(a - 1, b - 1)) % mod;
}
int main() {
int n, m;
cin >> n >> m;
memset(dp, -1, sizeof(dp));
cout << C(n, m);
return 0;
}
import java.util.Scanner;
import java.util.Arrays;
public class Main {
static final int N = 3010;
static final int mod = 1000000007;
static int[][] dp = new int[N][N];
static int C(int a, int b) {
if (a == b || b == 0) return 1;
if (dp[a][b] != -1) return dp[a][b];
return dp[a][b] = (C(a - 1, b) + C(a - 1, b - 1)) % mod;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
for (int i = 0; i < N; i++) {
Arrays.fill(dp[i], -1);
}
System.out.println(C(n, m));
}
}
import sys
sys.setrecursionlimit(10000)
mod = 1000000007
N = 3010
dp = [[-1] * N for _ in range(N)]
def C(a: int, b: int) -> int:
if a == b or b == 0:
return 1
if dp[a][b] != -1:
return dp[a][b]
dp[a][b] = (C(a - 1, b) + C(a - 1, b - 1)) % mod
return dp[a][b]
n, m = map(int, input().split())
print(C(n, m))
练习
练习1
请将以下的搜索代码改成记忆化搜索并通过上题。
代码实现
参考实现
#include<bits/stdc++.h>
using namespace std;
const int N = 1010;
int a[N][N];
int n;
int dfs(int x, int y) {
if (x == n) return a[x][y];
return max(dfs(x + 1, y), dfs(x + 1, y + 1)) + a[x][y];
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= i; j++) {
cin >> a[i][j];
}
}
cout << dfs(1, 1);
return 0;
}
import java.util.Scanner;
public class Main {
static final int N = 1010;
static int[][] a = new int[N][N];
static int n;
static int dfs(int x, int y) {
if (x == n) return a[x][y];
return Math.max(dfs(x + 1, y), dfs(x + 1, y + 1)) + a[x][y];
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= i; j++) {
a[i][j] = sc.nextInt();
}
}
System.out.println(dfs(1, 1));
}
}
import sys
sys.setrecursionlimit(10000)
N = 1010
a = [[0] * N for _ in range(N)]
n = int(input())
for i in range(1, n + 1):
row = list(map(int, input().split()))
for j in range(1, i + 1):
a[i][j] = row[j - 1]
def dfs(x: int, y: int) -> int:
if x == n:
return a[x][y]
return max(dfs(x + 1, y), dfs(x + 1, y + 1)) + a[x][y]
print(dfs(1, 1))
练习2
代码实现
参考实现
#include<bits/stdc++.h>
using namespace std;
const int N = 5005;
int n;
int a[N];
int dfs(int u) {
int best = 1;
for (int i = u + 1; i <= n; i++) {
if (a[i] > a[u]) {
best = max(best, dfs(i) + 1);
}
}
return best;
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
int ans = 0;
for (int i = 1; i <= n; i++) {
ans = max(ans, dfs(i));
}
cout << ans;
return 0;
}
import java.util.Scanner;
public class Main {
static final int N = 5005;
static int n;
static int[] a = new int[N];
static int dfs(int u) {
int best = 1;
for (int i = u + 1; i <= n; i++) {
if (a[i] > a[u]) {
best = Math.max(best, dfs(i) + 1);
}
}
return best;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
for (int i = 1; i <= n; i++) {
a[i] = sc.nextInt();
}
int ans = 0;
for (int i = 1; i <= n; i++) {
ans = Math.max(ans, dfs(i));
}
System.out.println(ans);
}
}
import sys
sys.setrecursionlimit(1000000)
N = 5005
a = [0] * N
n = int(input())
arr = list(map(int, input().split()))
for i in range(1, n + 1):
a[i] = arr[i - 1]
def dfs(u: int) -> int:
best = 1
for i in range(u + 1, n + 1):
if a[i] > a[u]:
best = max(best, dfs(i) + 1)
return best
ans = 0
for i in range(1, n + 1):
ans = max(ans, dfs(i))
print(ans)
练习3
代码实现
参考实现
#include<bits/stdc++.h>
using namespace std;
const int N = 5005;
string s1, s2;
int n, m;
int dfs(int x, int y) {
if (x > n || y > m) {
return 0;
}
if (s1[x] == s2[y]) {
return dfs(x + 1, y + 1) + 1;
}
return max(dfs(x + 1, y), dfs(x, y + 1));
}
int main() {
cin >> s1 >> s2;
n = s1.size();
m = s2.size();
s1 = " " + s1;
s2 = " " + s2;
cout << dfs(1, 1);
return 0;
}
import java.util.Scanner;
public class Main {
static final int N = 5005;
static String s1, s2;
static int n, m;
static int dfs(int x, int y) {
if (x > n || y > m) {
return 0;
}
if (s1.charAt(x) == s2.charAt(y)) {
return dfs(x + 1, y + 1) + 1;
}
return Math.max(dfs(x + 1, y), dfs(x, y + 1));
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
s1 = sc.next();
s2 = sc.next();
n = s1.length();
m = s2.length();
s1 = " " + s1;
s2 = " " + s2;
System.out.println(dfs(1, 1));
}
}
import sys
sys.setrecursionlimit(1000000)
s1 = ""
s2 = ""
n = 0
m = 0
s1 = input().strip()
s2 = input().strip()
n = len(s1)
m = len(s2)
s1 = " " + s1
s2 = " " + s2
def dfs(x: int, y: int) -> int:
if x > n or y > m:
return 0
if s1[x] == s2[y]:
return dfs(x + 1, y + 1) + 1
return max(dfs(x + 1, y), dfs(x, y + 1))
print(dfs(1, 1))