題目連結:http://www.lydsy.com/JudgeOnline/problem.php?id=2752
題意:
有一個初始全為0的,長度為n的序列a。
有兩種操作:
(1)C l r v: 將[l,r)內的數全部加v。
(2)Q l r: 在[l,r)內隨機選兩個數x,y(x < y),問你∑(a[x to y])的期望,用最簡分數形式輸出。
題解:
首先,題中要求的期望 = 區間內所有子串之和 / 區間內子串個數。
如果一個區間的長度為len,顯然區間內的子串個數為len*(len+1)/2。
所以題目就變成了怎樣維護區間內所有子串之和。
dat表示某個區間的子串和。
假設有兩個相鄰區間l,r,合併起來的區間叫x。
那麼dat[x] = dat[x] + dat[y] + 跨兩個區間的子串和
所以接下來考慮如何求跨區間的子串和。
sum表示某個區間的所有元素之和。
ln表示區間l的長度,rn表示區間r的長度。
ls表示某個區間的所有所有字首之和,rs表示某個區間的所有字尾之和。
則跨區間的子串之和 = rs[l]*rn + ls[r]*ln
即dat[x] = dat[x] + dat[y] + rs[l]*rn + ls[r]*ln
ls,rs和sum的合併就很好求了:
ls[x] = ls[l] + rn*sum[l] + ls[r]
rs[x] = rs[r] + ln*sum[r] + rs[l]
sum[x] = sum[l] + sum[r]
這樣線段樹的pushup函式就寫完了。
然後考慮如何pushdown傳標記。
tag表示某個區間被同時加了多少。
現在只考慮當前節點x的某一個兒子y,兒子y的區間長度為len。
首先考慮tag[x]對dat[y]的貢獻。
貢獻 = 列舉子串的長度 * 這種長度的子串個數 * tag[x]
即:dat[y] += ∑ i*(len-i+1)*tag[x],其中i∈[1,len]。
化簡得:dat[y] += ( len*(len+1)/2*(len+1) + ∑(i^2) ) * tag[x]
對於其中的∑(i^2),事先O(n)預處理出來一個平方字首和陣列sqr即可。
然後易得tag[x]對ls,rs,sum的貢獻:
ls[y] += len*(len+1)/2*tag[x]
rs[y] += len*(len+1)/2*tag[x]
sum[y] += len*tag[x]
這樣pushdown也就寫好了。
然後大力線段樹即可QAQ……
AC Code:
1 #include <iostream> 2 #include <stdio.h> 3 #include <string.h> 4 #include <algorithm> 5 #define MAX_N 100005 6 #define MAX_T 400005 7 #define int ll 8 9 using namespace std; 10 11 typedef long long ll; 12 13 struct Node 14 { 15 int dt,ls,rs,s,ln; 16 Node(int _dt,int _ls,int _rs,int _s,int _ln) 17 { 18 dt=_dt; ls=_ls; rs=_rs; s=_s; ln=_ln; 19 } 20 Node(){} 21 friend Node mix(const Node &a,const Node &b) 22 { 23 int _dt=a.dt+b.dt+a.rs*b.ln+b.ls*a.ln; 24 int _ls=a.ls+b.ln*a.s+b.ls; 25 int _rs=b.rs+a.ln*b.s+a.rs; 26 int _s=a.s+b.s; 27 int _ln=a.ln+b.ln; 28 return Node(_dt,_ls,_rs,_s,_ln); 29 } 30 }; 31 32 int n,m; 33 int ls[MAX_T]; 34 int rs[MAX_T]; 35 int dat[MAX_T]; 36 int sum[MAX_T]; 37 int tag[MAX_T]; 38 int sqr[MAX_N]; 39 40 void cal_sqr() 41 { 42 for(int i=1;i<=n;i++) sqr[i]=sqr[i-1]+i*i; 43 } 44 45 void push_up(int x,int len) 46 { 47 int l=x*2+1,r=x*2+2; 48 Node L(dat[l],ls[l],rs[l],sum[l],len-(len>>1)); 49 Node R(dat[r],ls[r],rs[r],sum[r],(len>>1)); 50 Node tmp=mix(L,R); 51 dat[x]=tmp.dt; 52 ls[x]=tmp.ls; 53 rs[x]=tmp.rs; 54 sum[x]=tmp.s; 55 } 56 57 void push_down(int x,int len) 58 { 59 if(tag[x]) 60 { 61 int l=x*2+1,r=x*2+2; 62 int ln=(len-(len>>1)),rn=(len>>1); 63 dat[l]+=(ln*(ln+1)/2*(ln+1)-sqr[ln])*tag[x]; 64 dat[r]+=(rn*(rn+1)/2*(rn+1)-sqr[rn])*tag[x]; 65 ls[l]+=ln*(ln+1)/2*tag[x]; 66 ls[r]+=rn*(rn+1)/2*tag[x]; 67 rs[l]+=ln*(ln+1)/2*tag[x]; 68 rs[r]+=rn*(rn+1)/2*tag[x]; 69 sum[l]+=ln*tag[x]; 70 sum[r]+=rn*tag[x]; 71 tag[l]+=tag[x]; 72 tag[r]+=tag[x]; 73 tag[x]=0; 74 } 75 } 76 77 void update(int a,int b,int k,int l,int r,int x) 78 { 79 if(a<=l && r<=b) 80 { 81 int len=r-l+1; 82 tag[k]+=x; 83 sum[k]+=len*x; 84 ls[k]+=len*(len+1)/2*x; 85 rs[k]+=len*(len+1)/2*x; 86 dat[k]+=(len*(len+1)/2*(len+1)-sqr[len])*x; 87 return; 88 } 89 if(r<a || b<l) return; 90 push_down(k,r-l+1); 91 int mid=(l+r)>>1; 92 update(a,b,k*2+1,l,mid,x); 93 update(a,b,k*2+2,mid+1,r,x); 94 push_up(k,r-l+1); 95 } 96 97 Node query(int a,int b,int k,int l,int r) 98 { 99 if(a<=l && r<=b) return Node(dat[k],ls[k],rs[k],sum[k],r-l+1); 100 if(r<a || b<l) return Node(0,0,0,0,0); 101 push_down(k,r-l+1); 102 int mid=(l+r)>>1; 103 Node v1=query(a,b,k*2+1,l,mid); 104 Node v2=query(a,b,k*2+2,mid+1,r); 105 return mix(v1,v2); 106 } 107 108 signed main() 109 { 110 scanf("%lld%lld",&n,&m); 111 n--; 112 cal_sqr(); 113 char opt[16]; 114 int l,r,v; 115 while(m--) 116 { 117 scanf("%s%lld%lld",opt,&l,&r); 118 if(opt[0]=='C') 119 { 120 scanf("%lld",&v); 121 update(l,r-1,0,1,n,v); 122 } 123 else 124 { 125 int dt=query(l,r-1,0,1,n).dt; 126 int len=r-l; 127 int tot=len*(len+1)/2; 128 int g=__gcd(dt,tot); 129 printf("%lld/%lld\n",dt/g,tot/g); 130 } 131 } 132 }