动态区间第k大的一种$O(nlog^2n)$的树套树解法

题意:

给定一个序列,有“将一个数修改为另一个”的操作,和询问“[l,r]区间内的第k小值是几”的询问。要求在1s内对一个长$10^4$的序列完成$10^4$组修改和询问。

思路:

1.
我最先想到的树套树的思路是:线段树套Splay。

线段树的每一个区间[l,r]定义为原序列的[l,r]中的数字组成的排序Splay。这样,对于每个修改操作,我们对$logn$棵包含了这个位置的Splay删除一个数,添加一个数即可。对于每个询问操作[l,r],我们需要二分答案x,再在logn棵恰好组成[l,r]区间的Splay中查询x的排名。

由于线段树有logn层,每层的所有Splay合并后恰为一个完整的原序列,所以空间复杂度为O(nlogn)的。

经过上面的分析可知:修改复杂度为$O(nlog^2n)$,查询为$O(nlog^3n)$。实践证明可以通过所有测试点。洛谷最慢测试点340ms.
2.
第二种思路,即本文终点介绍的思路,两种操作都是$O(nlog^2n)$的,只是还需要对所有出现过的数字离散化。

刚才,我们的线段树对应的是区间,Splay对应的是值。如果互换一下呢?

现在线段树是建立在离散化后的实数域上的。线段树区间[l,r]定义为原序列中值属于[l,r]的所有值的下标的排序Splay。即:将值属于[l,r]的所有位置提取为一个新的序列,保存在这个Splay里。

对于修改操作,我们依旧修改$logn$棵Splay。如果我们将a[i]修改为b,则将所有的线段树节点$[l,r](l\leq a[i]\leq r)$中的Splay中删除i。同时向所有的线段树节点$[l,r](l\leq b\leq r)$中的Splay添加i。故一次修改操作的复杂度为$O(log^2n)$。

对于查询操作,注意到这棵线段树是支持前缀和的。即实数区间[l,r]内的数x$(a\leq x \leq b)$的个数等于[1,r]内的个数减[1,l-1]内的个数(这里的x就是原序列的下标)。所以我们只需要在树上二分即可。如果现在我们确定了k小值一定在实数区间[a,b]内,那么如果[a,(a+b)/2]内的Splay中满足上述条件的节点小于k,则k小值一定在[(a+b)/2+1,b]内,反之同理。这样的查询操作只需要对logn个线段树节点进行查询,故一次询问的时间复杂度为$O(log^n)$。

实践证明可以通过所有测试点,洛谷最慢测试点92ms.

细节问题

这种方法虽然省去了一个log,但是其代码量却比之前多了一个log。

以下是一些需要注意的地方:

  • 离散化不但要离散原序列的值,也要离散修改出的值。
  • 在最初插入节点时最好使用类似线段树建树一样的归并方式,这样可以降低常数。实践证明不这样洛谷最慢点为296ms.
  • 最好不要用这种方法因为我打了200行。
/*
A data structure used to maintain interval kth number
With splays in a segment tree
*/
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <map>

#define MX 61005
#define mid ((l+r)>>1)
#define ls (a<<1)
#define rs (a<<1|1)

using namespace std;

typedef struct splnode
{
    int x,f,siz,s[2];
}node;
typedef struct tqeury
{
    int l,r,x,t;
}query;
query qur[MX];
node t[MX*18];
int seq[MX],real[MX];
vector<int>ord[MX];
map<int,int>mp;
map<int,int>::iterator itr;
int tnum;
int n,m;
int bar[MX],bnum;

typedef struct trenode
{
    int root,l,r;
    inline int pos(int a){return t[t[a].f].s[1]==a;}
    inline void upd(int a){t[a].siz=t[t[a].s[0]].siz+t[t[a].s[1]].siz+1;}
    inline void rot(int a)
    {
        int f=t[a].f,g=t[f].f,p=pos(a),q=pos(f);
        t[f].s[p]=t[a].s[!p],t[a].s[!p]=f,t[f].f=a;
        if(t[f].s[p])t[t[f].s[p]].f=f;
        if(t[a].f=g)t[g].s[q]=a;
        upd(f),upd(a);
    }
    inline void spl(int tar,int a)
    {
        while(t[a].f!=tar)
            if(t[t[a].f].f==tar)rot(a);
            else if(pos(a)==pos(t[a].f))rot(t[a].f),rot(a);
            else rot(a),rot(a);
        if(!tar)root=a;
    }
    int merg(int f,int l,int r)
    {
        if(l>r)return 0;
        int a=++tnum;
        t[a]=(node){bar[mid],f,1,0,0};
        t[a].s[0]=merg(a,l,mid-1);
        t[a].s[1]=merg(a,mid+1,r);
        upd(a);
        return a;
    }
    void insrt(int &a,int f,int x)
    {
        if(!a)t[a=++tnum]=(node){x,f,1,0,0},spl(0,tnum);
        else if(x<t[a].x)insrt(t[a].s[0],a,x);
        else insrt(t[a].s[1],a,x);
    }
    int findn(int a,int x)
    {
        if(!a)return 0;
        else if(t[a].x==x)return a;
        else if(x<t[a].x)return findn(t[a].s[0],x);
        else return findn(t[a].s[1],x);
    }
    void del(int x)
    {
        int a=findn(root,x);
        spl(0,a);
        int la=t[a].s[0],ra=t[a].s[1];
        while(t[la].s[1])la=t[la].s[1];
        spl(a,la);
        t[la].s[1]=ra,t[ra].f=la,t[root=la].f=0;
        spl(0,ra);
    }
    int rank(int a,int x)
    {
        if(!a)return 0;
        else if(x>=t[a].x)return rank(t[a].s[1],x)+t[t[a].s[0]].siz+1;
        else return rank(t[a].s[0],x);
    }
}segt;
segt tre[MX*4];

