2929reference to `A` to `B`'s `fields` attribute.
3030"""
3131
32-
3332import builtins
3433import re
35- import textwrap
3634from dataclasses import (
3735 dataclass ,
3836 field ,
4947)
5048
5149import betterproto
52- from betterproto import which_one_of
53- from betterproto .casing import sanitize_name
54- from betterproto .compile .importing import (
55- get_type_reference ,
56- parse_source_type_name ,
57- )
5850from betterproto .compile .naming import (
5951 pythonize_class_name ,
6052 pythonize_field_name ,
7264)
7365from betterproto .lib .google .protobuf .compiler import CodeGeneratorRequest
7466
67+ from .. import which_one_of
7568from ..compile .importing import (
7669 get_type_reference ,
7770 parse_source_type_name ,
8275 pythonize_field_name ,
8376 pythonize_method_name ,
8477)
78+ from .typing_compiler import (
79+ DirectImportTypingCompiler ,
80+ TypingCompiler ,
81+ )
8582
8683
8784# Create a unique placeholder to deal with
@@ -173,6 +170,7 @@ class ProtoContentBase:
173170 """Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
174171
175172 source_file : FileDescriptorProto
173+ typing_compiler : TypingCompiler
176174 path : List [int ]
177175 comment_indent : int = 4
178176 parent : Union ["betterproto.Message" , "OutputTemplate" ]
@@ -242,7 +240,6 @@ class OutputTemplate:
242240 input_files : List [str ] = field (default_factory = list )
243241 imports : Set [str ] = field (default_factory = set )
244242 datetime_imports : Set [str ] = field (default_factory = set )
245- typing_imports : Set [str ] = field (default_factory = set )
246243 pydantic_imports : Set [str ] = field (default_factory = set )
247244 builtins_import : bool = False
248245 messages : List ["MessageCompiler" ] = field (default_factory = list )
@@ -251,6 +248,7 @@ class OutputTemplate:
251248 imports_type_checking_only : Set [str ] = field (default_factory = set )
252249 pydantic_dataclasses : bool = False
253250 output : bool = True
251+ typing_compiler : TypingCompiler = field (default_factory = DirectImportTypingCompiler )
254252
255253 @property
256254 def package (self ) -> str :
@@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
289287 """Representation of a protobuf message."""
290288
291289 source_file : FileDescriptorProto
290+ typing_compiler : TypingCompiler
292291 parent : Union ["MessageCompiler" , OutputTemplate ] = PLACEHOLDER
293292 proto_obj : DescriptorProto = PLACEHOLDER
294293 path : List [int ] = PLACEHOLDER
@@ -319,7 +318,7 @@ def py_name(self) -> str:
319318 @property
320319 def annotation (self ) -> str :
321320 if self .repeated :
322- return f"List[ { self .py_name } ]"
321+ return self .typing_compiler . list ( self . py_name )
323322 return self .py_name
324323
325324 @property
@@ -434,18 +433,6 @@ def datetime_imports(self) -> Set[str]:
434433 imports .add ("datetime" )
435434 return imports
436435
437- @property
438- def typing_imports (self ) -> Set [str ]:
439- imports = set ()
440- annotation = self .annotation
441- if "Optional[" in annotation :
442- imports .add ("Optional" )
443- if "List[" in annotation :
444- imports .add ("List" )
445- if "Dict[" in annotation :
446- imports .add ("Dict" )
447- return imports
448-
449436 @property
450437 def pydantic_imports (self ) -> Set [str ]:
451438 return set ()
@@ -458,7 +445,6 @@ def use_builtins(self) -> bool:
458445
459446 def add_imports_to (self , output_file : OutputTemplate ) -> None :
460447 output_file .datetime_imports .update (self .datetime_imports )
461- output_file .typing_imports .update (self .typing_imports )
462448 output_file .pydantic_imports .update (self .pydantic_imports )
463449 output_file .builtins_import = output_file .builtins_import or self .use_builtins
464450
@@ -488,7 +474,9 @@ def optional(self) -> bool:
488474 @property
489475 def mutable (self ) -> bool :
490476 """True if the field is a mutable type, otherwise False."""
491- return self .annotation .startswith (("List[" , "Dict[" ))
477+ return self .annotation .startswith (
478+ ("typing.List[" , "typing.Dict[" , "dict[" , "list[" , "Dict[" , "List[" )
479+ )
492480
493481 @property
494482 def field_type (self ) -> str :
@@ -562,6 +550,7 @@ def py_type(self) -> str:
562550 package = self .output_file .package ,
563551 imports = self .output_file .imports ,
564552 source_type = self .proto_obj .type_name ,
553+ typing_compiler = self .typing_compiler ,
565554 pydantic = self .output_file .pydantic_dataclasses ,
566555 )
567556 else :
@@ -573,9 +562,9 @@ def annotation(self) -> str:
573562 if self .use_builtins :
574563 py_type = f"builtins.{ py_type } "
575564 if self .repeated :
576- return f"List[ { py_type } ]"
565+ return self . typing_compiler . list ( py_type )
577566 if self .optional :
578- return f"Optional[ { py_type } ]"
567+ return self . typing_compiler . optional ( py_type )
579568 return py_type
580569
581570
@@ -623,11 +612,13 @@ def __post_init__(self) -> None:
623612 source_file = self .source_file ,
624613 parent = self ,
625614 proto_obj = nested .field [0 ], # key
615+ typing_compiler = self .typing_compiler ,
626616 ).py_type
627617 self .py_v_type = FieldCompiler (
628618 source_file = self .source_file ,
629619 parent = self ,
630620 proto_obj = nested .field [1 ], # value
621+ typing_compiler = self .typing_compiler ,
631622 ).py_type
632623
633624 # Get proto types
@@ -645,7 +636,7 @@ def field_type(self) -> str:
645636
646637 @property
647638 def annotation (self ) -> str :
648- return f"Dict[ { self .py_k_type } , { self .py_v_type } ]"
639+ return self .typing_compiler . dict ( self . py_k_type , self .py_v_type )
649640
650641 @property
651642 def repeated (self ) -> bool :
@@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase):
702693 def __post_init__ (self ) -> None :
703694 # Add service to output file
704695 self .output_file .services .append (self )
705- self .output_file .typing_imports .add ("Dict" )
706696 super ().__post_init__ () # check for unset fields
707697
708698 @property
@@ -725,22 +715,6 @@ def __post_init__(self) -> None:
725715 # Add method to service
726716 self .parent .methods .append (self )
727717
728- # Check for imports
729- if "Optional" in self .py_output_message_type :
730- self .output_file .typing_imports .add ("Optional" )
731-
732- # Check for Async imports
733- if self .client_streaming :
734- self .output_file .typing_imports .add ("AsyncIterable" )
735- self .output_file .typing_imports .add ("Iterable" )
736- self .output_file .typing_imports .add ("Union" )
737-
738- # Required by both client and server
739- if self .client_streaming or self .server_streaming :
740- self .output_file .typing_imports .add ("AsyncIterator" )
741-
742- # add imports required for request arguments timeout, deadline and metadata
743- self .output_file .typing_imports .add ("Optional" )
744718 self .output_file .imports_type_checking_only .add ("import grpclib.server" )
745719 self .output_file .imports_type_checking_only .add (
746720 "from betterproto.grpc.grpclib_client import MetadataLike"
@@ -806,6 +780,7 @@ def py_input_message_type(self) -> str:
806780 package = self .output_file .package ,
807781 imports = self .output_file .imports ,
808782 source_type = self .proto_obj .input_type ,
783+ typing_compiler = self .output_file .typing_compiler ,
809784 unwrap = False ,
810785 pydantic = self .output_file .pydantic_dataclasses ,
811786 ).strip ('"' )
@@ -835,6 +810,7 @@ def py_output_message_type(self) -> str:
835810 package = self .output_file .package ,
836811 imports = self .output_file .imports ,
837812 source_type = self .proto_obj .output_type ,
813+ typing_compiler = self .output_file .typing_compiler ,
838814 unwrap = False ,
839815 pydantic = self .output_file .pydantic_dataclasses ,
840816 ).strip ('"' )
0 commit comments