更快的NTT

题意

给定两个序列$A$和$B$,$Alice$从$A$序列中随机抽取一个数字$x$,$Bob$从$B$序列中随机抽取一个数字$y$,求$(x+y)^t$的期望在模$998244353$意义下的值。

前置理论

期望的线性性

设$a,b$是两个不相关的随机变量,$E(x)$为$x$的期望取值,则:
$$
E(a\cdot b)=E(a)\cdot E(b)
$$
证明:

如果$a,b$的取值是离散的(比如题目中的$a,b$,只能从给定的有限个数中选择),那么可以设$P_a(x)$为$a$取到$x$的概率

则我们要求的$E(a\cdot b)=\sum_{\forall x,y}{P_a(x)P_b(y)}$

又因为$E(a)\cdot E(b)=\sum_{\forall x} P_a(x)\sum_{\forall y} P_b(y)$

将下面的式子的$P_a(x)$乘进后面的$\sum$即可发现两个式子是相等的。

但是,要注意,如果$a,b$是相关的,那么上述性质就不再适用了。

整数$k$次幂和

对于一个数列,如果我们要求:对于所有的$t\in [0,k]$,$A_i^t$的和,怎么做呢?

方法1:

对于这个问题,我们有一个$O(nlog^2n)$的做法。

写出这个和关于$t$的生成函数(函数中$x^c$系数是$c$次幂和)
$$
1+a_1^1x+a_1^2x^2+a_1^3x^3+\cdots+\\
1+a_2^1x+a_2^2x^2+a_2^3x^3+\cdots +\\
\cdots \\
$$
根据等比数列求和公式(在生成函数中的推广),我们可以知道,上面的式子实际上是:
$$
\frac{1}{1-a_1x}+\frac{1}{1-a_2x}+\cdots+\frac{1}{1-a_nx}
$$
如果我们暴力从左往右合并这个式子,时间复杂度是:$O(n^2)$的,因为每进行一次通分,分子最高项次数必然增加。

如果逐层合并这个式子,时间复杂度可以变成$O(nlog^2n)$。逐层的意思,即每一层合并所有相邻的分式,且不重复合并。这个过程类似于第二种方法中的”分治FFT“,所以复杂度证明将会与方法二相同,这里不再赘述。

方法2:

一位大神坐在书桌前,对着眼前的《高等数学》略有所思,他观察$ln(1+x)$的泰勒展开后,捻灭了手中快要燃尽的雪茄,又若有所悟地点点头,合上了桌上的书。

他发现了什么?
$$
ln(1+x)=x-\frac{x^2}{2}+\frac{x^3}{3}-\frac{x^4}{4}+\cdots \\
ln(1+ax)=ax-\frac{a^2x^2}{2}+\frac{a^3x^3}{3}-\frac{a^4x^4}{4}+\cdots
$$
也就是说,$ln(1+ax)$的泰勒展开里隐藏了$a$的$k$次幂和。

那么我们怎么利用这个求多个$a$的$k$次幂和呢?
$$
\sum_{i=1}^nln(1+a_ix)=ln(\Pi_{i=1}^{n}(1+a_ix))
$$

假设我们已经搞出了
$$
F=ln(\Pi_{i=1}^{n}(1+a_ix))
$$
由于取对数的函数是个多项式,所以我们需要用点巧办法:
$$
ln(x)’=\frac{1}{x} \\
ln(x)=\int \frac{1}{x}
$$
由于多项式求逆是很方便的,而多项式积分更方便(具体实现请参考其他资料),因此多项式求对数也很方便。

现在的主要问题就是:如何求$\Pi_{i=1}^{n}(1+a_ix)$

如果我们把所有要求积的括号分成两部分分别求积,然后合并,那么复杂度将是:
$$
T(n)=T(\frac{n}{2})+O(nlogn)
$$
故$T(n)=O(nlog^2n)$

题目分析

首先,由于我们要求$(x+y)^t$的期望,根据第一个前置知识,我们就可以先将这个式子展开
$$
E(x+y)^t=\sum_{i=0}^tC_t^iE(x^i)E(y^{t-i})
$$
所以,我有种预感,就是为了求对于所有$t$ 的答案,我们只要分别求出对于所有的$t$,$x^t$的期望和$y^t$的期望,就可以用卷积的方式快速求出$(x+y)^t$的期望。

实际上,的确是这样的:
$$
\begin{aligned}
F(t)& =\sum_{i=0}^tC_t^iA(i)B(t-i) \\
F(t)& =\sum_{i=0}^t\frac{t!}{i!(t-i)!}A(i)B(t-i) \\
\frac{1}{t!}F(t)& =\sum_{i=0}^t\frac{1}{i!}A(i)\cdot\frac{1}{t-i}B(t-i)
\end{aligned}
$$
这是一个卷积的形式,被卷积的两函数分别是$\frac{1}{i!}A(i)$和$\frac{1}{t-i}B(t-i)$

