CF1906D題解

OEAiHAN發表於2024-03-26

CF1906D

這裡更容易進入且有翻譯

題意

給定一個 \(n\) 個頂點的凸多邊形,多次進行詢問。每次詢問給出兩個不在多邊形內的點 \(P^1_j (A_j, B_j), P^2_j (C_j, D_j)\),問能否找到一個點 \(P^3\),使線段 \(P^1P^3, P^2P^3\) 不與凸多邊形相交(可以相切),並求最短的 \(|P^1P^3| + |P^2P^3|\)

解析

這題思路比較簡單,但碼量不小。

對於每次詢問,先判斷線段 \(P^1P^2\) 是否與凸多邊形相交,若不相交直接輸出 \(|P^1P^2|\) 即可;否則,可分別二分求 \(P^1, P^2\) 的切線(各兩條,為以 \(P^1, P^2\) 為端點的射線),算出 \(P^1\) 各條切線與 \(P^2\) 切線的交點作為 \(P^3\) 計算出 \(\min\{|P^1P^3_i| + |P^2P^3_i|\}\) 作為答案,或者確認不存在這樣一個交點 \(P^3\)

對於二分求切線,如果點的橫座標不大於或不小於凸包上的任意點的橫座標,則把凸包分成上凸包和下凸包分別進行二分即可;否則,可按點的橫座標將凸包分割成左凸包和右凸包分別進行二分。

關於線段是否與凸多邊形相交,將 \(P^1, P^2\) 視作點光源,判斷 \(P^1\)\(P^2\) 光線的可見點集是否有交,若交集內有多於 1 個元素則線段不與凸包相交;或者如果交集內正好有 1 個元素點 \(Z\),則可求 \(\overrightarrow{P^1P^2} \times \overrightarrow{P^1Z}\)\(\overrightarrow{P^2P^1} \times \overrightarrow{P^2Z}\) 進行判斷。

程式碼

#include <bits/stdc++.h>
#include <unordered_map>
#define LL long long
#define pii pair<int, int>
#define pll pair<LL, LL>
#define double long double
#define pdd pair<double, double>
#define eps 1e-9

using namespace std;

//直線
struct line
{
    pdd p, v;

    line(pll p, pll v)
    {
        this->p = (pdd)p;
        this->v = (pdd)v;
    }
};

pll operator + (pll l, pll r)
{
    return make_pair(l.first + r.first, l.second + r.second);
}

pll operator - (pll l, pll r)
{
    return make_pair(l.first - r.first, l.second - r.second);
}

//點乘,下同
LL dot(pll l, pll r)
{
    return l.first * r.first + l.second * r.second;
}

//叉乘,下同
LL cross(pll l, pll r)
{
    return l.first * r.second - l.second * r.first;
}

pdd operator + (pdd l, pdd r)
{
    return make_pair(l.first + r.first, l.second + r.second);
}

pdd operator - (pdd l, pdd r)
{
    return make_pair(l.first - r.first, l.second - r.second);
}

double dot(pdd l, pdd r)
{
    return l.first * r.first + l.second * r.second;
}

//求兩點距離
double dis(pdd l, pdd r)
{
    return sqrt(dot(l - r, l - r));
}

double cross(pdd l, pdd r)
{
    return l.first * r.second - l.second * r.first;
}

//求直線交點
pdd getNode(line l, line r)
{
    double s1 = cross(l.v, r.p - l.p), s2 = cross(l.v, r.p + r.v - l.p);
    return make_pair((r.p.first * s2 - (r.p.first + r.v.first) * s1) / (s2 - s1), (r.p.second * s2 - (r.p.second + r.v.second) * s1) / (s2 - s1));
}

int n;
pll ps[200005], ps1[100005];
map<pll, int> mp;

