Skip to content

Latest commit

 

History

History
397 lines (333 loc) · 9.98 KB

File metadata and controls

397 lines (333 loc) · 9.98 KB

English Version

题目描述

给定一个由 n 个节点组成的网络,用 n x n 个邻接矩阵 graph 表示。在节点网络中,只有当 graph[i][j] = 1 时,节点 i 能够直接连接到另一个节点 j

一些节点 initial 最初被恶意软件感染。只要两个节点直接连接,且其中至少一个节点受到恶意软件的感染,那么两个节点都将被恶意软件感染。这种恶意软件的传播将继续,直到没有更多的节点可以被这种方式感染。

假设 M(initial) 是在恶意软件停止传播之后,整个网络中感染恶意软件的最终节点数。

我们可以从 initial删除一个节点并完全移除该节点以及从该节点到任何其他节点的任何连接。

请返回移除后能够使 M(initial) 最小化的节点。如果有多个节点满足条件,返回索引 最小的节点

 

示例 1:

输出:graph = [[1,1,0],[1,1,0],[0,0,1]], initial = [0,1]
输入:0

示例 2:

输入:graph = [[1,1,0],[1,1,1],[0,1,1]], initial = [0,1]
输出:1

示例 3:

输入:graph = [[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]], initial = [0,1]
输出:1

 

提示:

  • n == graph.length
  • n == graph[i].length
  • 2 <= n <= 300
  • graph[i][j] 是 0 或 1.
  • graph[i][j] == graph[j][i]
  • graph[i][i] == 1
  • 1 <= initial.length < n
  • 0 <= initial[i] <= n - 1
  •  initial 中每个整数都不同

解法

逆向思维并查集。对于本题,先遍历所有未被感染的节点(即不在 initial 列表的节点),构造并查集,并且在集合根节点维护 size,表示当前集合的节点数。

然后找到只被一个 initial 节点感染的集合,求得感染节点数的最小值。

被某个 initial 节点感染的集合,节点数越多,若移除此 initial 节点,感染的节点数就越少。

以下是并查集的几个常用模板。

模板 1——朴素并查集:

# 初始化,p存储每个点的父节点
p = list(range(n))

# 返回x的祖宗节点
def find(x):
    if p[x] != x:
        # 路径压缩
        p[x] = find(p[x])
    return p[x]


# 合并a和b所在的两个集合
p[find(a)] = find(b)

模板 2——维护 size 的并查集:

# 初始化,p存储每个点的父节点,size只有当节点是祖宗节点时才有意义,表示祖宗节点所在集合中,点的数量
p = list(range(n))
size = [1] * n

# 返回x的祖宗节点
def find(x):
    if p[x] != x:
        # 路径压缩
        p[x] = find(p[x])
    return p[x]

# 合并a和b所在的两个集合
if find(a) != find(b):
    size[find(b)] += size[find(a)]
    p[find(a)] = find(b)

模板 3——维护到祖宗节点距离的并查集:

# 初始化,p存储每个点的父节点,d[x]存储x到p[x]的距离
p = list(range(n))
d = [0] * n

# 返回x的祖宗节点
def find(x):
    if p[x] != x:
        t = find(p[x])
        d[x] += d[p[x]]
        p[x] = t
    return p[x]

# 合并a和b所在的两个集合
p[find(a)] = find(b)
d[find(a)] = distance

Python3

class Solution:
    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        def find(x):
            if p[x] != x:
                p[x] = find(p[x])
            return p[x]

        def union(a, b):
            pa, pb = find(a), find(b)
            if pa != pb:
                size[pb] += size[pa]
                p[pa] = pb

        n = len(graph)
        p = list(range(n))
        size = [1] * n
        clean = [True] * n
        for i in initial:
            clean[i] = False
        for i in range(n):
            if not clean[i]:
                continue
            for j in range(i + 1, n):
                if clean[j] and graph[i][j] == 1:
                    union(i, j)
        cnt = Counter()
        mp = {}
        for i in initial:
            s = {find(j) for j in range(n) if clean[j] and graph[i][j] == 1}
            for root in s:
                cnt[root] += 1
            mp[i] = s

        mx, ans = -1, 0
        for i, s in mp.items():
            t = sum(size[root] for root in s if cnt[root] == 1)
            if mx < t or mx == t and i < ans:
                mx, ans = t, i
        return ans

Java

class Solution {
    private int[] p;
    private int[] size;