所以下面的任务就是求$A(i),B(i)$,也就是$a,b$的$i$次幂期望。这个期望就是$i$次幂和除以项数,用第二个前置知识解决即可,总复杂度$O(nlog^2n)$。

更快的NTT

考虑到一般的NTT都略慢于FFT,优化NTT的常数以加速计算势在必行。

本题中由于需要进行分治FFT,NTT执行次数将非常多,所以优化的主要方法有两个:

  • 1.将一切可以预处理的东西预处理
  • 2.将一切可以省略的取模省略

那么原来的NTT是长这样的:

void dft(int f)
{
    for(int i=0;i<len;i++)if(i<rev[i])swap(x[i],x[rev[i]]);
    for(int w=1;w<len;w<<=1)
    {
        ll w1=ksm(G,(MOD-1)/(w<<1));
        if(f==-1)w1=inv(w1);
        for(int j=0;j<len;j+=(w<<1))
        {
            ll wn=1;
            for(int i=j;i<j+w;i++)
            {
                ll a=x[i],b=x[i+w];
                x[i]=(a+wn*b)%MOD;
                x[i+w]=(a-wn*b%MOD+MOD)%MOD;
                wn=wn*w1%MOD;
            }
        }
    }
    if(f==-1)
    {
        ll mul=ksm(len,MOD-2);
        for(int i=0;i<len;i++)x[i]=x[i]*mul%MOD;
    }
}

经过我比较彻底的常数优化后,成了这样:

void dft(int f)
{
    for(int i=0;i<len;i++)if(i<rev[i])swap(x[i],x[rev[i]]);
    for(int w=1,b=1;w<len;w<<=1,b++)
    {
        ll *wx=(f==1?wi[b]:wii[b]);     //wi和wii为原根的幂及幂的逆
        for(int j=0;j<len;j+=(w<<1))
        {
            for(int i=j;i<j+w;i++)
            {
                ll a=x[i],b=wx[i-j]*x[i+w];
                x[i]=(a+b)%MOD;
                x[i+w]=(a-b+MOD*MOD)%MOD;
            }
        }
    }
    if(f==-1)
    {
        ll mul=ksm(len,MOD-2);
        for(int i=0;i<len;i++)x[i]=x[i]*mul%MOD;
    }
}

在洛谷”【模板】多项式乘法”中,效率提升20%

在洛谷”玩游戏”中,效率提升48%

代码

下面的代码是优化过的方法一的代码。

不优化时无法通过本题。

#pragma GCC optimize("Ofast")
#pragma GCC target("sse3","sse2","sse")
#pragma GCC target("avx","sse4","sse4.1","sse4.2","ssse3")
#pragma GCC target("f16c")
#pragma GCC optimize("inline","fast-math","unroll-loops","no-stack-protector")
#pragma GCC diagnostic error "-fwhole-program"
#pragma GCC diagnostic error "-fcse-skip-blocks"
#pragma GCC diagnostic error "-funsafe-loop-optimizations"
#pragma GCC diagnostic error "-std=c++14"

#include <cstdlib>
#include <iostream>
#include <cstring>
#include <cstdio>
#define MX (524288+10)
#define MOD 998244353LL
#define G 3

using namespace std;

typedef long long ll;

int len,bit,rev[MX];
int n1,n2,m,A[MX],B[MX];
ll fac[MX];
ll EA[MX],EB[MX];
ll gp[MX],gpi[MX],wi[22][MX],wii[22][MX];

ll read()
{
    ll x=0;char ch=getchar();
    while(!isdigit(ch))ch=getchar();
    while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
    return x;
}

void qm(ll& x){x%=MOD;if(x<0)x+=MOD;}

