Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/acquisition/covid_hosp/common/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,16 @@ def nan_safe_dtype(dtype, value):
for csv_name in self.key_columns:
dataframe.loc[:, csv_name] = dataframe[csv_name].map(self.columns_and_types[csv_name].dtype)

num_columns = 2 + len(dataframe_columns_and_types) + len(self.additional_fields)
value_placeholders = ', '.join(['%s'] * num_columns)
col_names = [f'`{i.sql_name}`' for i in dataframe_columns_and_types + self.additional_fields]
columns = ', '.join(col_names)
updates = ', '.join(f'{c}=new_values.{c}' for c in col_names)
# NOTE: list in `updates` presumes `publication_col_name` is part of the unique key and thus not needed in UPDATE
sql = f'INSERT INTO `{self.table_name}` (`id`, `{self.publication_col_name}`, {columns}) ' \
f'VALUES ({value_placeholders}) AS new_values ' \
f'ON DUPLICATE KEY UPDATE {updates}'
value_placeholders = ', '.join(['%s'] * (2 + len(col_names))) # extra 2 for `id` and `self.publication_col_name` cols
columnstring = ', '.join(col_names)
sql = f'REPLACE INTO `{self.table_name}` (`id`, `{self.publication_col_name}`, {columnstring}) VALUES ({value_placeholders})'
id_and_publication_date = (0, publication_date)
num_values = len(dataframe.index)
if logger:
logger.info('updating values', count=len(dataframe.index))
logger.info('updating values', count=num_values)
n = 0
rows_affected = 0
many_values = []
with self.new_cursor() as cursor:
for index, row in dataframe.iterrows():
Expand All @@ -212,6 +209,7 @@ def nan_safe_dtype(dtype, value):
if n % 5_000 == 0:
try:
cursor.executemany(sql, many_values)
rows_affected += cursor.rowcount
many_values = []
except Exception as e:
if logger:
Expand All @@ -220,6 +218,11 @@ def nan_safe_dtype(dtype, value):
# insert final batch
if many_values:
cursor.executemany(sql, many_values)
rows_affected += cursor.rowcount
if logger:
# NOTE: REPLACE INTO marks 2 rows affected for a "replace" (one for a delete and one for a re-insert)
# which allows us to count rows which were updated
logger.info('rows affected', total=rows_affected, updated=rows_affected-num_values)

# deal with non/seldomly updated columns used like a fk table (if this database needs it)
if hasattr(self, 'AGGREGATE_KEY_COLS'):
Expand Down
2 changes: 1 addition & 1 deletion tests/acquisition/covid_hosp/common/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_insert_dataset(self):

actual_sql = mock_cursor.executemany.call_args[0][0]
self.assertIn(
'INSERT INTO `test_table` (`id`, `publication_date`, `sql_str_col`, `sql_int_col`, `sql_float_col`)',
'REPLACE INTO `test_table` (`id`, `publication_date`, `sql_str_col`, `sql_int_col`, `sql_float_col`)',
actual_sql)

expected_values = [
Expand Down