diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index 95fdac5d..8927551e 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -23,6 +23,15 @@ class ProviderStatus(Enum): class FeatureProvider(typing.Protocol): # pragma: no cover + def attach( + self, + on_emit: typing.Callable[ + [FeatureProvider, ProviderEvent, ProviderEventDetails], None + ], + ) -> None: ... + + def detach(self) -> None: ... + def initialize(self, evaluation_context: EvaluationContext) -> None: ... def shutdown(self) -> None: ... @@ -68,6 +77,18 @@ def resolve_object_details( class AbstractProvider(FeatureProvider): + def attach( + self, + on_emit: typing.Callable[ + [FeatureProvider, ProviderEvent, ProviderEventDetails], None + ], + ) -> None: + self._on_emit = on_emit + + def detach(self) -> None: + if hasattr(self, "_on_emit"): + del self._on_emit + def initialize(self, evaluation_context: EvaluationContext) -> None: pass @@ -141,6 +162,5 @@ def emit_provider_stale(self, details: ProviderEventDetails) -> None: self.emit(ProviderEvent.PROVIDER_STALE, details) def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: - from openfeature.provider._registry import provider_registry - - provider_registry.dispatch_event(self, event, details) + if hasattr(self, "_on_emit"): + self._on_emit(self, event, details) diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index 902a204e..e2ec2e53 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -71,6 +71,7 @@ def _get_evaluation_context(self) -> EvaluationContext: return get_evaluation_context() def _initialize_provider(self, provider: FeatureProvider) -> None: + provider.attach(self.dispatch_event) try: if hasattr(provider, "initialize"): provider.initialize(self._get_evaluation_context()) @@ -106,6 +107,7 @@ def _shutdown_provider(self, provider: FeatureProvider) -> None: error_code=ErrorCode.PROVIDER_FATAL, ), ) + provider.detach() def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus: return self._provider_status.get(provider, ProviderStatus.NOT_READY) diff --git a/tests/test_api.py b/tests/test_api.py index aaea26b8..019037db 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -235,6 +235,8 @@ def test_clear_providers_shutdowns_every_provider_and_resets_default_provider(): def test_provider_events(): # Given spy = MagicMock() + provider = NoOpProvider() + set_provider(provider) add_handler(ProviderEvent.PROVIDER_READY, spy.provider_ready) add_handler( @@ -243,8 +245,6 @@ def test_provider_events(): add_handler(ProviderEvent.PROVIDER_ERROR, spy.provider_error) add_handler(ProviderEvent.PROVIDER_STALE, spy.provider_stale) - provider = NoOpProvider() - provider_details = ProviderEventDetails(message="message") details = EventDetails.from_provider_event_details( provider.get_metadata().name, provider_details