@@ -26,12 +26,17 @@ protected sealed class ExtractionContext
26
26
/// <summary>
27
27
/// Specific schemas to extract
28
28
/// </summary>
29
- public readonly Dictionary < string , Schema > TargetSchemes = new Dictionary < string , Schema > ( ) ;
29
+ public readonly Dictionary < string , Schema > TargetSchemes = new ( ) ;
30
30
31
31
/// <summary>
32
- /// Extracted users.
32
+ /// Extracted users (subset of <see cref="RoleLookup"/>) .
33
33
/// </summary>
34
- public readonly Dictionary < long , string > UserLookup = new Dictionary < long , string > ( ) ;
34
+ public readonly Dictionary < long , string > UserLookup = new ( ) ;
35
+
36
+ /// <summary>
37
+ /// Extracted roles.
38
+ /// </summary>
39
+ public readonly Dictionary < long , string > RoleLookup = new ( ) ;
35
40
36
41
/// <summary>
37
42
/// Catalog to extract information.
@@ -41,46 +46,54 @@ protected sealed class ExtractionContext
41
46
/// <summary>
42
47
/// Extracted schemas.
43
48
/// </summary>
44
- public readonly Dictionary < long , Schema > SchemaMap = new Dictionary < long , Schema > ( ) ;
49
+ public readonly Dictionary < long , Schema > SchemaMap = new ( ) ;
45
50
46
51
/// <summary>
47
52
/// Extracted schemas identifiers.
48
53
/// </summary>
49
- public readonly Dictionary < Schema , long > ReversedSchemaMap = new Dictionary < Schema , long > ( ) ;
54
+ public readonly Dictionary < Schema , long > ReversedSchemaMap = new ( ) ;
50
55
51
56
/// <summary>
52
57
/// Extracted tables.
53
58
/// </summary>
54
- public readonly Dictionary < long , Table > TableMap = new Dictionary < long , Table > ( ) ;
59
+ public readonly Dictionary < long , Table > TableMap = new ( ) ;
55
60
56
61
/// <summary>
57
62
/// Extracted views.
58
63
/// </summary>
59
- public readonly Dictionary < long , View > ViewMap = new Dictionary < long , View > ( ) ;
64
+ public readonly Dictionary < long , View > ViewMap = new ( ) ;
60
65
61
66
/// <summary>
62
67
/// Extracted sequences.
63
68
/// </summary>
64
- public readonly Dictionary < long , Sequence > SequenceMap = new Dictionary < long , Sequence > ( ) ;
69
+ public readonly Dictionary < long , Sequence > SequenceMap = new ( ) ;
65
70
66
71
/// <summary>
67
72
/// Extracted index expressions.
68
73
/// </summary>
69
- public readonly Dictionary < long , ExpressionIndexInfo > ExpressionIndexMap = new Dictionary < long , ExpressionIndexInfo > ( ) ;
74
+ public readonly Dictionary < long , ExpressionIndexInfo > ExpressionIndexMap = new ( ) ;
70
75
71
76
/// <summary>
72
77
/// Extracted domains.
73
78
/// </summary>
74
- public readonly Dictionary < long , Domain > DomainMap = new Dictionary < long , Domain > ( ) ;
79
+ public readonly Dictionary < long , Domain > DomainMap = new ( ) ;
75
80
76
81
/// <summary>
77
82
/// Extracted columns connected grouped by owner (table or view)
78
83
/// </summary>
79
- public readonly Dictionary < long , Dictionary < long , TableColumn > > TableColumnMap = new Dictionary < long , Dictionary < long , TableColumn > > ( ) ;
84
+ public readonly Dictionary < long , Dictionary < long , TableColumn > > TableColumnMap = new ( ) ;
80
85
86
+ /// <summary>
87
+ /// Roles in which current user is a member, self included.
88
+ /// </summary>
89
+ public readonly List < long > CurrentUserRoles = new ( ) ;
90
+
91
+ public string CurrentUserName { get ; set ; }
81
92
public long CurrentUserSysId { get ; set ; } = - 1 ;
82
93
public long ? CurrentUserIdentifier { get ; set ; }
83
94
95
+ public bool IsMe ( string name ) => name == CurrentUserName ;
96
+
84
97
public ExtractionContext ( Catalog catalog )
85
98
{
86
99
Catalog = catalog ;
@@ -332,7 +345,7 @@ public override Catalog ExtractSchemes(string catalogName, string[] schemaNames)
332
345
{
333
346
var ( catalog , context ) = CreateCatalogAndContext ( catalogName , schemaNames ) ;
334
347
335
- ExtractUsers ( context ) ;
348
+ _ = ExtractRoles ( context , false ) ;
336
349
ExtractSchemas ( context ) ;
337
350
return catalog ;
338
351
}
@@ -343,7 +356,7 @@ public override async Task<Catalog> ExtractSchemesAsync(
343
356
{
344
357
var ( catalog , context ) = CreateCatalogAndContext ( catalogName , schemaNames ) ;
345
358
346
- await ExtractUsersAsync ( context , token ) . ConfigureAwait ( false ) ;
359
+ await ExtractRoles ( context , true , token ) . ConfigureAwait ( false ) ;
347
360
await ExtractSchemasAsync ( context , token ) . ConfigureAwait ( false ) ;
348
361
return catalog ;
349
362
}
@@ -360,48 +373,62 @@ private static (Catalog catalog, ExtractionContext context) CreateCatalogAndCont
360
373
return ( catalog , context ) ;
361
374
}
362
375
363
- private void ExtractUsers ( ExtractionContext context )
376
+ private async ValueTask ExtractRoles ( ExtractionContext context , bool isAsync , CancellationToken token = default )
364
377
{
365
378
context . UserLookup . Clear ( ) ;
366
- string me ;
367
- using ( var command = Connection . CreateCommand ( "SELECT user" ) ) {
368
- me = ( string ) command . ExecuteScalar ( ) ;
369
- }
370
379
371
- using ( var cmd = Connection . CreateCommand ( "SELECT usename, usesysid FROM pg_user" ) )
372
- using ( var dr = cmd . ExecuteReader ( ) ) {
373
- while ( dr . Read ( ) ) {
374
- ReadUserData ( dr , context , me ) ;
380
+ var extractCurentUserCommand = Connection . CreateCommand ( "SELECT user" ) ;
381
+ // Roles include users.
382
+ // Users also can have members for some reason and it doesn't make them groups,
383
+ // the only thing that defines user is ability to login :-)
384
+ const string ExtractRolesQueryTemplate = "SELECT rolname, oid, rolcanlogin, pg_has_role('{0}', oid,'member') FROM pg_roles" ;
385
+
386
+
387
+ if ( isAsync ) {
388
+ await using ( extractCurentUserCommand . ConfigureAwait ( false ) ) {
389
+ context . CurrentUserName = ( string ) await extractCurentUserCommand . ExecuteScalarAsync ( token ) . ConfigureAwait ( false ) ;
375
390
}
376
- }
377
- }
378
391
379
- private async Task ExtractUsersAsync ( ExtractionContext context , CancellationToken token = default )
380
- {
381
- context . UserLookup . Clear ( ) ;
382
- string me ;
383
- var command = Connection . CreateCommand ( "SELECT user" ) ;
384
- await using ( command . ConfigureAwait ( false ) ) {
385
- me = ( string ) await command . ExecuteScalarAsync ( token ) . ConfigureAwait ( false ) ;
392
+ var getAllUsersCommand = Connection . CreateCommand ( string . Format ( ExtractRolesQueryTemplate , context . CurrentUserName ) ) ;
393
+ await using ( getAllUsersCommand . ConfigureAwait ( false ) ) {
394
+ var reader = await getAllUsersCommand . ExecuteReaderAsync ( token ) . ConfigureAwait ( false ) ;
395
+ await using ( reader . ConfigureAwait ( false ) ) {
396
+ while ( await reader . ReadAsync ( token ) . ConfigureAwait ( false ) ) {
397
+ ReadUserData ( reader , context ) ;
398
+ }
399
+ }
400
+ }
386
401
}
402
+ else {
403
+ using ( extractCurentUserCommand ) {
404
+ context . CurrentUserName = ( string ) extractCurentUserCommand . ExecuteScalar ( ) ;
405
+ }
387
406
388
- command = Connection . CreateCommand ( "SELECT usename, usesysid FROM pg_user" ) ;
389
- await using ( command . ConfigureAwait ( false ) ) {
390
- var reader = await command . ExecuteReaderAsync ( token ) . ConfigureAwait ( false ) ;
391
- await using ( reader . ConfigureAwait ( false ) ) {
392
- while ( await reader . ReadAsync ( token ) . ConfigureAwait ( false ) ) {
393
- ReadUserData ( reader , context , me ) ;
407
+ var getAllUsersCommand = Connection . CreateCommand ( string . Format ( ExtractRolesQueryTemplate , context . CurrentUserName ) ) ;
408
+ using ( getAllUsersCommand )
409
+ using ( var dr = getAllUsersCommand . ExecuteReader ( ) ) {
410
+ while ( dr . Read ( ) ) {
411
+ ReadUserData ( dr , context ) ;
394
412
}
395
413
}
396
414
}
397
415
}
398
416
399
- private static void ReadUserData ( DbDataReader dr , ExtractionContext context , string me )
417
+ private static void ReadUserData ( DbDataReader dr , ExtractionContext context )
400
418
{
401
- var name = dr [ 0 ] . ToString ( ) ;
419
+ var name = dr . GetString ( 0 ) ;
420
+ // oid, which is basically a number, has its own type - oid! can't be read as int or long (facepalm)
402
421
var sysId = Convert . ToInt64 ( dr [ 1 ] ) ;
403
- context . UserLookup . Add ( sysId , name ) ;
404
- if ( name == me ) {
422
+ var canLogin = dr . GetBoolean ( 2 ) ;
423
+ var containsCurrentUser = dr . GetBoolean ( 3 ) ;
424
+ context . RoleLookup . Add ( sysId , name ) ;
425
+ if ( containsCurrentUser ) {
426
+ context . CurrentUserRoles . Add ( sysId ) ;
427
+ }
428
+ if ( canLogin ) {
429
+ context . UserLookup . Add ( sysId , name ) ;
430
+ }
431
+ if ( context . IsMe ( name ) ) {
405
432
context . CurrentUserSysId = sysId ;
406
433
}
407
434
}
@@ -499,7 +526,11 @@ protected virtual SqlQueryExpression BuildExtractSchemasQuery(ExtractionContext
499
526
selectPublic . Columns . Add ( namespaceTable1 [ "nspowner" ] ) ;
500
527
501
528
var selectMine = SqlDml . Select ( namespaceTable2 ) ;
502
- selectMine . Where = namespaceTable2 [ "nspowner" ] == context . CurrentUserIdentifier ;
529
+ if ( context . CurrentUserRoles . Count == 0 )
530
+ selectMine . Where = namespaceTable2 [ "nspowner" ] == context . CurrentUserIdentifier ;
531
+ else {
532
+ selectMine . Where = SqlDml . In ( namespaceTable2 [ "nspowner" ] , SqlDml . Array ( context . CurrentUserRoles . ToArray ( ) ) ) ;
533
+ }
503
534
selectMine . Columns . Add ( namespaceTable2 [ "nspname" ] ) ;
504
535
selectMine . Columns . Add ( namespaceTable2 [ "oid" ] ) ;
505
536
selectMine . Columns . Add ( namespaceTable2 [ "nspowner" ] ) ;
@@ -522,7 +553,12 @@ protected virtual void ReadSchemaData(DbDataReader dataReader, ExtractionContext
522
553
catalog . DefaultSchema = schema ;
523
554
}
524
555
525
- schema . Owner = context . UserLookup [ owner ] ;
556
+ if ( context . RoleLookup . TryGetValue ( owner , out var userName ) ) {
557
+ schema . Owner = userName ;
558
+ }
559
+ else {
560
+ throw new InvalidOperationException ( string . Format ( Resources . Strings . ExCantFindSchemaXOwnerWithIdYInTheListOfRoles , name , owner ) ) ;
561
+ }
526
562
context . SchemaMap [ oid ] = schema ;
527
563
context . ReversedSchemaMap [ schema ] = oid ;
528
564
}
@@ -594,8 +630,8 @@ protected virtual ISqlCompileUnit BuildExtractSchemaContentsQuery(ExtractionCont
594
630
select . Columns . Add ( relationsTable [ "relnamespace" ] ) ;
595
631
select . Columns . Add ( tablespacesTable [ "spcname" ] ) ;
596
632
select . Columns . Add ( new Func < SqlCase > ( ( ) => {
597
- var defCase = SqlDml . Case ( relationsTable [ "relkind" ] ) ;
598
- defCase . Add ( 'v' , SqlDml . FunctionCall ( "pg_get_viewdef" , relationsTable [ "oid" ] ) ) ;
633
+ var defCase = SqlDml . Case ( relationsTable [ "relkind" ] )
634
+ . Add ( 'v' , SqlDml . FunctionCall ( "pg_get_viewdef" , relationsTable [ "oid" ] ) ) ;
599
635
return defCase ;
600
636
} ) ( ) , "definition" ) ;
601
637
return select ;
@@ -741,7 +777,7 @@ protected virtual void ReadColumnData(DbDataReader dataReader, ExtractionContext
741
777
}
742
778
else {
743
779
var view = viewMap [ columnOwnerId ] ;
744
- view . CreateColumn ( columnName ) ;
780
+ _ = view . CreateColumn ( columnName ) ;
745
781
}
746
782
}
747
783
@@ -912,8 +948,9 @@ protected virtual int ReadTableIndexData(DbDataReader dataReader, ExtractionCont
912
948
else {
913
949
for ( int j = 0 ; j < indexKey . Length ; j ++ ) {
914
950
int colIndex = indexKey [ j ] ;
915
- if ( colIndex > 0 )
916
- index . CreateIndexColumn ( tableColumns [ tableIdentifier ] [ colIndex ] , true ) ;
951
+ if ( colIndex > 0 ) {
952
+ _ = index . CreateIndexColumn ( tableColumns [ tableIdentifier ] [ colIndex ] , true ) ;
953
+ }
917
954
else {
918
955
//column index is 0
919
956
//this means that this index column is an expression
@@ -967,12 +1004,9 @@ protected virtual void ReadIndexColumnsData(DbDataReader dataReader, ExtractionC
967
1004
var exprIndexInfo = expressionIndexMap [ Convert . ToInt64 ( dataReader [ 1 ] ) ] ;
968
1005
for ( var j = 0 ; j < exprIndexInfo . Columns . Length ; j ++ ) {
969
1006
int colIndex = exprIndexInfo . Columns [ j ] ;
970
- if ( colIndex > 0 ) {
971
- exprIndexInfo . Index . CreateIndexColumn ( tableColumns [ Convert . ToInt64 ( dataReader [ 0 ] ) ] [ colIndex ] , true ) ;
972
- }
973
- else {
974
- exprIndexInfo . Index . CreateIndexColumn ( SqlDml . Native ( dataReader [ ( j + 1 ) . ToString ( ) ] . ToString ( ) ) ) ;
975
- }
1007
+ _ = colIndex > 0
1008
+ ? exprIndexInfo . Index . CreateIndexColumn ( tableColumns [ Convert . ToInt64 ( dataReader [ 0 ] ) ] [ colIndex ] , true )
1009
+ : exprIndexInfo . Index . CreateIndexColumn ( SqlDml . Native ( dataReader [ ( j + 1 ) . ToString ( ) ] . ToString ( ) ) ) ;
976
1010
}
977
1011
}
978
1012
0 commit comments