此生此世做过的第一恶心的树型DP


题意:

很简单:就是要求一棵树上相距距离不超过k的点有几对。多组数据,n<=10000

分析:

看到不能$O(n^2)$赶紧想有没有带$log$的算法。没想到。于是百度。网上说是树的点分治,大概看懂了。

算法思想:

对于一棵树,其中任何一条路径都有:要么经过根,要么不经过根。于是我们把这棵树的根节点揪出来,寻找所有经过根的合法路径,然后把这个根删掉,得到很多子树,对子树再进行上述计算。由于新的子树上的路径绝对不经过旧的大树的根,所以路径不会有重复。又由于很显然地每一条路径都只能属于一棵子树也必然属于一棵树,所以这种方法是正确的。经过合理的选择根节点,我们只计算了$logn$层,所以这种方法的复杂度为$O(logn)*P(n)$,$P(n)$的复杂度取决于你寻找路径的算法的好坏。

下面我们讨论如何寻找经过一个树的根的路径总数。
首先,我们可以获得这棵树中所有节点到根的距离。那么所有和小于k的点对有可能构成合法路径的两端。为什么只是有可能呢?因为两个点有可能出现在这棵树的同一个子树上,他们构成的路径不经过根。
设$A$为满足$dis[x]+dis[y]\leq k$的$(x,y)$的数量,$B$为满足$dis[x]+dis[y]\leq k$且$x,y$所在子树相同的$(x,y)$数量,那么这棵树中经过根节点的路径条数就是$A-B$。
我们可以$O(n)$求出$dis[i]$,$O(nlogn)$将$dis[]$排序,再$O(n)$利用单调性找出每一个x所对应的$dis$最大的$y$(当$dis[x]$增加时,$dis[y]$不会增加,呈单调递减),也就是$A$。对于这棵树的每一棵子树,我们又可以$O(n)$求出所有$B$。然后$A-B$就是答案。这样的理想时间复杂度为$O(nlognlogn)$

还需要注意几点:

1.单纯的将子树分治是不可行的,因为出题人会`专门把树扯成一条链让你$O(n^2logn)$地吃屎。于是我们需要专门用一个DP找出树的重心不断进行分治,而复杂度为$O(nlognlogn)$。
2. 这一道题很无耻的卡memset(),删掉memset你就奇迹般地从TLE变成了547ms。此题还会与分段式桶排发生反应,我的比sort在1e7下快10倍的桶排竟然会TLE,而sort奇迹般AC。考试时如果遇到这种让我做一下午一晚上的题,我一定会说:打暴力。如果非要给这个暴力加上一个期限,我希望是+1s。(+1s就可以用memset)

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

#define MX 10010

using namespace std;

int fst[MX],nxt[MX*2],v[MX*2],w[MX*2],lnum;
int vis[MX];
int n,k;
int mx[MX],sum[MX];

inline void addeg(int nu,int nv,int nw)
{
    lnum++;
    nxt[lnum]=fst[nu];
    fst[nu]=lnum;
    v[lnum]=nv;
    w[lnum]=nw;
}

inline void init()
{
    lnum=-1;
    memset(fst,-1,sizeof(fst));
}

inline void input()
{
    int a,b,c;
    for(int i=1;i<n;i++)
    {
        scanf("%d%d%d",&a,&b,&c);
        addeg(a,b,c);
        addeg(b,a,c);
    }
}

int root,sz;
void _getroot(int x,int fa)
{
    sum[x]=mx[x]=0;
    for(int i=fst[x];i!=-1;i=nxt[i])
    {
        if(v[i]==fa||vis[v[i]])continue;
        _getroot(v[i],x);
        sum[x]+=sum[v[i]]+1;
        mx[x]=max(mx[x],sum[v[i]]+1);
    }
    mx[x]=max(mx[x],sz-sum[x]-1);
    if(mx[x]<mx[root])root=x;
}
inline void getroot(int x,int fa)
{
    root=0;
    mx[root]=99999999;
    _getroot(x,fa);
}

int q[MX],dp[MX];
int dis[MX];
inline void getdep(int pred,int x,int fa)
{
    int h=1,t=1,now;
    memset(dp,0xff,sizeof(dp));
    dis[0]=0;
    dp[x]=pred;
    dis[++dis[0]]=pred;
    q[h]=x;
    while(h>=t)
    {
        now=q[t++];
        for(int i=fst[now];i!=-1;i=nxt[i])
        {
            if(v[i]==fa||vis[v[i]]||dp[v[i]]!=-1)continue;
            dp[v[i]]=dp[now]+w[i];
            dis[++dis[0]]=dp[v[i]];
            q[++h]=v[i];
        }
    }
}

int tdis[MX];
int sch(int x,int fa)
{
    int a=0,b=0;
    vis[x]=1;
    tdis[0]=0;
    for(int i=fst[x];i!=-1;i=nxt[i])
    {
        if(v[i]==fa||vis[v[i]])continue;
        getdep(w[i],v[i],x);
        sort(dis+1,dis+dis[0]+1);
        for(int j=1;j<=dis[0];j++)tdis[++tdis[0]]=dis[j];
        for(int j=1,c=dis[0];j<=dis[0];j++)
        {
            while(dis[c]+dis[j]>k&&c>=1)c--;
            if(c<=j)break;
            b+=c-j;
        }
    }
    sort(tdis+1,tdis+tdis[0]+1);
    for(int j=tdis[0];j>=1;j--)if(tdis[j]<=k){a+=j;break;}
    for(int i=1,j=tdis[0];i<=tdis[0];i++)
    {
        while(tdis[j]+tdis[i]>k&&j>=1)j--;
        if(j<=i)break;
        a+=j-i;
    }
    a-=b;
    for(int i=fst[x];i!=-1;i=nxt[i])
    {
        if(v[i]==fa||vis[v[i]])continue;
        sz=sum[v[i]]+1;
        getroot(v[i],x);
        a+=sch(root,x);
    }
    return a;
}

int main()
{
    while(~scanf("%d%d",&n,&k))
    {
        memset(vis,0,sizeof(vis));
        if(n==0&&k==0)break;
        init();
        input();
        sz=n;
        getroot(1,0);
        printf("%d\n",sch(root,0));
    }
    return 0;
}

分类: 文章

发表评论

电子邮件地址不会被公开。 必填项已用*标注

你是机器人吗? =。= *