void init(int l)
{
    len=1,bit=0,rev[0]=0;
    while(len<l)len<<=1,bit++;
    for(int i=1;i<len;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}

ll ksm(ll x,ll t)
{
    ll ans=1;
    while(t)
    {
        if(t&1)ans=ans*x%MOD;
        x=x*x%MOD;
        t>>=1;
    }
    return ans;
}

ll inv(ll x){return ksm(x,MOD-2);}

struct poly
{
    ll *x;
    void resize(){delete x;x=new ll[len];memset(x,0,len*sizeof(ll));}
    void reset(){memset(x,0,len*sizeof(ll));}
    void input(int s){for(int i=0;i<s;i++)scanf("%lld",&x[i]),qm(x[i]);}
    void output(int s){for(int i=0;i<s;i++)printf("%lld ",x[i]);putchar('\n');}
    void dft(int f)
    {
        for(int i=0;i<len;i++)if(i<rev[i])swap(x[i],x[rev[i]]);
        for(int w=1,b=1;w<len;w<<=1,b++)
        {
            ll *wx=(f==1?wi[b]:wii[b]);
            for(int j=0;j<len;j+=(w<<1))
            {
                for(int i=j;i<j+w;i++)
                {
                    ll a=x[i],b=wx[i-j]*x[i+w];
                    x[i]=(a+b)%MOD;
                    x[i+w]=(a-b+MOD*MOD)%MOD;
                }
            }
        }
        if(f==-1)
        {
            ll mul=ksm(len,MOD-2);
            for(int i=0;i<len;i++)x[i]=x[i]*mul%MOD;
        }
    }
};

poly up[MX],dn[MX];
poly dl,ul,dr,ur,dv,f1,f2;
poly tmp;

void get_inv(poly &a,poly &b,int nx)
{
    if(nx==1)b.x[0]=inv(a.x[0]);
    else
    {
        get_inv(a,b,nx>>1);
        for(int i=0;i<nx;i++)tmp.x[i]=a.x[i],tmp.x[i+nx]=0;
        init(nx<<1);
        tmp.dft(1);
        b.dft(1);
        for(int i=0;i<len;i++)
            tmp.x[i]=(b.x[i]*2-tmp.x[i]*b.x[i]%MOD*b.x[i]%MOD+MOD)%MOD;
        tmp.dft(-1);
        for(int i=0;i<nx;i++)b.x[i]=tmp.x[i],b.x[i+nx]=0;
    }
}

void work(int *val,ll *tar,int sz,int ori)
{
    init(2);
    for(int i=0;i<sz;i++)
    {
        up[i].resize();
        dn[i].resize();
        if(i<ori)up[i].x[0]=1;
        dn[i].x[0]=1;
        dn[i].x[1]=-val[i],qm(dn[i].x[1]);
    }
    for(int i=1;i<sz;i<<=1)
    {
        init(min(i<<2,sz<<1));
        dl.resize();
        dr.resize();
        ul.resize();
        ur.resize();
        for(int j=0;j+i<sz;j+=(i<<1))
        {
            dl.reset();
            dr.reset();
            ul.reset();
            ur.reset();
            for(int k=0;k<(len>>1);k++)dl.x[k]=dn[j].x[k];
            for(int k=0;k<(len>>1);k++)dr.x[k]=dn[j+i].x[k];
            for(int k=0;k<(len>>1);k++)ul.x[k]=up[j].x[k];
            for(int k=0;k<(len>>1);k++)ur.x[k]=up[j+i].x[k];
            dn[j].resize();
            up[j].resize();
            dl.dft(1);
            dr.dft(1);
            ul.dft(1);
            ur.dft(1);
            for(int k=0;k<len;k++)dn[j].x[k]=dl.x[k]*dr.x[k]%MOD;
            for(int k=0;k<len;k++)up[j].x[k]=(ul.x[k]*dr.x[k]+dl.x[k]*ur.x[k])%MOD;
            dn[j].dft(-1);
            up[j].dft(-1);
        }
    }

    dv.resize();
    tmp.resize();
    get_inv(dn[0],dv,len>>1);

    dv.dft(1);
    up[0].dft(1);
    for(int i=0;i<len;i++)tmp.x[i]=dv.x[i]*up[0].x[i]%MOD;
    tmp.dft(-1);

    for(int i=0;i<sz;i++)tar[i]=tmp.x[i];
}

void calc_ans()
{
    init(max(max(n1,m),n2));
    int slen=len;
    work(A,EA,slen,n1);
    work(B,EB,slen,n2);
    init(slen*2);
    f1.resize();
    f2.resize();
    for(int i=0;i<len>>1;i++)f1.x[i]=EA[i]*inv(fac[i])%MOD*inv(n1)%MOD,f2.x[i]=EB[i]*inv(fac[i])%MOD*inv(n2)%MOD;
    f1.dft(1);
    f2.dft(1);
    for(int i=0;i<len;i++)f1.x[i]=f1.x[i]*f2.x[i]%MOD;
    f1.dft(-1);
    for(int i=1;i<m;i++)printf("%lld\n",f1.x[i]*fac[i]%MOD);putchar('\n');
}

void prework()
{
    fac[0]=1;
    for(int i=1;i<MX;i++)fac[i]=fac[i-1]*(ll)i%MOD;
    for(int i=0;i<MX;i++)gp[i]=ksm(G,(MOD-1)>>i),gpi[i]=inv(gp[i]);
    for(int w=1,i=1;w<MX;i++,w<<=1)
    {
        wi[i][0]=1;
        wii[i][0]=1;
        for(int j=1;j<w;j++)
        {
            wi[i][j]=wi[i][j-1]*gp[i]%MOD;
            wii[i][j]=wii[i][j-1]*gpi[i]%MOD;
        }
    }
}

int main()
{
    prework();
    scanf("%d%d",&n1,&n2);
    for(int i=0;i<n1;i++)A[i]=read();
    for(int i=0;i<n2;i++)B[i]=read();
    scanf("%d",&m);m++;
    calc_ans();
    return 0;
}
分类: 所有

2 条评论

XZYQvQ · 2018年7月8日 10:31 下午

Do you like van♂游戏?

    boshi · 2018年7月9日 1:11 下午

    I like VAN 游 Shi*

发表评论

电子邮件地址不会被公开。 必填项已用*标注

你是机器人吗? =。= *