定义

并查集,在一些有 N 个元素的集合应用问题中,我们通常是在开始时让每个元素构成一个单元素的集合,然后按一定顺序将属于同一组的元素所在的集合合并,其间要反复查找一个元素在哪个集合中。

模版

class Unionfind
{
  private:
    vector<int> father;

  public:
    Unionfind(int max_size) : father(std::vector<int>(max_size+1))
    {
        for (int i = 0; i <= max_size; i++)
            father[i] = i;
    }

    int find_father(int x)
    {
        return father[x] == x ? x : father[x] = find_father(father[x]);
    }

    void to_union(int x, int y)
    {
        x = find_father(x);
        y = find_father(y);
        if (x == y)
            return;
        father[x] = y;
    }

    bool is_same(int x, int y)
    {
        return find_father(x) == find_father(y);
    }
};

示例

奶酪 NOIP2017 D2T1

题目

现有一块大奶酪,它的高度为 h,它的长度和宽度我们可以认为是无限大的,奶酪中间有许多半径相同的球形空洞。我们可以在这块奶酪中建立空间坐标系,在坐标系中,奶酪的下表面为z = 0,奶酪的上表面为z = h

现在,奶酪的下表面有一只小老鼠 Jerry,它知道奶酪中所有空洞的球心所在的坐 标。如果两个空洞相切或是相交,则 Jerry 可以从其中一个空洞跑到另一个空洞,特别地,如果一个空洞与下表面相切或是相交,Jerry 则可以从奶酪下表面跑进空洞;如果 一个空洞与上表面相切或是相交,Jerry 则可以从空洞跑到奶酪上表面。

位于奶酪下表面的 Jerry 想知道,在不破坏奶酪的情况下,能否利用已有的空洞跑到奶酪的上表面去?

  • 每个输入文件包含多组数据。
  • 第一行,包含一个正整数 T,代表该输入文件中所含的数据组数。
  • 接下来是 T 组数据,每组数据的格式如下:第一行包含三个正整数 n,h 和 r,两个数之间以一个空格分开,分别代表奶酪中空洞的数量,奶酪的高度和空洞的半径。
  • 接下来的 n 行,每行包含三个整数 x,y,z。x,y,z,两个数之间以一个空格分开,表示空洞球心坐标为 (x,y,z)(x,y,z)。

空间坐标系两点距离:dist(P1,P2)=(x1x2)2+(y1y2)2+(z1z2)2\mathrm{dist}(P_1,P_2)=\sqrt{(x_1-x_2)^2+(y_1-y_2)^2+(z_1-z_2)^2}

输入输出

输入:
3
2 4 1
0 0 1
0 0 3
2 5 1
0 0 1
0 0 4
2 5 2
0 0 2
2 0 4
输出:
Yes
No
Yes

核心思想

在自己实验了 3 个想法之后发现了一个及其巧妙的做法:为什么不将奶酪之上和奶酪之下看作两个空洞,在并查集操作完后直接判断他们是不是连通的不就好了吗?
条件也十分简单 (将 0 视为地下,n+1 视为奶酪之上) 则:

  • hole.y - r <= 0 <=> 与下底面连通
  • hole.y + r >= h <=> 与上表面连通
  • sqrt(power(hole1.x - hole2.x) + power(hole1.y - hole2.y) + power(hole1.z - hole2.z)) <= r * 2 <=> 互相连通

代码

#include <cmath>
#include <cstdio>
#include <iostream>
using namespace std;

int f[1002], n, h, r, t;
long long x[1002], y[1002], z[1002];

inline int find(int x) { return x == f[x] ? x : f[x] = find(f[x]); }

inline long long power(long long x) { return x * x; }

inline void merge(int x, int y)
{
    f[find(x)] = find(y);
    return;
}

int main()
{
    cin >> t;
    while (t--)
    {
        cin >> n >> h >> r;
        for (int i = 0; i <= n + 1; i++)
        {
            f[i] = i;
        }
        for (int i = 1; i <= n; i++)
        {
            cin >> x[i] >> y[i] >> z[i];
            if (z[i] - r <= 0)
                merge(0, i);
            if (z[i] + r >= h)
                merge(i, n + 1);
        }

        for (int i = 1; i <= n - 1; i++)
        {
            for (int j = i + 1; j <= n; j++)
            {
                if (sqrt(power(x[i] - x[j]) + power(y[i] - y[j]) + power(z[i] - z[j])) <= r * 2)
                    merge(i, j);
            }
        }

        if (find(0) == find(n + 1))
            printf("Yes\n");
        else
            printf("No\n");
    }
    return 0;
}