数位dp入门详解
1. 介绍
数位 d p dp dp一般出现在来求一个范围 [ a , b ] [a, b] [a,b]内满足条件的数有多少。数位 d p dp dp的解决比较公式化,考虑每一位对最终答案的影响。
2. 案例
Luogu P2602: 求给定范围 [ a , b ] [a,b] [a,b]各个数位 k k k出现了多少次。
考虑
n
n
n的
10
10
10进制表示
n
=
∑
i
=
0
m
a
i
1
0
i
n=\sum_{i=0}^{m}a_i 10^{i}\\
n=i=0∑mai10i
我们令
n
n
n的最高位数的数字为
a
m
a_m
am,最高幂次为
m
m
m
m
=
⌊
lg
n
⌋
a
m
=
⌊
n
/
(
1
0
m
)
⌋
m = \lfloor \lg n \rfloor \\ a_m =\lfloor n /( 10^{m})\rfloor
m=⌊lgn⌋am=⌊n/(10m)⌋
一个数
n
n
n中
k
k
k出现的次数
d
i
g
i
t
C
n
t
(
n
,
k
)
=
∑
i
=
0
m
ϕ
(
a
i
)
ϕ
(
x
)
=
{
1
,
a
i
=
k
0
,
a
i
≠
k
digitCnt(n, k) =\sum_{i=0}^m \phi(a_i)\\ \phi(x) = \begin{cases} 1, \quad a_i=k\\ 0, \quad a_i \ne k \end{cases} \\
digitCnt(n,k)=i=0∑mϕ(ai)ϕ(x)={1,ai=k0,ai=k
最终我们需要求
a
n
s
(
a
,
b
)
=
∑
x
=
a
b
d
i
g
i
t
C
n
t
(
x
)
\\ ans(a,b) = \sum_{x=a}^{b} digitCnt(x)
ans(a,b)=x=a∑bdigitCnt(x)
如果我们直接进行遍历,时间复杂度为
O
(
n
log
n
)
O(n \log n)
O(nlogn),对于稍稍大一点的情况是处理不了的,这时候我们的数位dp出场了。
我们定义
d
p
[
n
]
[
k
]
dp[n][k]
dp[n][k]为
n
n
n以内的
k
k
k位出现的次数,那么我们可以把
n
n
n以内的数分为两个部分
d
p
[
n
]
=
a
n
s
(
1
,
a
m
1
0
m
−
1
)
+
a
n
s
(
a
m
1
0
m
,
n
)
dp[n] =ans(1,a_{m}10^{m}-1) +\\ans(a_{m}10^{m}, n)
dp[n]=ans(1,am10m−1)+ans(am10m,n)
考虑最高位对最终增加的数为
M
s
b
C
n
t
(
n
,
k
)
=
{
1
0
m
,
a
m
>
k
0
,
k
=
0
∨
a
m
<
k
n
−
a
m
1
0
m
+
1
,
a
m
=
k
MsbCnt(n,k) = \begin{cases} 10^{m} \quad , a_{m} > k \\ 0\quad, k=0 \vee a_{m} <k\\ n - a_{m}10^{m}+1 \quad,a_{m} =k \end{cases}
MsbCnt(n,k)=⎩
⎨
⎧10m,am>k0,k=0∨am<kn−am10m+1,am=k
我们先忽略前导
0
0
0, 除去最高位的贡献后可以得到
d
p
[
n
]
=
M
s
b
C
n
t
(
n
)
+
a
m
T
[
1
0
m
−
1
]
+
T
[
n
−
a
m
1
0
m
]
dp[n]=MsbCnt(n)+a_m T[10^{m} -1]\\+T[n-a_m10^{m}]
dp[n]=MsbCnt(n)+amT[10m−1]+T[n−am10m]
这里的
T
[
n
]
T[n]
T[n]定义为
T
[
n
]
=
d
i
g
i
t
C
n
t
(
S
{
n
}
)
S
{
n
}
:
=
{
s
0
s
1
⋯
s
m
−
1
,
0
≤
s
i
≤
9
,
S
≤
n
−
a
m
1
0
m
}
T[n]= digitCnt(S\{n\})\\ \quad S\{n\}:= \{s_0s_1\cdots s_{m-1} \quad ,0 \le s_i \le 9\\ ,S \le n- a_m 10^{m}\}
T[n]=digitCnt(S{n})S{n}:={s0s1⋯sm−1,0≤si≤9,S≤n−am10m}
通俗一点就是比如
n
=
123
n=123
n=123
那么 S { 123 } = { 000 , 001 , ⋯ , 123 } S\{123\} = \{000,001,\cdots,123\} S{123}={000,001,⋯,123}。
也就是把 n n n之前的数列出,并加上前导 0 0 0使之与 n n n的位数对齐。
前导 0 0 0的添加并不会影响非 0 0 0的其他数位置的计数,因此可以得到
d p [ n ] [ k ] = T [ n ] [ k ] , k ≠ 0 dp[n][k] =T[n][k], k\ne0 dp[n][k]=T[n][k],k=0
我们考虑
T
[
1
0
d
−
1
]
[
k
]
T[10^{d} -1][k]
T[10d−1][k]的值,
S
{
1
0
d
−
1
}
S\{10^{d}-1\}
S{10d−1}集合中总共有
1
0
d
10^{d}
10d个数,每个数有
d
d
d位,而每个数字在排列中等可能出现,因此
T
[
1
0
d
−
1
]
[
1
]
=
⋯
=
T
[
1
0
d
−
1
]
[
9
]
=
d
×
1
0
d
−
1
T[10^{d}-1][1]=\cdots=T[10^{d}-1][9] \\=d \times10^{d-1}
T[10d−1][1]=⋯=T[10d−1][9]=d×10d−1
我们再考虑减去前导
0
0
0, 容易得到
1
0
d
−
1
10^{d}-1
10d−1内的数在
S
{
1
0
d
−
1
}
S\{10^{d}-1\}
S{10d−1}集合中的前导
0
0
0的个数为
F r o n t Z e r o ( S { 1 0 d − 1 } ) = ∑ i = 1 d − 1 1 0 i FrontZero(S\{10^{d}-1\}) =\sum_{i=1}^{d-1}10^{i} FrontZero(S{10d−1})=i=1∑d−110i
因此
T
[
1
0
d
−
1
]
[
0
]
=
∑
i
=
1
d
−
1
(
d
−
1
)
1
0
i
+
d
T[10^d-1][0] =\sum_{i=1}^{d-1} (d-1)10^{i} +d
T[10d−1][0]=i=1∑d−1(d−1)10i+d
对于一个数
n
n
n而言, 在它之前有
k
≠
0
k \ne 0
k=0的个数为
d
p
[
n
]
[
k
]
=
M
o
s
t
C
n
t
(
n
,
k
)
+
m
a
m
1
0
m
−
1
+
d
p
[
n
−
a
m
1
0
m
]
[
k
]
dp[n][k] = MostCnt(n,k)+ma_{m}10^{m-1}+\\ dp[n-a_m10^{m}][k ]
dp[n][k]=MostCnt(n,k)+mam10m−1+dp[n−am10m][k]
如果
k
=
0
k =0
k=0, 还需要减去前导
0
0
0
d
p
[
n
]
[
0
]
=
T
[
n
]
[
0
]
−
∑
i
=
1
m
1
0
m
dp[n][0] =T[n][0]-\sum_{i=1}^{m}10^{m}
dp[n][0]=T[n][0]−i=1∑m10m
- 代码一
#include <iostream>
#include <vector>
#include <functional>
#include <unordered_set>
constexpr static int BASE = 10;
constexpr static int MAX_POW = 12;
unsigned long long typeVal;
using int_type = decltype(typeVal);
int MaxPowNotGreater(int_type BASE, int_type v) {
int ans = 1;
auto tb = BASE;
for ( ;tb <= v; tb *= BASE, ans++) {
}
return ans - 1;
}
int_type getDigitCntUntil(int_type val, int_type k) {
int_type v = val;
int digitCnt = MaxPowNotGreater(BASE, v);
int_type mod = val % BASE;
int_type ans = ((mod >= k))? 1 : 0;
v /= BASE;
int_type cpow = BASE;
for (int d = 1; d < digitCnt + 1; d++, v/= BASE, cpow *= BASE) {
int_type m = v % BASE;
if ( m > k)
ans += cpow;
else if ( m == k)
ans += mod + 1;
else {
}
ans += m * (cpow / BASE) * d;
mod += cpow * m;
}
if ( k == 0) {
cpow /= BASE;
while (cpow >= 1) {
ans -= cpow;
cpow /= BASE;
}
}
return ans;
}
int main()
{
int_type a = 1;
int_type b = 99;
std::cin >> a >> b;
std::vector<int_type> cal(BASE, 0);
for (int i = 0;i < 10;i++) {
cal[i] = getDigitCntUntil( b,i) - getDigitCntUntil(a - 1, i);
}
for (auto num:cal) {
std::cout << num << " ";
}
std::cout << std::endl;
return 0;
}
我们可以将整次幂的数位置个数存起来
T
C
n
t
[
d
]
[
k
]
=
T
[
1
0
d
−
1
]
[
k
]
TCnt[d][k] = T[10^{d}-1][k]
TCnt[d][k]=T[10d−1][k]
容易得到得到递推关系式
T
C
n
t
[
d
]
[
k
]
=
{
1
0
d
−
1
+
T
C
n
t
[
d
−
1
]
[
k
]
,
k
≠
0
9
T
C
n
t
[
d
−
1
]
[
1
]
+
T
C
n
t
[
d
−
1
]
[
0
]
,
k
=
0
TCnt[d][k] = \\ \begin{cases} 10^{d-1} +TCnt[d-1][k],\quad k \ne 0 \\ 9TCnt[d-1][1] + TCnt[d-1][0],\quad k =0 \end{cases}
TCnt[d][k]={10d−1+TCnt[d−1][k],k=09TCnt[d−1][1]+TCnt[d−1][0],k=0
- 代码2
#include <iostream>
#include <vector>
#include <functional>
#include <unordered_set>
constexpr static int BASE = 10;
constexpr static int MAX_POW = 12;
unsigned long long typeVal;
using int_type = decltype(typeVal);
auto fpow = [](int_type base, int_type cnt) {
int_type ans = 1;
while (cnt) {
if (cnt & 1) ans *= base;
base *= base;
cnt = cnt >> 1;
}
return ans;
};
auto LogFloor = [](int_type base, int_type v) {
int_type m = base;
int_type kdigits = 1;
while ( m <= v) {
m *= base;
kdigits++;
}
kdigits--;
return kdigits;
};
template<typename T>
void TEST_EQ( T a, T b) {
bool ret = (a == b);
if (!ret) {
std::cout << a << " NOT EQUAL " << b << '\n';
}
else {
std::cout << a << " EQUAL " << b << '\n';
}
}
void testEqual(int_type a, int_type b, const std::vector<int_type>& cal) {
std::vector<int> tmpCnt(BASE, 0);
for (int_type i = a; i <= b; i++) {
auto ti = i;
while (ti) {
tmpCnt[ti % BASE]++;
ti /= BASE;
}
}
bool ok = true;
for (int i = 0; i < BASE; i++) {
if (cal[i] != tmpCnt[i]) {
std::cout << i << "failed: Real=" << tmpCnt[i] << " ;Cal= " << cal[i] <<'\n';
ok = false;
}
}
if (ok) {
std::cout << "ok fine result!!!\n";
}
}
std::vector<std::vector<int_type>> FCnt( MAX_POW + 1, std::vector<int_type>(BASE, 0));
int_type getDigitCntUntil(int_type val, int_type k) {
int_type v = val;
int digitCnt = LogFloor(BASE, v);
int_type mod = val % BASE;
int_type ans = ((mod >= k) && k)? 1 : 0;
v /= BASE;
for (int d = 1; d < digitCnt + 1;d++, v /= BASE) {
int_type lsb = v % BASE ;
if (d != digitCnt) {
ans += lsb * fpow(BASE, d - 1) * d;
if ( lsb > k)
ans += fpow(BASE, d);
else if (lsb == k)
ans += mod + 1;
}
else {
ans += FCnt[digitCnt][k];
ans += (lsb - 1) * fpow(BASE, d - 1) * d;
if (0 != k) {
if ( lsb > k)
ans += fpow(BASE, d);
else if (k == lsb)
ans += mod + 1;
else
ans += 0;
}
}
mod += lsb * fpow(BASE, d);
}
return ans;
}
int main()
{
for (int i = 0;i < BASE;i++) {
FCnt[1][i] = 1;
}
for (int i = 2;i <= MAX_POW;i++) {
for (int d = 0; d < BASE;d++) {
if (d == 0)
FCnt[i][d] = 9 * FCnt[i - 1][1] + FCnt[i - 1][0];
else {
FCnt[i][d] = fpow(10, i - 1) + 10 * FCnt[i - 1][d];
}
}
}
int_type a = 1;
int_type b = 99;
std::cin >> a >> b;
std::vector<int_type> cal(BASE, 0);
for (int i = 0;i < 10;i++) {
cal[i] = getDigitCntUntil( b,i) - getDigitCntUntil(a - 1, i);
}
for (auto num:cal) {
std::cout << num << " ";
}
std::cout << std::endl;
return 0;
}