    public int minMalwareSpread(int[][] graph, int[] initial) {
        int n = graph.length;
        p = new int[n];
        size = new int[n];
        for (int i = 0; i < n; ++i) {
            p[i] = i;
            size[i] = 1;
        }
        boolean[] clean = new boolean[n];
        Arrays.fill(clean, true);
        for (int i : initial) {
            clean[i] = false;
        }
        for (int i = 0; i < n; ++i) {
            if (!clean[i]) {
                continue;
            }
            for (int j = i + 1; j < n; ++j) {
                if (clean[j] && graph[i][j] == 1) {
                    union(i, j);
                }
            }
        }
        int[] cnt = new int[n];
        Map<Integer, Set<Integer>> mp = new HashMap<>();
        for (int i : initial) {
            Set<Integer> s = new HashSet<>();
            for (int j = 0; j < n; ++j) {
                if (clean[j] && graph[i][j] == 1) {
                    s.add(find(j));
                }
            }
            for (int root : s) {
                cnt[root] += 1;
            }
            mp.put(i, s);
        }
        int mx = -1;
        int ans = 0;
        for (Map.Entry<Integer, Set<Integer>> entry : mp.entrySet()) {
            int i = entry.getKey();
            int t = 0;
            for (int root : entry.getValue()) {
                if (cnt[root] == 1) {
                    t += size[root];
                }
            }
            if (mx < t || (mx == t && i < ans)) {
                mx = t;
                ans = i;
            }
        }
        return ans;
    }

    private int find(int x) {
        if (p[x] != x) {
            p[x] = find(p[x]);
        }
        return p[x];
    }

    private void union(int a, int b) {
        int pa = find(a);
        int pb = find(b);
        if (pa != pb) {
            size[pb] += size[pa];
            p[pa] = pb;
        }
    }
}

C++

class Solution {
public:
    vector<int> p;
    vector<int> size;

    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        int n = graph.size();
        p.resize(n);
        size.resize(n);
        for (int i = 0; i < n; ++i) p[i] = i;
        fill(size.begin(), size.end(), 1);
        vector<bool> clean(n, true);
        for (int i : initial) clean[i] = false;
        for (int i = 0; i < n; ++i) {
            if (!clean[i]) continue;
            for (int j = i + 1; j < n; ++j)
                if (clean[j] && graph[i][j] == 1) merge(i, j);
        }
        vector<int> cnt(n, 0);
        unordered_map<int, unordered_set<int>> mp;
        for (int i : initial) {
            unordered_set<int> s;
            for (int j = 0; j < n; ++j)
                if (clean[j] && graph[i][j] == 1) s.insert(find(j));
            for (int e : s) ++cnt[e];
            mp[i] = s;
        }
        int mx = -1, ans = 0;
        for (auto& [i, s] : mp) {
            int t = 0;
            for (int root : s)
                if (cnt[root] == 1)
                    t += size[root];
            if (mx < t || (mx == t && i < ans)) {
                mx = t;
                ans = i;
            }
        }
        return ans;
    }

    int find(int x) {
        if (p[x] != x) p[x] = find(p[x]);
        return p[x];
    }

    void merge(int a, int b) {
        int pa = find(a), pb = find(b);
        if (pa != pb) {
            size[pb] += size[pa];
            p[pa] = pb;
        }
    }
};

Go

func minMalwareSpread(graph [][]int, initial []int) int {
	n := len(graph)
	p := make([]int, n)
	size := make([]int, n)
	clean := make([]bool, n)
	for i := 0; i < n; i++ {
		p[i], size[i], clean[i] = i, 1, true
	}
	for _, i := range initial {
		clean[i] = false
	}

	var find func(x int) int
	find = func(x int) int {
		if p[x] != x {
			p[x] = find(p[x])
		}
		return p[x]
	}
	union := func(a, b int) {
		pa, pb := find(a), find(b)
		if pa != pb {
			size[pb] += size[pa]
			p[pa] = pb
		}
	}

	for i := 0; i < n; i++ {
		if !clean[i] {
			continue
		}
		for j := i + 1; j < n; j++ {
			if clean[j] && graph[i][j] == 1 {
				union(i, j)
			}
		}
	}
	cnt := make([]int, n)
	mp := make(map[int]map[int]bool)
	for _, i := range initial {
		s := make(map[int]bool)
		for j := 0; j < n; j++ {
			if clean[j] && graph[i][j] == 1 {
				s[find(j)] = true
			}
		}
		for root, _ := range s {
			cnt[root]++
		}
		mp[i] = s
	}
	mx, ans := -1, 0
	for i, s := range mp {
		t := 0
		for root, _ := range s {
			if cnt[root] == 1 {
				t += size[root]
			}
		}
		if mx < t || (mx == t && i < ans) {
			mx, ans = t, i
		}
	}
	return ans
}

...