Skip to content

Commit d45b3b4

Browse files
committed
Test service interface/impl with/without name
1 parent d43f29a commit d45b3b4

File tree

2 files changed

+209
-16
lines changed

2 files changed

+209
-16
lines changed

temporalio/workflow.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -5159,23 +5159,26 @@ def __init__(
51595159
endpoint: str,
51605160
schedule_to_close_timeout: Optional[timedelta] = None,
51615161
) -> None:
5162-
# TODO(dan): getattr here is a temporary hack. We need to be able to attach
5163-
# metadata (such as the service name) to service interfaces (perhaps setting a
5164-
# __metadata__ attribute?), as well as service implementations (they have a
5165-
# decorator that sets NexusServiceDefinition), and ensure all semantics are
5166-
# completely clear and consistent when using e.g. a service impl without an
5167-
# interface, or impl and interface with clashing names, or one without an
5168-
# explicitly-set name and the other with, etc.
5162+
# If service is not a str, then it must be a service interface or implementation
5163+
# class.
51695164
if isinstance(service, str):
51705165
self._service_name = service
51715166
elif hasattr(service, "__nexus_service__"):
51725167
self._service_name = service.__nexus_service__.name
5168+
elif hasattr(service, "__nexus_service_metadata__"):
5169+
self._service_name = service.__nexus_service_metadata__.name
51735170
else:
5174-
self._service_name = service.__name__
5171+
raise ValueError(
5172+
f"`service` may be a name (str), or a class decorated with either "
5173+
f"@nexusrpc.handler.service or @nexusrpc.interface.service. "
5174+
f"Invalid service type: {type(service)}"
5175+
)
5176+
print(f"🌈 NexusClient using service name: {self._service_name}")
51755177
self._endpoint = endpoint
51765178
self._schedule_to_close_timeout = schedule_to_close_timeout
51775179

