@@ -2,13 +2,17 @@ use hir::Semantics;
2
2
use ide_db:: {
3
3
base_db:: { FileId , FileRange } ,
4
4
defs:: Definition ,
5
- search:: SearchScope ,
5
+ search:: { SearchScope , UsageSearchResult } ,
6
6
RootDatabase ,
7
7
} ;
8
8
use syntax:: {
9
- ast:: { self , make:: impl_trait_type, HasGenericParams , HasName , HasTypeBounds } ,
10
- ted, AstNode ,
9
+ ast:: {
10
+ self , make:: impl_trait_type, HasGenericParams , HasName , HasTypeBounds , Name , NameLike ,
11
+ PathType ,
12
+ } ,
13
+ match_ast, ted, AstNode ,
11
14
} ;
15
+ use text_edit:: TextRange ;
12
16
13
17
use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
14
18
@@ -36,87 +40,131 @@ pub(crate) fn replace_named_generic_with_impl(
36
40
let type_bound_list = type_param. type_bound_list ( ) ?;
37
41
38
42
let fn_ = type_param. syntax ( ) . ancestors ( ) . find_map ( ast:: Fn :: cast) ?;
39
- let params = fn_
40
- . param_list ( ) ?
41
- . params ( )
42
- . filter_map ( |param| {
43
- // function parameter type needs to match generic type name
44
- if let ast:: Type :: PathType ( path_type) = param. ty ( ) ? {
45
- let left = path_type. path ( ) ?. segment ( ) ?. name_ref ( ) ?. ident_token ( ) ?. to_string ( ) ;
46
- let right = type_param_name. to_string ( ) ;
47
- if left == right {
48
- Some ( param)
49
- } else {
50
- None
51
- }
52
- } else {
53
- None
54
- }
55
- } )
56
- . collect :: < Vec < _ > > ( ) ;
57
-
58
- if params. is_empty ( ) {
59
- return None ;
60
- }
43
+ let param_list_text_range = fn_. param_list ( ) ?. syntax ( ) . text_range ( ) ;
61
44
62
45
let type_param_hir_def = ctx. sema . to_def ( & type_param) ?;
63
46
let type_param_def = Definition :: GenericParam ( hir:: GenericParam :: TypeParam ( type_param_hir_def) ) ;
64
47
65
- if is_referenced_outside ( & ctx. sema , type_param_def, & fn_, ctx. file_id ( ) ) {
48
+ // get all usage references for the type param
49
+ let usage_refs = find_usages ( & ctx. sema , & fn_, type_param_def, ctx. file_id ( ) ) ;
50
+ if usage_refs. is_empty ( ) {
66
51
return None ;
67
52
}
68
53
54
+ // All usage references need to be valid (inside the function param list)
55
+ if !check_valid_usages ( & usage_refs, param_list_text_range) {
56
+ return None ;
57
+ }
58
+
59
+ let mut path_types_to_replace = Vec :: new ( ) ;
60
+ for ( _a, refs) in usage_refs. iter ( ) {
61
+ for usage_ref in refs {
62
+ let param_node = find_path_type ( & ctx. sema , & type_param_name, & usage_ref. name ) ?;
63
+ path_types_to_replace. push ( param_node) ;
64
+ }
65
+ }
66
+
69
67
let target = type_param. syntax ( ) . text_range ( ) ;
70
68
71
69
acc. add (
72
70
AssistId ( "replace_named_generic_with_impl" , AssistKind :: RefactorRewrite ) ,
73
- "Replace named generic with impl" ,
71
+ "Replace named generic with impl trait " ,
74
72
target,
75
73
|edit| {
76
74
let type_param = edit. make_mut ( type_param) ;
77
75
let fn_ = edit. make_mut ( fn_) ;
78
76
79
- // get all params
80
- let param_types = params
81
- . iter ( )
82
- . filter_map ( |param| match param. ty ( ) {
83
- Some ( ast:: Type :: PathType ( param_type) ) => Some ( edit. make_mut ( param_type) ) ,
84
- _ => None ,
85
- } )
77
+ let path_types_to_replace = path_types_to_replace
78
+ . into_iter ( )
79
+ . map ( |param| edit. make_mut ( param) )
86
80
. collect :: < Vec < _ > > ( ) ;
87
81
82
+ // remove trait from generic param list
88
83
if let Some ( generic_params) = fn_. generic_param_list ( ) {
89
84
generic_params. remove_generic_param ( ast:: GenericParam :: TypeParam ( type_param) ) ;
90
85
if generic_params. generic_params ( ) . count ( ) == 0 {
91
86
ted:: remove ( generic_params. syntax ( ) ) ;
92
87
}
93
88
}
94
89
95
- // get type bounds in signature type: `P` -> `impl AsRef<Path>`
96
90
let new_bounds = impl_trait_type ( type_bound_list) ;
97
- for param_type in param_types . iter ( ) . rev ( ) {
98
- ted:: replace ( param_type . syntax ( ) , new_bounds. clone_for_update ( ) . syntax ( ) ) ;
91
+ for path_type in path_types_to_replace . iter ( ) . rev ( ) {
92
+ ted:: replace ( path_type . syntax ( ) , new_bounds. clone_for_update ( ) . syntax ( ) ) ;
99
93
}
100
94
} ,
101
95
)
102
96
}
103
97
104
- fn is_referenced_outside (
98
+ fn find_path_type (
99
+ sema : & Semantics < ' _ , RootDatabase > ,
100
+ type_param_name : & Name ,
101
+ param : & NameLike ,
102
+ ) -> Option < PathType > {
103
+ let path_type =
104
+ sema. ancestors_with_macros ( param. syntax ( ) . clone ( ) ) . find_map ( ast:: PathType :: cast) ?;
105
+
106
+ // Ignore any path types that look like `P::Assoc`
107
+ if path_type. path ( ) ?. as_single_name_ref ( ) ?. text ( ) != type_param_name. text ( ) {
108
+ return None ;
109
+ }
110
+
111
+ let ancestors = sema. ancestors_with_macros ( path_type. syntax ( ) . clone ( ) ) ;
112
+
113
+ let mut in_generic_arg_list = false ;
114
+ let mut is_associated_type = false ;
115
+
116
+ // walking the ancestors checks them in a heuristic way until the `Fn` node is reached.
117
+ for ancestor in ancestors {
118
+ match_ast ! {
119
+ match ancestor {
120
+ ast:: PathSegment ( ps) => {
121
+ match ps. kind( ) ? {
122
+ ast:: PathSegmentKind :: Name ( _name_ref) => ( ) ,
123
+ ast:: PathSegmentKind :: Type { .. } => return None ,
124
+ _ => return None ,
125
+ }
126
+ } ,
127
+ ast:: GenericArgList ( _) => {
128
+ in_generic_arg_list = true ;
129
+ } ,
130
+ ast:: AssocTypeArg ( _) => {
131
+ is_associated_type = true ;
132
+ } ,
133
+ ast:: ImplTraitType ( _) => {
134
+ if in_generic_arg_list && !is_associated_type {
135
+ return None ;
136
+ }
137
+ } ,
138
+ ast:: DynTraitType ( _) => {
139
+ if !is_associated_type {
140
+ return None ;
141
+ }
142
+ } ,
143
+ ast:: Fn ( _) => return Some ( path_type) ,
144
+ _ => ( ) ,
145
+ }
146
+ }
147
+ }
148
+
149
+ None
150
+ }
151
+
152
+ /// Returns all usage references for the given type parameter definition.
153
+ fn find_usages (
105
154
sema : & Semantics < ' _ , RootDatabase > ,
106
- type_param_def : Definition ,
107
155
fn_ : & ast:: Fn ,
156
+ type_param_def : Definition ,
108
157
file_id : FileId ,
109
- ) -> bool {
110
- // limit search scope to function body & return type
111
- let search_ranges = vec ! [
112
- fn_. body( ) . map( |body| body. syntax( ) . text_range( ) ) ,
113
- fn_. ret_type( ) . map( |ret_type| ret_type. syntax( ) . text_range( ) ) ,
114
- ] ;
115
-
116
- search_ranges. into_iter ( ) . flatten ( ) . any ( |search_range| {
117
- let file_range = FileRange { file_id, range : search_range } ;
118
- !type_param_def. usages ( sema) . in_scope ( SearchScope :: file_range ( file_range) ) . all ( ) . is_empty ( )
119
- } )
158
+ ) -> UsageSearchResult {
159
+ let file_range = FileRange { file_id, range : fn_. syntax ( ) . text_range ( ) } ;
160
+ type_param_def. usages ( sema) . in_scope ( SearchScope :: file_range ( file_range) ) . all ( )
161
+ }
162
+
163
+ fn check_valid_usages ( usages : & UsageSearchResult , param_list_range : TextRange ) -> bool {
164
+ usages
165
+ . iter ( )
166
+ . flat_map ( |( _, usage_refs) | usage_refs)
167
+ . all ( |usage_ref| param_list_range. contains_range ( usage_ref. range ) )
120
168
}
121
169
122
170
#[ cfg( test) ]
@@ -152,6 +200,96 @@ mod tests {
152
200
) ;
153
201
}
154
202
203
+ #[ test]
204
+ fn replace_generic_trait_applies_to_generic_arguments_in_params ( ) {
205
+ check_assist (
206
+ replace_named_generic_with_impl,
207
+ r#"
208
+ fn foo<P$0: Trait>(
209
+ _: P,
210
+ _: Option<P>,
211
+ _: Option<Option<P>>,
212
+ _: impl Iterator<Item = P>,
213
+ _: &dyn Iterator<Item = P>,
214
+ ) {}
215
+ "# ,
216
+ r#"
217
+ fn foo(
218
+ _: impl Trait,
219
+ _: Option<impl Trait>,
220
+ _: Option<Option<impl Trait>>,
221
+ _: impl Iterator<Item = impl Trait>,
222
+ _: &dyn Iterator<Item = impl Trait>,
223
+ ) {}
224
+ "# ,
225
+ ) ;
226
+ }
227
+
228
+ #[ test]
229
+ fn replace_generic_not_applicable_when_one_param_type_is_invalid ( ) {
230
+ check_assist_not_applicable (
231
+ replace_named_generic_with_impl,
232
+ r#"
233
+ fn foo<P$0: Trait>(
234
+ _: i32,
235
+ _: Option<P>,
236
+ _: Option<Option<P>>,
237
+ _: impl Iterator<Item = P>,
238
+ _: &dyn Iterator<Item = P>,
239
+ _: <P as Trait>::Assoc,
240
+ ) {}
241
+ "# ,
242
+ ) ;
243
+ }
244
+
245
+ #[ test]
246
+ fn replace_generic_not_applicable_when_referenced_in_where_clause ( ) {
247
+ check_assist_not_applicable (
248
+ replace_named_generic_with_impl,
249
+ r#"fn foo<P$0: Trait, I>() where I: FromRef<P> {}"# ,
250
+ ) ;
251
+ }
252
+
253
+ #[ test]
254
+ fn replace_generic_not_applicable_when_used_with_type_alias ( ) {
255
+ check_assist_not_applicable (
256
+ replace_named_generic_with_impl,
257
+ r#"fn foo<P$0: Trait>(p: <P as Trait>::Assoc) {}"# ,
258
+ ) ;
259
+ }
260
+
261
+ #[ test]
262
+ fn replace_generic_not_applicable_when_used_as_argument_in_outer_trait_alias ( ) {
263
+ check_assist_not_applicable (
264
+ replace_named_generic_with_impl,
265
+ r#"fn foo<P$0: Trait>(_: <() as OtherTrait<P>>::Assoc) {}"# ,
266
+ ) ;
267
+ }
268
+
269
+ #[ test]
270
+ fn replace_generic_not_applicable_with_inner_associated_type ( ) {
271
+ check_assist_not_applicable (
272
+ replace_named_generic_with_impl,
273
+ r#"fn foo<P$0: Trait>(_: P::Assoc) {}"# ,
274
+ ) ;
275
+ }
276
+
277
+ #[ test]
278
+ fn replace_generic_not_applicable_when_passed_into_outer_impl_trait ( ) {
279
+ check_assist_not_applicable (
280
+ replace_named_generic_with_impl,
281
+ r#"fn foo<P$0: Trait>(_: impl OtherTrait<P>) {}"# ,
282
+ ) ;
283
+ }
284
+
285
+ #[ test]
286
+ fn replace_generic_not_applicable_when_used_in_passed_function_parameter ( ) {
287
+ check_assist_not_applicable (
288
+ replace_named_generic_with_impl,
289
+ r#"fn foo<P$0: Trait>(_: &dyn Fn(P)) {}"# ,
290
+ ) ;
291
+ }
292
+
155
293
#[ test]
156
294
fn replace_generic_with_multiple_generic_params ( ) {
157
295
check_assist (
0 commit comments