@@ -545,6 +545,8 @@ def setup(self):
545
545
def _setup_model (
546
546
self , inference_spec_type : model_spec_pb2 .InferenceSpecType
547
547
):
548
+ self ._ai_platform_prediction_model_spec = (
549
+ inference_spec_type .ai_platform_prediction_model_spec )
548
550
self ._api_client = None
549
551
550
552
project_id = (
@@ -589,10 +591,10 @@ def _make_request(self, body: Mapping[Text, List[Any]]) -> http.HttpRequest:
589
591
return self ._api_client .projects ().predict (
590
592
name = self ._full_model_name , body = body )
591
593
592
- @classmethod
593
- def _prepare_instances (
594
- cls , elements : List [ExampleType ]
594
+ def _prepare_instances_dict (
595
+ self , elements : List [ExampleType ]
595
596
) -> Generator [Mapping [Text , Any ], None , None ]:
597
+ """Prepare instances by converting features to dictionary."""
596
598
for example in elements :
597
599
# TODO(b/151468119): support tf.train.SequenceExample
598
600
if not isinstance (example , tf .train .Example ):
@@ -604,13 +606,28 @@ def _prepare_instances(
604
606
if attr_name is None :
605
607
continue
606
608
attr = getattr (feature , attr_name )
607
- values = cls ._parse_feature_content (attr . value , attr_name ,
608
- cls ._sending_as_binary (input_name ))
609
+ values = self ._parse_feature_content (
610
+ attr . value , attr_name , self ._sending_as_binary (input_name ))
609
611
# Flatten a sequence if its length is 1
610
612
values = (values [0 ] if len (values ) == 1 else values )
611
613
instance [input_name ] = values
612
614
yield instance
613
615
616
+ def _prepare_instances_serialized (
617
+ self , elements : List [ExampleType ]
618
+ ) -> Generator [Mapping [Text , Text ], None , None ]:
619
+ """Prepare instances by base64 encoding serialized examples."""
620
+ for example in elements :
621
+ yield {'b64' : base64 .b64encode (example .SerializeToString ()).decode ()}
622
+
623
+ def _prepare_instances (
624
+ self , elements : List [ExampleType ]
625
+ ) -> Generator [Mapping [Text , Any ], None , None ]:
626
+ if self ._ai_platform_prediction_model_spec .use_serialization_config :
627
+ return self ._prepare_instances_serialized (elements )
628
+ else :
629
+ return self ._prepare_instances_dict (elements )
630
+
614
631
@staticmethod
615
632
def _sending_as_binary (input_name : Text ) -> bool :
616
633
"""Whether data should be sent as binary."""
0 commit comments