Codeforces Round #189 Div1C Kalila and Dimna in the Logging Industry
問題
木を全部切りたいです。
n(≦10^5)本の木があって、それぞれ高さがa_iメートルです。
チェーンソーがあってそれで切り倒すんですが、1メートル切ると使えなくなって新しいのを買わなきゃいけないというポンコツです。
新しいチェーンソーの値段は、b[(今まで高さ0になるまで切った最大の木の番号)]です。
全部の木を完全に切り倒すための最小コストを求めなさい。
制約:
1本目の木の高さは必ず1メートルで、b[n]は必ず0、
a_iは単調増加、b_iは単調減少です。
つまり、最初は必ず1本目の木を切ることになります。
たとえば、
6
1 2 3 10 20 30
6 5 4 3 2 0
(n=6、2行目がa_i、3行目がb_i)
の場合、
最初に1番目の木を切り(ここまでコスト0)、
次に3番目の木を切り(ここまでコスト6*3)、
次に6番目の木を切り(ここまでコスト18+4*30)、
残りの木全部をコスト0で切り、138というのが最小コストです。
解法
最後の木を倒したらそのあとはコスト0なので、そこまでのコストがいくら掛かるのかを求めたいです。
dp[i]をi番目の木を倒すまでの最小コストとすると、
dp[i]=min(dp[j]+a[i]*b[j]) (j<i)
となります(a[i]が単調増加なので、i番目より後の木を倒してから戻ってくる、という倒し方が最適になることはないです)。
ですがこれを愚直に実装するとO(n^2)でTLEしてしまいます。
そこでこのdp表更新をもう少し効率良く行う必要があります。
dp[j]+a[i]*b[j] という式は、よく見ると
y = b[j]*x + dp[j]
という一次関数にx=a[i]を代入したものだと考えられます。
つまり、ある数a[i]に対して、i-1本の一次関数の直線の中から一番下の点で交わる直線を判定する問題であることがわかります。
ところでbは単調減少かつaは単調増加なので、i番目のときに最小値としてj番目が選ばれた時、
i+1番目以降にj未満の直線が選ばれることはありません。
よって、どの直線が一番値を最小化するのかを覚えながら、今の直線より次の直線のほうがコストが低くなるのであればそちらを選ぶ、という形を取ることでO(n)の計算が可能になりそうです。
ところがひとつ罠があります。こんな例です。
a[5]のときはコスト最小の直線はL2でした。そこで新たな直線L5を追加して次のステップに進みます。
a[6]の最小値を求めるとき、まずL2との交差点とL3との交差点を求めます。
L3のほうが小さいのでL3が現状最小となります。
同様にL3とL4で比較して、L3のほうが小さい値で交差するので、L3が最小という結論になります。
しかし、図を見ると最小を実現するのは明らかにL5です。
というわけで、「こんな直線が追加されちゃうと、もうどこのxを取ってきても最小値を取らなくなってしまう」ような直線を適宜排除していく必要があります。
既存の線との交差点の位置関係からこれを割り出すことが、初等幾何的な計算でできます。
これをスタックっぽい構造で実現したのが以下のプログラムです。
(といっても、ほとんど他の方のソースの写経ですが・・・)
ちなみにa,bの各要素が最大10^9なので、64ビット整数でもこの判定はオーバーフローするらしく、doubleへのキャストを使っています。
なお今回はa,bの単調性などの過程があったためO(n)解法が実現できていますが、
一般化してn本の線からある値xにおける最小を実現する直線を判定するためにはO(nlogn)時間がかかります。
この一般化された問題はconvex hull trickと呼ばれてるらしいです。
本番中に解いた方々はこういうことちゃんと知ってるのかーと絶望的な気分になりました。
プログラム
#include<iostream> #include <fstream> #include <stdio.h> #include <stdlib.h> #define _USE_MATH_DEFINES #include <math.h> #include<string> #include<vector> #include<cmath> #include<stack> #include<queue> #include<sstream> #include<algorithm> #include<map> #include<complex> using namespace std; #define li long long int #define rep(i,to) for(li i=0;i<((li)(to));i++) #define repp(i,start,to) for(li i=(li)(start);i<((li)(to));i++) #define pb push_back #define sz(v) ((li)(v).size()) #define bgn(v) ((v).begin()) #define eend(v) ((v).end()) #define allof(v) (v).begin(),(v).end() #define dodp(v,n) memset(v,(li)n,sizeof(v)) #define bit(n) (1ll<<(li)(n)) #define mp(a,b) make_pair(a,b) #define rin rep(i,n) #define EPS 1e-10 #define ETOL 1e-8 #define MOD 1000000007 #define _T_ <<"\t"<< #define p2(a,b) cout<<a_T_b<<endl #define p3(a,b,c) cout<<a_T_b_T_c<<endl li a[100001], b[100001], dp[100001]; li aa[100001], bb[100001]; li lines=0; li best=0; bool owata(){ double x12=(double)(bb[lines-2]-bb[lines-3])/(double)(aa[lines-3]-aa[lines-2]); double x23=(double)(bb[lines-2]-bb[lines-1])/(double)(aa[lines-1]-aa[lines-2]); return x12>=x23; } void addLine(li an, li bn){ aa[lines]=an; bb[lines]=bn; lines++; while(lines>=3 && owata()){ aa[lines-2]=aa[lines-1]; bb[lines-2]=bb[lines-1]; lines--; } if(best>=lines)best=lines-1; } li calc(li x){ if(best>=lines)best=lines-1; while(best<lines-1 && aa[best]*x+bb[best] >= aa[best+1]*x+bb[best+1]){ best++; } return aa[best]*x+bb[best]; } int main(){ li n; cin>>n; vector<li> cand; rin{cin>>a[i];} rin{cin>>b[i];} rin{dp[i]=1e+18;} dp[0]=0; addLine(b[0], 0); repp(i,1,n){ dp[i]=calc(a[i]); addLine(b[i],dp[i]); } cout<<dp[n-1]<<endl; return 0; }