Skip to content

Commit ddb1335

Browse files
committed
Enable on strategy run hooks
1 parent 55665d1 commit ddb1335

File tree

2 files changed

+37
-14
lines changed

2 files changed

+37
-14
lines changed

investing_algorithm_framework/app/app.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,13 @@ def run(self, number_of_iterations: int = None):
595595
"""
596596
self.initialize_config()
597597

598+
# Run all on_initialize hooks
599+
for hook in self._on_initialize_hooks:
600+
logger.info(
601+
f"Running on_initialize hook: {hook.__class__.__name__}"
602+
)
603+
hook.on_run(self.context)
604+
598605
# Load the state if a state handler is provided
599606
if self._state_handler is not None:
600607
logger.info("Detected state handler, loading state")
@@ -607,8 +614,12 @@ def run(self, number_of_iterations: int = None):
607614
event_loop_service = None
608615

609616
try:
610-
# Run all on_initialize hooks
611-
for hook in self._on_initialize_hooks:
617+
# Run all on_after_initialize hooks
618+
for hook in self._on_after_initialize_hooks:
619+
logger.info(
620+
f"Running on_after_initialize "
621+
f"hook: {hook.__class__.__name__}"
622+
)
612623
hook.on_run(self.context)
613624

614625
algorithm = self.get_algorithm()

investing_algorithm_framework/app/eventloop.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import datetime, timedelta, timezone
22
from time import sleep
33
from typing import List, Set, Dict
4+
from logging import getLogger
45

56
import polars as pl
67

@@ -13,6 +14,8 @@
1314
from .algorithm import Algorithm
1415
from .strategy import TradingStrategy
1516

17+
logger = getLogger("investing_algorithm_framework")
18+
1619

1720
class EventLoopService:
1821
"""
@@ -256,7 +259,7 @@ def _snapshot(
256259
def initialize(
257260
self,
258261
algorithm: Algorithm,
259-
trade_order_evaluator: TradeOrderEvaluator
262+
trade_order_evaluator: TradeOrderEvaluator,
260263
):
261264
"""
262265
Initializes the event loop service by calculating the schedule for
@@ -279,6 +282,7 @@ def initialize(
279282
"""
280283
self._algorithm = algorithm
281284
self.strategies = algorithm.strategies
285+
self.tasks = algorithm.tasks
282286

283287
if len(self.strategies) == 0:
284288
raise OperationalException(
@@ -367,9 +371,10 @@ def start(
367371
INDEX_DATETIME, current_time
368372
)
369373
strategy_ids = schedule[current_time]["strategy_ids"]
370-
# task_ids = schedule[current_time]["task_ids"]
371374
strategies = self._get_strategies(strategy_ids)
372-
self._run_iteration(strategies=strategies, tasks=[])
375+
self._run_iteration(
376+
strategies=strategies
377+
)
373378

374379
else:
375380
for current_time in sorted_times:
@@ -379,14 +384,18 @@ def start(
379384
strategy_ids = schedule[current_time]["strategy_ids"]
380385
# task_ids = schedule[current_time]["task_ids"]
381386
strategies = self._get_strategies(strategy_ids)
382-
self._run_iteration(strategies=strategies, tasks=[])
387+
self._run_iteration(
388+
strategies=strategies
389+
)
383390
else:
384391
if number_of_iterations is None:
385392
try:
386393
config = self._configuration_service.config
387394
current_time = config[INDEX_DATETIME]
388395
strategies = self._get_due_strategies(current_time)
389-
self._run_iteration(strategies)
396+
self._run_iteration(
397+
strategies=strategies, tasks=self.tasks
398+
)
390399
current_time = datetime.now(timezone.utc)
391400
self._configuration_service.add_value(
392401
INDEX_DATETIME, current_time
@@ -405,7 +414,9 @@ def start(
405414
config = self._configuration_service.config
406415
current_time = config[INDEX_DATETIME]
407416
strategies = self._get_due_strategies(current_time)
408-
self._run_iteration(strategies)
417+
self._run_iteration(
418+
strategies=strategies, tasks=self.tasks
419+
)
409420
current_time = datetime.now(timezone.utc)
410421
self._configuration_service.add_value(
411422
INDEX_DATETIME, current_time
@@ -419,7 +430,9 @@ def start(
419430
config = self._configuration_service.config
420431
current_time = config[INDEX_DATETIME]
421432
strategies = self._get_due_strategies(current_time)
422-
self._run_iteration(strategies)
433+
self._run_iteration(
434+
strategies=strategies, tasks=self.tasks
435+
)
423436
current_time = datetime.now(timezone.utc)
424437
self._configuration_service.add_value(
425438
INDEX_DATETIME, current_time
@@ -526,6 +539,9 @@ def _run_iteration(
526539
if not strategies:
527540
return
528541

542+
for task in self.tasks:
543+
logger.info(f"Running task {task.__class__.__name__}")
544+
529545
for strategy in strategies:
530546

531547
if strategy.data_sources is not None:
@@ -537,13 +553,9 @@ def _run_iteration(
537553
else:
538554
data = {}
539555

540-
# Select data for the strategy
556+
logger.info(f"Running strategy {strategy.strategy_id}")
541557
strategy.run_strategy(context=self.context, data=data)
542558

543-
# # Step 6: Run all on_strategy_run hooks
544-
# for strategy in due_strategies:
545-
# strategy.run_on_strategy_run_hooks(context=self.context)
546-
547559
# Step 7: Snapshot the portfolios if needed and update history
548560
created_orders = self._order_service.get_all(
549561
{

0 commit comments

Comments
 (0)