51785180
# TODO(dan): overloads: no-input, operation name, ret type
5181+
# TODO(dan): should it be an error to use a reference to a mathod on a class other than that supplied?
51795182
async def start_operation(
51805183
self,
51815184
operation: Union[

tests/worker/test_nexus.py

+198-8
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
# Service interface
3737
#
3838
class CallerReference(StrEnum):
39-
IMPLEMENTATION = "implementation"
39+
IMPLEMENTATION_WITHOUT_INTERFACE = "implementation-without-interface"
40+
IMPLEMENTATION_OF_INTERFACE = "implementation-of-interface"
4041
INTERFACE = "interface"
4142

4243

@@ -189,7 +190,7 @@ def __init__(
189190
) -> None:
190191
self.nexus_service = workflow.NexusClient(
191192
service={
192-
CallerReference.IMPLEMENTATION: MyServiceImpl,
193+
CallerReference.IMPLEMENTATION_OF_INTERFACE: MyServiceImpl,
193194
CallerReference.INTERFACE: MyServiceInterface,
194195
}[input.caller_reference],
195196
endpoint=make_nexus_endpoint_name(task_queue),
@@ -237,7 +238,7 @@ def _get_operation(
237238
(
238239
SyncResponse,
239240
OpDefinitionType.SHORTHAND,
240-
CallerReference.IMPLEMENTATION,
241+
CallerReference.IMPLEMENTATION_OF_INTERFACE,
241242
): MyServiceImpl.my_sync_operation,
242243
(
243244
SyncResponse,
@@ -247,7 +248,7 @@ def _get_operation(
247248
(
248249
SyncResponse,
249250
OpDefinitionType.LONGHAND,
250-
CallerReference.IMPLEMENTATION,
251+
CallerReference.IMPLEMENTATION_OF_INTERFACE,
251252
): MyServiceImpl.my_sync_or_async_operation,
252253
(
253254
SyncResponse,
@@ -257,7 +258,7 @@ def _get_operation(
257258
(
258259
AsyncResponse,
259260
OpDefinitionType.SHORTHAND,
260-
CallerReference.IMPLEMENTATION,
261+
CallerReference.IMPLEMENTATION_OF_INTERFACE,
261262
): MyServiceImpl.my_async_operation,
262263
(
263264
AsyncResponse,
@@ -267,7 +268,7 @@ def _get_operation(
267268
(
268269
AsyncResponse,
269270
OpDefinitionType.LONGHAND,
270-
CallerReference.IMPLEMENTATION,
271+
CallerReference.IMPLEMENTATION_OF_INTERFACE,
271272
): MyServiceImpl.my_sync_or_async_operation,
272273
(
273274
AsyncResponse,
@@ -297,7 +298,8 @@ def _get_operation(
297298
"op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND]
298299
)
299300
@pytest.mark.parametrize(
300-
"caller_reference", [CallerReference.IMPLEMENTATION, CallerReference.INTERFACE]
301+
"caller_reference",
302+
[CallerReference.IMPLEMENTATION_OF_INTERFACE, CallerReference.INTERFACE],
301303
)
302304
async def test_sync_response(
303305
client: Client,
@@ -343,7 +345,8 @@ async def test_sync_response(
343345
"op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND]
344346
)
345347
@pytest.mark.parametrize(
346-
"caller_reference", [CallerReference.IMPLEMENTATION, CallerReference.INTERFACE]
348+
"caller_reference",
349+
[CallerReference.IMPLEMENTATION_OF_INTERFACE, CallerReference.INTERFACE],
347350
)
348351
async def test_async_response(
349352
client: Client,
@@ -459,6 +462,193 @@ async def _start_wf_and_nexus_op(
459462
return caller_wf_handle, handler_wf_handle
460463

461464

465+
@nexusrpc.interface.service
466+
class MyServiceInterfaceWithoutNameOverride:
467+
my_op: nexusrpc.interface.Operation[None, str]
468+
469+
470+
@nexusrpc.interface.service(name="my-service-interface-🌈")
471+
class MyServiceInterfaceWithNameOverride:
472+
my_op: nexusrpc.interface.Operation[None, str]
473+
474+
475+
@nexusrpc.handler.service
476+
class MyServiceImplInterfaceWithNeitherInterfaceNorNameOverride:
477+
@nexusrpc.handler.sync_operation
478+
async def my_op(
479+
self, input: None, options: nexusrpc.handler.StartOperationOptions
480+
) -> str:
481+
return self.__class__.__name__
482+
483+
484+
@nexusrpc.handler.service(interface=MyServiceInterfaceWithoutNameOverride)
485+
class MyServiceImplInterfaceWithoutNameOverride:
486+
@nexusrpc.handler.sync_operation
487+
async def my_op(
488+
self, input: None, options: nexusrpc.handler.StartOperationOptions
489+
) -> str:
490+
return self.__class__.__name__
491+
492+
493+
@nexusrpc.handler.service(interface=MyServiceInterfaceWithNameOverride)
494+
class MyServiceImplInterfaceWithNameOverride:
495+
@nexusrpc.handler.sync_operation
496+
async def my_op(
497+
self, input: None, options: nexusrpc.handler.StartOperationOptions
498+
) -> str:
499+
return self.__class__.__name__
500+
501+
502+
@nexusrpc.handler.service(name="my-service-impl-🌈")
503+
class MyServiceImplWithNameOverride:
504+
@nexusrpc.handler.sync_operation
505+
async def my_op(
506+
self, input: None, options: nexusrpc.handler.StartOperationOptions
507+
) -> str:
508+
return self.__class__.__name__
509+
510+
511+
class NameOverride(StrEnum):
512+
YES = "yes"
513+
NO = "no"
514+
515+
516+
@workflow.defn
517+
class MyServiceInterfaceAndImplCallerWorkflow:
518+
@workflow.run
519+
async def run(
520+
self,
521+
caller_reference: CallerReference,
522+
name_override: NameOverride,
523+
task_queue: str,
524+
) -> str:
525+
service_cls = {
526+
(
527+
CallerReference.INTERFACE,
528+
NameOverride.YES,
529+
): MyServiceInterfaceWithNameOverride,
530+
(
531+
CallerReference.INTERFACE,
532+
NameOverride.NO,
533+
): MyServiceInterfaceWithoutNameOverride,
534+
(
535+
CallerReference.IMPLEMENTATION_OF_INTERFACE,
536+
NameOverride.YES,
537+
): MyServiceImplWithNameOverride,
538+
(
539+
CallerReference.IMPLEMENTATION_OF_INTERFACE,
540+
NameOverride.NO,
541+
): MyServiceImplInterfaceWithoutNameOverride,
542+
(
543+
CallerReference.IMPLEMENTATION_WITHOUT_INTERFACE,
544+
NameOverride.NO,
545+
): MyServiceImplInterfaceWithNeitherInterfaceNorNameOverride,
546+
}[caller_reference, name_override]
547+
nexus_client = workflow.NexusClient(
548+
service=service_cls,
549+
endpoint=make_nexus_endpoint_name(task_queue),
550+
)
551+
return await nexus_client.execute_operation(service_cls.my_op, None)
552+
553+
554+
# TODO(dan): make it possible to refer to an impl that doesn't itself refer to an interface in its decorator
555+
# TODO(dan): allow decorator to be used without calling it
556+
# TODO(dan): check missing decorator behavior
557+
558+
559+
async def test_service_interface_and_implementation_names(client: Client):
560+
# Note that:
561+
# - The caller can specify the service & operation via a reference to either the
562+
# interface or implementation class.
563+
# - An interface class may optionally override its name.
564+
# - An implementation class may either override its name or specify an interface that
565+
# it is implementing.
566+
# - On registering a service implementation with a worker, the name by which the
567+
# service is addressed in requests is the interface name if the implementation
568+
# supplies one, or else the name override made by the impl class, or else the impl
569+
# class name.
570+
#
571+
# This test checks that the request is routed to the expected service under a variety
572+
# of scenarios related to the above considerations.
573+
task_queue = str(uuid.uuid4())
574+
async with Worker(
575+
client,
576+
nexus_services=[
577+
MyServiceImplWithNameOverride(),
578+
MyServiceImplInterfaceWithNameOverride(),
579+
MyServiceImplInterfaceWithoutNameOverride(),
580+
MyServiceImplInterfaceWithNeitherInterfaceNorNameOverride(),
581+
],
582+
workflows=[MyServiceInterfaceAndImplCallerWorkflow],
583+
task_queue=task_queue,
584+
workflow_runner=UnsandboxedWorkflowRunner(),
585+
):
586+
await create_nexus_endpoint(task_queue, client)
587+
assert (
588+
await client.execute_workflow(
589+
MyServiceInterfaceAndImplCallerWorkflow.run,
590+
args=(CallerReference.INTERFACE, NameOverride.YES, task_queue),
591+
id=str(uuid.uuid4()),
592+
task_queue=task_queue,
593+
)
594+
== "MyServiceImplInterfaceWithNameOverride"
595+
)
596+
assert (
597+
await client.execute_workflow(
598+
MyServiceInterfaceAndImplCallerWorkflow.run,
599+
args=(CallerReference.INTERFACE, NameOverride.NO, task_queue),
600+
id=str(uuid.uuid4()),
601+
task_queue=task_queue,
602+
)
603+
== "MyServiceImplInterfaceWithoutNameOverride"
604+
)
605+
assert (
606+
await client.execute_workflow(
607+
MyServiceInterfaceAndImplCallerWorkflow.run,
608+
args=(
609+
CallerReference.IMPLEMENTATION_OF_INTERFACE,
610+
NameOverride.YES,
611+
task_queue,
612+
),
613+
id=str(uuid.uuid4()),
614+
task_queue=task_queue,
615+
)
616+
== "MyServiceImplWithNameOverride"
617+
)
618+
assert (
619+
await client.execute_workflow(
620+
MyServiceInterfaceAndImplCallerWorkflow.run,
621+
args=(
622+
CallerReference.IMPLEMENTATION_OF_INTERFACE,
623+
NameOverride.NO,
624+
task_queue,
625+
),
626+
id=str(uuid.uuid4()),
627+
task_queue=task_queue,
628+
)
629+
== "MyServiceImplInterfaceWithoutNameOverride"
630+
)
631+
assert (
632+
await client.execute_workflow(
633+
MyServiceInterfaceAndImplCallerWorkflow.run,
634+
args=(
635+
CallerReference.IMPLEMENTATION_WITHOUT_INTERFACE,
636+
NameOverride.NO,
637+
task_queue,
638+
),
639+
id=str(uuid.uuid4()),
640+
task_queue=task_queue,
641+
)
642+
== "MyServiceImplInterfaceWithNeitherInterfaceNorNameOverride"
643+
)
644+
645+
646+
# TODO(dan): test invalid service interface implementations
647+
# TODO(dan): test service impls and interfaces with and without names, conflicting names, etc.
648+
# TODO(dan): test impl used without interface
649+
# TODO(dan): test empty service impl/interface names
650+
651+
462652
def make_nexus_endpoint_name(task_queue: str) -> str:
463653
# Create endpoints for different task queues without name collisions.
464654
return f"nexus-endpoint-{task_queue}"

0 commit comments

Comments
 (0)