leetcode 1289

问题

Given a square grid of integers arr, a falling path with non-zero shifts is a choice of exactly one element from each row of arr, such that no two elements chosen in adjacent rows are in the same column.

Return the minimum sum of a falling path with non-zero shifts.

Example 1:

1
2
3
4
5
6
7
8
Input: arr = [[1,2,3],[4,5,6],[7,8,9]]
Output: 13
Explanation:
The possible falling paths are:
[1,5,9], [1,5,7], [1,6,7], [1,6,8],
[2,4,8], [2,4,9], [2,6,7], [2,6,8],
[3,4,8], [3,4,9], [3,5,7], [3,5,9]
The falling path with the smallest sum is [1,5,7], so the answer is 13.

Constraints:

  • 1 <= arr.length == arr[i].length <= 200
  • -99 <= arr[i][j] <= 99

分析

思路1: $dp[i][j]$ 表示从第 $0$ 行到第 $i$ 行,最后停止在位置 $j$ 上的最短路径。每次找到最小的 $dp[i-1][k] \; (k \neq j)$ 然后加上 $arr[i][j]$ 即可。时间复杂度 $O(n^3)$

思路2: 虽短时间复杂度。针对上一行寻找最小值的情况,我们可以记录上一行的最小值 $fm$ ,上一行的次小值 $sm$ ,以及上一行最小值的下标 $k$ ,对于 $dp[i][j] $ ,$j \neq k$ 时,使用 $fm + arr[i][j]$ ,当 $j == k$ 时,使用 $sm + arr[i][j]$ 。处理完一行的同时更新 $fm, sm, k$。时间复杂度为 $O(n^2)$。

代码1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// O(n^3)
class Solution {
public:
int minFallingPathSum(vector<vector<int>>& arr) {
int m = arr.size();
int n = arr[0].size();
for (int i = 1; i < m; ++i) {
for (int j = 0; j < n; ++j) {
int t = INT_MAX;
for (int k = 0; k < n; ++k) {
if (k == j) continue;
t = min(t, arr[i-1][k]);
}
arr[i][j] += t;
}
}
return *min_element(arr[n-1].begin(), arr[n-1].end());
}
};

代码2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// O(n^2)
class Solution {
public:
int minFallingPathSum(vector<vector<int>>& arr) {
int m = arr.size();
int n = arr[0].size();
int fm = 0, sm = 0, k = -1;
for (int i = 0; i < m; ++i) {
int fm2 = INT_MAX, sm2 = INT_MAX, k2 = 0;
for (int j = 0; j < n; ++j) {
arr[i][j] += (j!=k? fm: sm);
if (arr[i][j] < fm2) {
sm2 = fm2;
fm2 = arr[i][j];
k2 = j;
} else if (arr[i][j] < sm2) {
sm2 = arr[i][j];
}
}
fm = fm2;
sm = sm2;
k = k2;
}
return *min_element(arr[n-1].begin(), arr[n-1].end());
}
};