点分治

点分治一般用来处理树上路径信息问题


分治

不断解决子问题,常见的套路是 $n$ 个问题,分成 $2$ 个 $\frac{n}{2}$ 的问题…..,不断缩小问题的规模,进而求解。
一般用到调和级数 ${\rm nlogn}$。

点分治

洛谷笔记

解决步骤

  • 寻找树的重心:降低复杂度。
  • 求解经过当前子树的重心的满足要求路径数量。
  • 不断递归求解子问题。

时间复杂度

$\rm O(n{log^2n})$

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<vector>
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn = 50010;

vector<pair<int,int> >g[maxn];
int n, k, u, v, c;
int dis[maxn], sz[maxn], rt, maxsz[maxn];
int ans, cnt;
int minid, minx, sum;
bool vis[maxn];

void init()
{
ans=0;
memset(vis, false, sizeof(vis));
}

void dfs1(int u, int father)
{
sz[u]=1;
maxsz[u]=0;
for(int i=0;i<int(g[u].size());i++)
{
int v=g[u][i].first;
if(v==father || vis[v]) continue;
dfs1(v, u);
sz[u]+=sz[v];
maxsz[u]=max(maxsz[u], sz[v]);
}
}

void dfs2(int u, int father)
{
int temp=max(maxsz[u], sum-maxsz[u]);
if(temp < minx)
{
minx=temp;
minid=u;
}
for(int i=0;i<int(g[u].size());i++)
{
int v=g[u][i].first;
if(v==father || vis[v]) continue;
dfs2(v, u);
}
}

void getdis(int u, int father, int d)
{
dis[++cnt]=d;
for(int i=0;i<int(g[u].size());i++)
{
int v=g[u][i].first, w=g[u][i].second;
if(v==father || vis[v]) continue;
getdis(v, u, d+w);
}

}

int getroot(int u)
{
dfs1(u, 0);
sum=sz[u];
minx=1e9;
minid=-1;
dfs2(u, 0);
return minid;
}

int look(int l,int x)//урвС╠ъ╫Г
{
int ans=0,r=cnt;
while(l<=r)
{
int mid(l+r>>1);
if(dis[mid]<x) l=mid+1;
else ans=mid,r=mid-1;
}
return ans;
}
int look2(int l,int x)//урср╠ъ╫Г
{
int ans=0,r=cnt;
while(l<=r)
{
int mid(l+r>>1);
if(dis[mid]<=x) ans=mid,l=mid+1;
else r=mid-1;
}
return ans;
}

int calc(int u, int val)
{
cnt=1;
int res=0;
getdis(u, 0, val);
sort(dis+1, dis+cnt+1);
int l=1, r=cnt;

for(int i=1;i<=cnt;i++)
{
//if(k-dis[i]<dis[i] || k-dis[i]>dis[cnt]) break;
int p1=lower_bound(dis+i+1, dis+cnt+1, k-dis[i])-dis;
int p2=upper_bound(dis+i+1, dis+cnt+1, k-dis[i])-dis;
//if(dis[p1] != k-dis[i]) continue;
p2-=1;
if(p2>=p1) res+=p2-p1+1;

}

// while(l<cnt && dis[l]+dis[cnt]<k) ++l;
// while(l<cnt && k-dis[l]>=dis[l])
// {
// int D1(look(l+1,k-dis[l])),D2(look2(l+1,k-dis[l]));
// if(D2>=D1) res+=D2-D1+1;
// ++l;
// }
return res;
}

void solve(int u)
{
int root=getroot(u);
int num=calc(root, 0);
ans += num;
vis[root]=1;
for(int i=0;i<int(g[root].size());i++)
{
int v=g[root][i].first, w=g[root][i].second;
if(vis[v]) continue;
int num=calc(v, w);
ans-=num;
solve(v);
}
}

int main()
{
while(scanf("%d%d", &n, &k)!=EOF &&n &&k)
{
init();
for(int i=1;i<=n-1;i++)
{
scanf("%d%d", &u, &v);
g[u].push_back(make_pair(v, 1));
g[v].push_back(make_pair(u, 1));
}
solve(1);
printf("%d\n", ans);
for(int i=1;i<=n;i++)
g[i].clear();
}
}