BZOJ 2752 [HAOI2012]高速公路(road):線段樹【維護區間內子串和】

Leohh發表於2018-03-12

題目連結:http://www.lydsy.com/JudgeOnline/problem.php?id=2752

題意:

  有一個初始全為0的,長度為n的序列a。

  有兩種操作:

    (1)C l r v: 將[l,r)內的數全部加v。

    (2)Q l r: 在[l,r)內隨機選兩個數x,y(x < y),問你∑(a[x to y])的期望,用最簡分數形式輸出。

 

題解:

  首先,題中要求的期望 = 區間內所有子串之和 / 區間內子串個數。

  如果一個區間的長度為len,顯然區間內的子串個數為len*(len+1)/2。

  所以題目就變成了怎樣維護區間內所有子串之和。

 

  dat表示某個區間的子串和。

  假設有兩個相鄰區間l,r,合併起來的區間叫x。

  那麼dat[x] = dat[x] + dat[y] + 跨兩個區間的子串和

 

  所以接下來考慮如何求跨區間的子串和。

  sum表示某個區間的所有元素之和。

  ln表示區間l的長度,rn表示區間r的長度。

  ls表示某個區間的所有所有字首之和,rs表示某個區間的所有字尾之和。

  則跨區間的子串之和 = rs[l]*rn + ls[r]*ln

  即dat[x] = dat[x] + dat[y] + rs[l]*rn + ls[r]*ln

 

  ls,rs和sum的合併就很好求了:

  ls[x] = ls[l] + rn*sum[l] + ls[r]

  rs[x] = rs[r] + ln*sum[r] + rs[l]

  sum[x] = sum[l] + sum[r]

 

  這樣線段樹的pushup函式就寫完了。

  然後考慮如何pushdown傳標記。

 

  tag表示某個區間被同時加了多少。

  現在只考慮當前節點x的某一個兒子y,兒子y的區間長度為len。

 

  首先考慮tag[x]對dat[y]的貢獻。

  貢獻 = 列舉子串的長度 * 這種長度的子串個數 * tag[x]

  即:dat[y] += ∑ i*(len-i+1)*tag[x],其中i∈[1,len]。

  化簡得:dat[y] += ( len*(len+1)/2*(len+1) + ∑(i^2) ) * tag[x]

  對於其中的∑(i^2),事先O(n)預處理出來一個平方字首和陣列sqr即可。

 

  然後易得tag[x]對ls,rs,sum的貢獻:

  ls[y] += len*(len+1)/2*tag[x]

  rs[y] += len*(len+1)/2*tag[x]

  sum[y] += len*tag[x]

 

  這樣pushdown也就寫好了。

  然後大力線段樹即可QAQ……

 

AC Code:

  1 #include <iostream>
  2 #include <stdio.h>
  3 #include <string.h>
  4 #include <algorithm>
  5 #define MAX_N 100005
  6 #define MAX_T 400005
  7 #define int ll
  8 
  9 using namespace std;
 10 
 11 typedef long long ll;
 12 
 13 struct Node
 14 {
 15     int dt,ls,rs,s,ln;
 16     Node(int _dt,int _ls,int _rs,int _s,int _ln)
 17     {
 18         dt=_dt; ls=_ls; rs=_rs; s=_s; ln=_ln;
 19     }
 20     Node(){}
 21     friend Node mix(const Node &a,const Node &b)
 22     {
 23         int _dt=a.dt+b.dt+a.rs*b.ln+b.ls*a.ln;
 24         int _ls=a.ls+b.ln*a.s+b.ls;
 25         int _rs=b.rs+a.ln*b.s+a.rs;
 26         int _s=a.s+b.s;
 27         int _ln=a.ln+b.ln;
 28         return Node(_dt,_ls,_rs,_s,_ln);
 29     }
 30 };
 31 
 32 int n,m;
 33 int ls[MAX_T];
 34 int rs[MAX_T];
 35 int dat[MAX_T];
 36 int sum[MAX_T];
 37 int tag[MAX_T];
 38 int sqr[MAX_N];
 39 
 40 void cal_sqr()
 41 {
 42     for(int i=1;i<=n;i++) sqr[i]=sqr[i-1]+i*i;
 43 }
 44 
 45 void push_up(int x,int len)
 46 {
 47     int l=x*2+1,r=x*2+2;
 48     Node L(dat[l],ls[l],rs[l],sum[l],len-(len>>1));
 49     Node R(dat[r],ls[r],rs[r],sum[r],(len>>1));
 50     Node tmp=mix(L,R);
 51     dat[x]=tmp.dt;
 52     ls[x]=tmp.ls;
 53     rs[x]=tmp.rs;
 54     sum[x]=tmp.s;
 55 }
 56 
 57 void push_down(int x,int len)
 58 {
 59     if(tag[x])
 60     {
 61         int l=x*2+1,r=x*2+2;
 62         int ln=(len-(len>>1)),rn=(len>>1);
 63         dat[l]+=(ln*(ln+1)/2*(ln+1)-sqr[ln])*tag[x];
 64         dat[r]+=(rn*(rn+1)/2*(rn+1)-sqr[rn])*tag[x];
 65         ls[l]+=ln*(ln+1)/2*tag[x];
 66         ls[r]+=rn*(rn+1)/2*tag[x];
 67         rs[l]+=ln*(ln+1)/2*tag[x];
 68         rs[r]+=rn*(rn+1)/2*tag[x];
 69         sum[l]+=ln*tag[x];
 70         sum[r]+=rn*tag[x];
 71         tag[l]+=tag[x];
 72         tag[r]+=tag[x];
 73         tag[x]=0;
 74     }
 75 }
 76 
 77 void update(int a,int b,int k,int l,int r,int x)
 78 {
 79     if(a<=l && r<=b)
 80     {
 81         int len=r-l+1;
 82         tag[k]+=x;
 83         sum[k]+=len*x;
 84         ls[k]+=len*(len+1)/2*x;
 85         rs[k]+=len*(len+1)/2*x;
 86         dat[k]+=(len*(len+1)/2*(len+1)-sqr[len])*x;
 87         return;
 88     }
 89     if(r<a || b<l) return;
 90     push_down(k,r-l+1);
 91     int mid=(l+r)>>1;
 92     update(a,b,k*2+1,l,mid,x);
 93     update(a,b,k*2+2,mid+1,r,x);
 94     push_up(k,r-l+1);
 95 }
 96 
 97 Node query(int a,int b,int k,int l,int r)
 98 {
 99     if(a<=l && r<=b) return Node(dat[k],ls[k],rs[k],sum[k],r-l+1);
100     if(r<a || b<l) return Node(0,0,0,0,0);
101     push_down(k,r-l+1);
102     int mid=(l+r)>>1;
103     Node v1=query(a,b,k*2+1,l,mid);
104     Node v2=query(a,b,k*2+2,mid+1,r);
105     return mix(v1,v2);
106 }
107 
108 signed main()
109 {
110     scanf("%lld%lld",&n,&m);
111     n--;
112     cal_sqr();
113     char opt[16];
114     int l,r,v;
115     while(m--)
116     {
117         scanf("%s%lld%lld",opt,&l,&r);
118         if(opt[0]=='C')
119         {
120             scanf("%lld",&v);
121             update(l,r-1,0,1,n,v);
122         }
123         else
124         {
125             int dt=query(l,r-1,0,1,n).dt;
126             int len=r-l;
127             int tot=len*(len+1)/2;
128             int g=__gcd(dt,tot);
129             printf("%lld/%lld\n",dt/g,tot/g);
130         }
131     }
132 }

 

相關文章