//二分求切點
pair<pll, pll> getPT(pll p)
{
    pair<pll, pll> res;

    if (p.first <= ps1[1].first)
    {
        int l = mp[ps1[1]], r = mp[ps1[n]];
        if (r < l)
            r += n;
        while (l < r)
        {
            int mid = l + r >> 1;
            if (cross(ps[mid + 1] - ps[mid], ps[mid + 1] - p) < 0)
                r = mid;
            else
                l = mid + 1;
        }
        res.first = ps[l];

        l = mp[ps1[n]], r = mp[ps1[1]];
        if (r < l)
            r += n;
        while (l < r)
        {
            int mid = l + r >> 1;
            if (cross(ps[mid + 1] - ps[mid], p - ps[mid]) > 0)
                l = mid + 1;
            else
                r = mid;
        }
        res.second = ps[l];
    }
    else if (p.first >= ps1[n].first)
    {
        int l = mp[ps1[n]], r = mp[ps1[1]];
        if (r < l)
            r += n;
        while (l < r)
        {
            int mid = l + r >> 1;
            if (cross(ps[mid + 1] - ps[mid], ps[mid + 1] - p) < 0)
                r = mid;
            else
                l = mid + 1;
        }
        res.first = ps[l];

        l = mp[ps1[1]], r = mp[ps1[n]];
        if (r < l)
            r += n;
        while (l < r)
        {
            int mid = l + r >> 1;
            if (cross(ps[mid + 1] - ps[mid], p - ps[mid]) > 0)
                l = mid + 1;
            else
                r = mid;
        }
        res.second = ps[l];
    }
    else
    {
        if (cross(ps1[1] - p, p - ps1[n]) < 0 || !cross(ps1[1] - p, p - ps1[n]) && !cross(ps[mp[ps1[n]] + 1] - ps1[n], ps1[1] - ps1[n]))
        {
            int L = mp[ps1[n]], R = mp[ps1[1]];
            if (R < L)
                R += n;
            int div = lower_bound(ps + L, ps + R + 1, p, greater<pll>()) - (ps + L);

            int l = L + div, r = R;

            if (r < l)
                r += n;
            while (l < r)
            {
                int mid = l + r >> 1;
                if (cross(ps[mid + 1] - ps[mid], ps[mid + 1] - p) < 0)
                    r = mid;
                else
                    l = mid + 1;
            }
            res.first = ps[l];

            l = L, r = L + div - 1;
            if (r < l)
                r += n;
            while (l < r)
            {
                int mid = l + r >> 1;
                if (cross(ps[mid + 1] - ps[mid], p - ps[mid]) > 0)
                    l = mid + 1;
                else
                    r = mid;
            }
            res.second = ps[l];
        }
        else
        {
            int L = mp[ps1[1]], R = mp[ps1[n]];
            if (R < L)
                R += n;
            int div = lower_bound(ps + L, ps + R + 1, p) - ps - L;

            int l = L + div, r = R;
            if (r < l)
                r += n;
            while (l < r)
            {
                int mid = l + r >> 1;
                if (cross(ps[mid + 1] - ps[mid], ps[mid + 1] - p) < 0)
                    r = mid;
                else
                    l = mid + 1;
            }
            res.first = ps[l];

            l = L, r = L + div - 1;
            if (r < l)
                r += n;
            while (l < r)
            {
                int mid = l + r >> 1;
                if (cross(ps[mid + 1] - ps[mid], p - ps[mid]) > 0)
                    l = mid + 1;
                else
                    r = mid;
            }
            res.second = ps[l];
        }
    }

    return res;
}

int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);

    int q;
    pll p0, p1;
    cin >> n;
    for (int i = 1; i <= n; i++)
    {
        cin >> ps[i].first >> ps[i].second;
        ps1[i] = ps[n + i] = ps[i];
        mp[ps[i]] = i;
    }
    sort(ps1 + 1, ps1 + n + 1);

    cin >> q;
    while (q--)
    {
        cin >> p0.first >> p0.second >> p1.first >> p1.second;

        pair<pll, pll> res1 = getPT(p0), res2 = getPT(p1);

        double ans = 1e25; //初始值一定要開大
        int r1 = mp[res1.first], l1 = mp[res1.second], r2 = mp[res2.first], l2 = mp[res2.second];
        if (r1 < l1)
            r1 += n;
        if (r2 < l2)
            r2 += n;

        //判斷線段是否與凸包相交
        if (!(l1 >= r2 || l2 >= r1))
            ans = dis(p0, p1);
        else if (l1 == r2 && cross(p0 - p1, ps[l1] - p1) >= 0)
            ans = dis(p0, p1);
        else if (l2 == r1 && cross(p1 - p0, ps[l2] - p0) >= 0)
            ans = dis(p0, p1);
        if (r1 > n && r2 <= n)
            l2 += n, r2 += n;
        else if (r2 > n && r1 <= n)
            l1 += n, r1 += n;
        if (!(l1 >= r2 || l2 >= r1))
            ans = dis(p0, p1);
        else if (l1 == r2 && cross(p0 - p1, ps[l1] - p1) >= 0)
            ans = dis(p0, p1);
        else if (l2 == r1 && cross(p1 - p0, ps[l2] - p0) >= 0)
            ans = dis(p0, p1);

        //求切線交點並計算答案
        if (cross(res1.first - p0, res2.first - p1))
        {
            pdd nd = getNode(line(p0, res1.first - p0), line(p1, res2.first - p1));
            if (dot((pdd)(res1.first - p0), nd - (pdd)p0) > eps && dot((pdd)(res2.first - p1), nd - (pdd)p1) > eps)
                ans = min(ans, dis(nd, p0) + dis(nd, p1));
        }
        if (cross(res1.first - p0, res2.second - p1))
        {
            pdd nd = getNode(line(p0, res1.first - p0), line(p1, res2.second - p1));
            if (dot((pdd)(res1.first - p0), nd - (pdd)p0) > eps && dot((pdd)(res2.second - p1), nd - (pdd)p1) > eps)
                ans = min(ans, dis(nd, p0) + dis(nd, p1));
        }
        if (cross(res1.second - p0, res2.first - p1))
        {
            pdd nd = getNode(line(p0, res1.second - p0), line(p1, res2.first - p1));
            if (dot((pdd)(res1.second - p0), nd - (pdd)p0) > eps && dot((pdd)(res2.first - p1), nd - (pdd)p1) > eps)
                ans = min(ans, dis(nd, p0) + dis(nd, p1));
        }
        if (cross(res1.second - p0, res2.second - p1))
        {
            pdd nd = getNode(line(p0, res1.second - p0), line(p1, res2.second - p1));
            if (dot((pdd)(res1.second - p0), nd - (pdd)p0) > eps && dot((pdd)(res2.second - p1), nd - (pdd)p1) > eps)
                ans = min(ans, dis(nd, p0) + dis(nd, p1));
        }

        if (ans >= 1e25)
            printf("-1\n");
        else
            printf("%.12LF\n", ans);
    }
}

最後祝各位順利AC。>w<