首页 技术分享 正文
  • 本文约727字,阅读需4分钟
  • 1436
  • 0

熟练剖分(tree) 树形DP

熟练剖分(tree) 树形DP

题目描述

题目传送门

分析

我们设\(f[i][j]\)为以\(i\)为根节点的子树中最坏时间复杂度小于等于\(j\)的概率

\(g[i][j]\)为当前扫到的以\(i\)为父亲节点的所有儿子最坏时间复杂度小于等于\(j\)的概率之和

因为每遍历到一个新的节点,原来的\(g\)数组中的值就要全部更新,因此我们压掉第一维

下面我们考虑转移

对于当前枚举到的某一个节点,我们用三重循环分别扫一边

第一重循环代表当前哪一个节点充当重儿子,第二重循环枚举所有儿子,第三充循环枚举最坏时间复杂度\(k\)

如果第二重循环中枚举的儿子恰好是重儿子的话,那么父亲节点的最坏时间复杂度为\(k\)的情况可以由两种情况转移过来

第一种情况就是重儿子的时间复杂度恰好为\(k\)的概率乘上其它儿子时间复杂度小于等于\(k\)的概率

第二种情况就是其它儿子的时间复杂度恰好为\(k\)的概率乘上重儿子的时间复杂度小于等于\(k\)的概率

不要忘了减去重复的情况

如果第二重循环中枚举的儿子不是重儿子的话,那么父亲节点的最坏时间复杂度为\(k\)的情况可以由两种情况转移过来

第一种情况就是重儿子的时间复杂度恰好为\(k-1\)的概率乘上其它儿子时间复杂度小于等于\(k\)的概率

第二种情况就是其它儿子的时间复杂度恰好为\(k\)的概率乘上重儿子的时间复杂度小于等于\(k-1\)的概率

也不要忘了减去重复的情况

代码

#include<cstdio>
#include<cstring>
#include<vector>
inline int read(){
    int x=0,fh=1;
    char ch=getchar();
    while(ch<'0' || ch>'9'){
        if(ch=='-') fh=-1;
        ch=getchar();
    }
    while(ch>='0' && ch<='9'){
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    return x*fh;
}
const int maxn=3e3+5;
const int mod=1e9+7;
int fa[maxn],head[maxn],tot=1,n,rt;
struct asd{
    int to,next;
}b[maxn<<1];
void ad(int aa,int bb){
    b[tot].to=bb;
    b[tot].next=head[aa];
    head[aa]=tot++;
}
int ksm(int ds,int zs){
    int ans=1;
    while(zs){
        if(zs&1) ans=1LL*ans*ds%mod;
        ds=1LL*ds*ds%mod;
        zs>>=1;
    }
    return ans;
}
int son[maxn],siz[maxn];
long long f[maxn][maxn],g[maxn],h[maxn];
void dfs(int now){
    siz[now]=1;
    for(int i=head[now];i!=-1;i=b[i].next){
        int u=b[i].to;
        if(u==fa[now]) continue;
        dfs(u);
        siz[now]+=siz[u];
    }
    int p=ksm(son[now],mod-2);
    for(int i=head[now];i!=-1;i=b[i].next){
        if(b[i].to==fa[now]) continue;
        for(int j=0;j<=n;j++) g[j]=1;
        //初始化g数组
        int zez=b[i].to;
        //枚举重儿子
        for(int j=head[now];j!=-1;j=b[j].next){
            if(b[j].to==fa[now]) continue;
            int qez=b[j].to;
            //枚举其它儿子
            for(int k=0;k<=siz[qez]+1;k++){
                //枚举最大时间复杂度
                long long qt=g[k],xz=f[qez][k];
                if(k) qt-=g[k-1],xz-=f[qez][k-1];
                if(zez==qez){
                    h[k]=(qt*f[qez][k]%mod+xz*g[k]%mod-xz*qt%mod+mod)%mod;
                } else {
                    xz=f[qez][k-1];
                    if(k>1) xz-=f[qez][k-2];
                    h[k]=(qt*f[qez][k-1]%mod+xz*g[k]%mod-xz*qt%mod+mod)%mod;
                }
            }
            g[0]=h[0],h[0]=0;
            for(int k=1;k<=siz[qez]+1;k++){
                g[k]=(g[k-1]+h[k])%mod;
                h[k]=0;
            }
            //h数组临时存储状态
        }
        for(int j=siz[now];j>=1;j--){
            g[j]=(g[j]-g[j-1]+mod)%mod;
            //将前缀和数组还原成正常数组
        }
        for(int j=0;j<=siz[now];j++){
            f[now][j]=(f[now][j]+g[j]*p%mod)%mod;
        }
    }
    if(son[now]==0) f[now][0]=1;
    for(int i=1;i<=siz[now]+1;i++){
        f[now][i]=(f[now][i]+f[now][i-1])%mod;
    }
}
int main(){
    memset(head,-1,sizeof(head));
    n=read();
    int aa;
    for(int i=1;i<=n;i++){
        son[i]=read();
        for(int j=1;j<=son[i];j++){
            aa=read(),fa[aa]=i;
            ad(i,aa),ad(aa,i);
        }
    }
    rt=1;
    while(fa[rt]) rt=fa[rt];
    dfs(rt);
    long long ans=0;
    for(int i=1;i<=n;i++){
        ans=(ans+i*(f[rt][i]-f[rt][i-1]+mod)%mod)%mod;
    }
    printf("%lld\n",ans);
    return 0;
}


    评论
    博主关闭了当前页面的评论
    友情链接