11// Adjacency List
2+ use std:: collections:: VecDeque ;
23type Graph = Vec < Vec < usize > > ;
34
45pub struct BipartiteMatching {
56 pub adj : Graph ,
67 pub num_vertices_grp1 : usize ,
78 pub num_vertices_grp2 : usize ,
8- // mt[i] = v is the matching of i in grp1 to v in grp2
9- pub mt : Vec < i32 > ,
9+ // mt1[i] = v is the matching of i in grp1 to v in grp2
10+ pub mt1 : Vec < i32 > ,
11+ pub mt2 : Vec < i32 > ,
1012 pub used : Vec < bool > ,
1113}
1214impl BipartiteMatching {
@@ -15,15 +17,15 @@ impl BipartiteMatching {
1517 adj : vec ! [ vec![ ] ; num_vertices_grp1 + 1 ] ,
1618 num_vertices_grp1,
1719 num_vertices_grp2,
18- mt : vec ! [ -1 ; num_vertices_grp2 + 1 ] ,
20+ mt2 : vec ! [ -1 ; num_vertices_grp2 + 1 ] ,
21+ mt1 : vec ! [ -1 ; num_vertices_grp1 + 1 ] ,
1922 used : vec ! [ false ; num_vertices_grp1 + 1 ] ,
2023 }
2124 }
2225 #[ inline]
23- // Add an undirected edge u-v in the graph
26+ // Add an directed edge u-> v in the graph
2427 pub fn add_edge ( & mut self , u : usize , v : usize ) {
2528 self . adj [ u] . push ( v) ;
26- // self.adj[v].push(u);
2729 }
2830
2931 fn try_kuhn ( & mut self , cur : usize ) -> bool {
@@ -33,34 +35,111 @@ impl BipartiteMatching {
3335 self . used [ cur] = true ;
3436 for i in 0 ..self . adj [ cur] . len ( ) {
3537 let to = self . adj [ cur] [ i] ;
36- if self . mt [ to] == -1 || self . try_kuhn ( self . mt [ to] as usize ) {
37- self . mt [ to] = cur as i32 ;
38+ if self . mt2 [ to] == -1 || self . try_kuhn ( self . mt2 [ to] as usize ) {
39+ self . mt2 [ to] = cur as i32 ;
3840 return true ;
3941 }
4042 }
4143 false
4244 }
45+ // Note: It does not modify self.mt1, it only works on self.mt2
4346 pub fn kuhn ( & mut self ) {
44- self . mt = vec ! [ -1 ; self . num_vertices_grp2 + 1 ] ;
47+ self . mt2 = vec ! [ -1 ; self . num_vertices_grp2 + 1 ] ;
4548 for v in 1 ..self . num_vertices_grp1 + 1 {
4649 self . used = vec ! [ false ; self . num_vertices_grp1 + 1 ] ;
4750 self . try_kuhn ( v) ;
4851 }
4952 }
5053 pub fn print_matching ( & self ) {
5154 for i in 1 ..self . num_vertices_grp2 + 1 {
52- if self . mt [ i] == -1 {
55+ if self . mt2 [ i] == -1 {
5356 continue ;
5457 }
55- println ! ( "Vertex {} in grp1 matched with {} grp2" , self . mt [ i] , i)
58+ println ! ( "Vertex {} in grp1 matched with {} grp2" , self . mt2 [ i] , i)
5659 }
5760 }
61+ fn bfs ( & self , dist : & mut [ i32 ] ) -> bool {
62+ let mut q = VecDeque :: new ( ) ;
63+ for ( u, d_i) in dist
64+ . iter_mut ( )
65+ . enumerate ( )
66+ . skip ( 1 )
67+ . take ( self . num_vertices_grp1 )
68+ {
69+ if self . mt1 [ u] == 0 {
70+ // u is not matched
71+ * d_i = 0 ;
72+ q. push_back ( u) ;
73+ } else {
74+ // else set the vertex distance as infinite because it is matched
75+ // this will be considered the next time
76+
77+ * d_i = i32:: max_value ( ) ;
78+ }
79+ }
80+ dist[ 0 ] = i32:: max_value ( ) ;
81+ while !q. is_empty ( ) {
82+ let u = * q. front ( ) . unwrap ( ) ;
83+ q. pop_front ( ) ;
84+ if dist[ u] < dist[ 0 ] {
85+ for i in 0 ..self . adj [ u] . len ( ) {
86+ let v = self . adj [ u] [ i] ;
87+ if dist[ self . mt2 [ v] as usize ] == i32:: max_value ( ) {
88+ dist[ self . mt2 [ v] as usize ] = dist[ u] + 1 ;
89+ q. push_back ( self . mt2 [ v] as usize ) ;
90+ }
91+ }
92+ }
93+ }
94+ dist[ 0 ] != i32:: max_value ( )
95+ }
96+ fn dfs ( & mut self , u : i32 , dist : & mut Vec < i32 > ) -> bool {
97+ if u == 0 {
98+ return true ;
99+ }
100+ for i in 0 ..self . adj [ u as usize ] . len ( ) {
101+ let v = self . adj [ u as usize ] [ i] ;
102+ if dist[ self . mt2 [ v] as usize ] == dist[ u as usize ] + 1 && self . dfs ( self . mt2 [ v] , dist) {
103+ self . mt2 [ v] = u;
104+ self . mt1 [ u as usize ] = v as i32 ;
105+ return true ;
106+ }
107+ }
108+ dist[ u as usize ] = i32:: max_value ( ) ;
109+ false
110+ }
111+ pub fn hopcroft_karp ( & mut self ) -> i32 {
112+ // NOTE: how to use: https://cses.fi/paste/7558dba8d00436a847eab8/
113+ self . mt2 = vec ! [ 0 ; self . num_vertices_grp2 + 1 ] ;
114+ self . mt1 = vec ! [ 0 ; self . num_vertices_grp1 + 1 ] ;
115+ let mut dist = vec ! [ i32 :: max_value( ) ; self . num_vertices_grp1 + 1 ] ;
116+ let mut res = 0 ;
117+ while self . bfs ( & mut dist) {
118+ for u in 1 ..self . num_vertices_grp1 + 1 {
119+ if self . mt1 [ u] == 0 && self . dfs ( u as i32 , & mut dist) {
120+ res += 1 ;
121+ }
122+ }
123+ }
124+ // for x in self.mt2 change x to -1 if it is 0
125+ for x in self . mt2 . iter_mut ( ) {
126+ if * x == 0 {
127+ * x = -1 ;
128+ }
129+ }
130+ for x in self . mt1 . iter_mut ( ) {
131+ if * x == 0 {
132+ * x = -1 ;
133+ }
134+ }
135+ res
136+ }
58137}
59138#[ cfg( test) ]
60139mod tests {
61140 use super :: * ;
62141 #[ test]
63- fn small_graph ( ) {
142+ fn small_graph_kuhn ( ) {
64143 let n1 = 6 ;
65144 let n2 = 6 ;
66145 let mut g = BipartiteMatching :: new ( n1, n2) ;
@@ -78,29 +157,73 @@ mod tests {
78157 g. kuhn ( ) ;
79158 g. print_matching ( ) ;
80159 let answer: Vec < i32 > = vec ! [ -1 , 2 , -1 , 1 , 3 , 4 , 6 ] ;
81- for i in 1 ..g. mt . len ( ) {
82- if g. mt [ i] == -1 {
160+ for i in 1 ..g. mt2 . len ( ) {
161+ if g. mt2 [ i] == -1 {
162+ // 5 in group2 has no pair
163+ assert_eq ! ( i, 5 ) ;
164+ continue ;
165+ }
166+ // 2 in group1 has no pair
167+ assert ! ( g. mt2[ i] != 2 ) ;
168+ assert_eq ! ( i as i32 , answer[ g. mt2[ i] as usize ] ) ;
169+ }
170+ }
171+ #[ test]
172+ fn small_graph_hopcroft ( ) {
173+ let n1 = 6 ;
174+ let n2 = 6 ;
175+ let mut g = BipartiteMatching :: new ( n1, n2) ;
176+ // vertex 1 in grp1 to vertex 1 in grp 2
177+ // denote the ith grp2 vertex as n1+i
178+ g. add_edge ( 1 , 2 ) ;
179+ g. add_edge ( 1 , 3 ) ;
180+ // 2 is not connected to any vertex
181+ g. add_edge ( 3 , 4 ) ;
182+ g. add_edge ( 3 , 1 ) ;
183+ g. add_edge ( 4 , 3 ) ;
184+ g. add_edge ( 5 , 3 ) ;
185+ g. add_edge ( 5 , 4 ) ;
186+ g. add_edge ( 6 , 6 ) ;
187+ let x = g. hopcroft_karp ( ) ;
188+ assert_eq ! ( x, 5 ) ;
189+ g. print_matching ( ) ;
190+ let answer: Vec < i32 > = vec ! [ -1 , 2 , -1 , 1 , 3 , 4 , 6 ] ;
191+ for i in 1 ..g. mt2 . len ( ) {
192+ if g. mt2 [ i] == -1 {
83193 // 5 in group2 has no pair
84194 assert_eq ! ( i, 5 ) ;
85195 continue ;
86196 }
87197 // 2 in group1 has no pair
88- assert ! ( g. mt [ i] != 2 ) ;
89- assert_eq ! ( i as i32 , answer[ g. mt [ i] as usize ] ) ;
198+ assert ! ( g. mt2 [ i] != 2 ) ;
199+ assert_eq ! ( i as i32 , answer[ g. mt2 [ i] as usize ] ) ;
90200 }
91201 }
92202 #[ test]
93- fn super_small_graph ( ) {
203+ fn super_small_graph_kuhn ( ) {
94204 let n1 = 1 ;
95205 let n2 = 1 ;
96206 let mut g = BipartiteMatching :: new ( n1, n2) ;
97207 g. add_edge ( 1 , 1 ) ;
98208 g. kuhn ( ) ;
99209 g. print_matching ( ) ;
100- assert_eq ! ( g. mt [ 1 ] , 1 ) ;
210+ assert_eq ! ( g. mt2 [ 1 ] , 1 ) ;
101211 }
102212 #[ test]
103- fn only_one_vertex_graph ( ) {
213+ fn super_small_graph_hopcroft ( ) {
214+ let n1 = 1 ;
215+ let n2 = 1 ;
216+ let mut g = BipartiteMatching :: new ( n1, n2) ;
217+ g. add_edge ( 1 , 1 ) ;
218+ let x = g. hopcroft_karp ( ) ;
219+ assert_eq ! ( x, 1 ) ;
220+ g. print_matching ( ) ;
221+ assert_eq ! ( g. mt2[ 1 ] , 1 ) ;
222+ assert_eq ! ( g. mt1[ 1 ] , 1 ) ;
223+ }
224+
225+ #[ test]
226+ fn only_one_vertex_graph_kuhn ( ) {
104227 let n1 = 10 ;
105228 let n2 = 10 ;
106229 let mut g = BipartiteMatching :: new ( n1, n2) ;
@@ -116,9 +239,32 @@ mod tests {
116239 g. add_edge ( 10 , 1 ) ;
117240 g. kuhn ( ) ;
118241 g. print_matching ( ) ;
119- assert_eq ! ( g. mt[ 1 ] , 1 ) ;
120- for i in 2 ..g. mt . len ( ) {
121- assert ! ( g. mt[ i] == -1 ) ;
242+ assert_eq ! ( g. mt2[ 1 ] , 1 ) ;
243+ for i in 2 ..g. mt2 . len ( ) {
244+ assert ! ( g. mt2[ i] == -1 ) ;
245+ }
246+ }
247+ #[ test]
248+ fn only_one_vertex_graph_hopcroft ( ) {
249+ let n1 = 10 ;
250+ let n2 = 10 ;
251+ let mut g = BipartiteMatching :: new ( n1, n2) ;
252+ g. add_edge ( 1 , 1 ) ;
253+ g. add_edge ( 2 , 1 ) ;
254+ g. add_edge ( 3 , 1 ) ;
255+ g. add_edge ( 4 , 1 ) ;
256+ g. add_edge ( 5 , 1 ) ;
257+ g. add_edge ( 6 , 1 ) ;
258+ g. add_edge ( 7 , 1 ) ;
259+ g. add_edge ( 8 , 1 ) ;
260+ g. add_edge ( 9 , 1 ) ;
261+ g. add_edge ( 10 , 1 ) ;
262+ let x = g. hopcroft_karp ( ) ;
263+ assert_eq ! ( x, 1 ) ;
264+ g. print_matching ( ) ;
265+ assert_eq ! ( g. mt2[ 1 ] , 1 ) ;
266+ for i in 2 ..g. mt2 . len ( ) {
267+ assert ! ( g. mt2[ i] == -1 ) ;
122268 }
123269 }
124270}
0 commit comments