void build(int a,int l,int r)
{
    tre[a].l=l,tre[a].r=r;
    if(l<r)build(ls,l,mid),build(rs,mid+1,r);
    bar[1]=-MX;
    bnum=1;
    for(int p=l;p<=r;p++)
        for(int i=0;i<ord[p].size();i++)
            bar[++bnum]=ord[p][i];
    bar[++bnum]=MX;
    sort(bar+1,bar+bnum+1);
    tre[a].root=tre[a].merg(0,1,bnum);
}

void del(int a,int p,int x)
{
    int l=tre[a].l,r=tre[a].r;
    tre[a].del(x);
    if(l==r)return;
    else if(p<=mid)del(ls,p,x);
    else del(rs,p,x);
}

void ins(int a,int p,int x)
{
    int l=tre[a].l,r=tre[a].r;
    tre[a].insrt(tre[a].root,0,x);
    if(l==r)return;
    else if(p<=mid)ins(ls,p,x);
    else ins(rs,p,x);
}

int kth(int a,int ql,int qr,int k)
{
    int dlt=tre[ls].rank(tre[ls].root,qr)-tre[ls].rank(tre[ls].root,ql);
    if(tre[a].l==tre[a].r)return tre[a].r;
    else if(dlt<k)return kth(rs,ql,qr,k-dlt);
    else return kth(ls,ql,qr,k);
}

void lsh()
{
    int x;
    for(x=1;x<=n;x++)mp[seq[x]]=1;
    for(x=1;x<=m;x++)if(qur[x].t==0)mp[qur[x].x]=1;
    for(x=1,itr=mp.begin();itr!=mp.end();itr++,x++)itr->second=x;
    for(x=1,itr=mp.begin();itr!=mp.end();itr++,x++)real[itr->second]=itr->first;
    for(x=1;x<=n;x++)seq[x]=mp[seq[x]];
    for(x=1;x<=m;x++)if(qur[x].t==0)qur[x].x=mp[qur[x].x];
}

void inpt()
{
    char str[10];
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)scanf("%d",&seq[i]);
    for(int i=1;i<=m;i++)
    {
        scanf("%s",str);
        if(str[0]=='C')qur[i].t=0,scanf("%d%d",&qur[i].l,&qur[i].x);
        else qur[i].t=1,scanf("%d%d%d",&qur[i].l,&qur[i].r,&qur[i].x);
    }
    lsh();
    for(int i=1;i<=n;i++)ord[seq[i]].push_back(i);
    n=mp.size();
    build(1,1,n);
}

void work()
{
    for(int i=1;i<=m;i++)
    {
        if(qur[i].t==0)
        {
            del(1,seq[qur[i].l],qur[i].l);
            ins(1,qur[i].x,qur[i].l);
            seq[qur[i].l]=qur[i].x;
        }
        else printf("%d\n",real[kth(1,qur[i].l-1,qur[i].r,qur[i].x)]);
    }
}

void init()
{
    tnum=0;
    mp.clear();
    for(int i=1;i<=n;i++)ord[i].clear();
}

int main()
{
    int T;
    scanf("%d",&T);
    for(int i=1;i<=T;i++)
    {
        init();
        inpt();
        work();
    }
    return 0;
}

分享至ヾ(≧∇≦*)ゝ:
分类: 所有

1 条评论

konnyakuxzy · 2018年1月9日 8:35 下午

我去这代码确实挺长的QvQ
您码力太强了Orz
不过确实奇怪网上居然没有这种权值线段树の题解

发表评论

电子邮件地址不会被公开。 必填项已用*标注

你是机器人吗? =。= *