diff --git a/nibabel/nicom/csareader.py b/nibabel/nicom/csareader.py index 7327c53f72..8ca5a6d2d3 100644 --- a/nibabel/nicom/csareader.py +++ b/nibabel/nicom/csareader.py @@ -18,6 +18,8 @@ 'IS': int, # integer string } +MAX_CSA_ITEMS = 199 + class CSAError(Exception): pass @@ -116,7 +118,9 @@ def read(csa_str): # CSA1 specific length modifier if tag_no == 1: tag0_n_items = n_items - assert n_items < 100 + if n_items > MAX_CSA_ITEMS: + raise CSAReadError('Expected <= {0} tags, got {1}'.format( + MAX_CSA_ITEMS, n_items)) items = [] for item_no in range(n_items): x0,x1,x2,x3 = up_str.unpack('4i') diff --git a/nibabel/nicom/tests/data/csa_str_200n_items.bin b/nibabel/nicom/tests/data/csa_str_200n_items.bin new file mode 100644 index 0000000000..780f196ffa Binary files /dev/null and b/nibabel/nicom/tests/data/csa_str_200n_items.bin differ diff --git a/nibabel/nicom/tests/data/csa_str_valid.bin b/nibabel/nicom/tests/data/csa_str_valid.bin new file mode 100644 index 0000000000..6779d2c0f1 Binary files /dev/null and b/nibabel/nicom/tests/data/csa_str_valid.bin differ diff --git a/nibabel/nicom/tests/test_csareader.py b/nibabel/nicom/tests/test_csareader.py index b0d36a3d8b..2b93768c07 100644 --- a/nibabel/nicom/tests/test_csareader.py +++ b/nibabel/nicom/tests/test_csareader.py @@ -18,6 +18,8 @@ CSA2_B0 = open(pjoin(IO_DATA_PATH, 'csa2_b0.bin'), 'rb').read() CSA2_B1000 = open(pjoin(IO_DATA_PATH, 'csa2_b1000.bin'), 'rb').read() CSA2_0len = gzip.open(pjoin(IO_DATA_PATH, 'csa2_zero_len.bin.gz'), 'rb').read() +CSA_STR_valid = open(pjoin(IO_DATA_PATH, 'csa_str_valid.bin'), 'rb').read() +CSA_STR_200n_items = open(pjoin(IO_DATA_PATH, 'csa_str_200n_items.bin'), 'rb').read() @dicom_test @@ -65,6 +67,22 @@ def test_csa_len0(): assert_equal(len(tags), 44) +def test_csa_nitem(): + # testing csa.read's ability to raise an error when n_items >= 200 + assert_raises(csa.CSAReadError, csa.read, CSA_STR_200n_items) + # OK when < 200 + csa_info = csa.read(CSA_STR_valid) + assert_equal(len(csa_info['tags']), 1) + # OK after changing module global + n_items_thresh = csa.MAX_CSA_ITEMS + try: + csa.MAX_CSA_ITEMS = 1000 + csa_info = csa.read(CSA_STR_200n_items) + assert_equal(len(csa_info['tags']), 1) + finally: + csa.MAX_CSA_ITEMS = n_items_thresh + + def test_csa_params(): for csa_str in (CSA2_B0, CSA2_B1000): csa_info = csa.read(csa_str)