From b5042f6fcf8d7864eb60bccea3cbf9bb8ecfcb92 Mon Sep 17 00:00:00 2001 From: zacharyburnett Date: Wed, 10 Jul 2024 09:23:13 -0400 Subject: [PATCH] use `tmp_path` in tests that create temporary files --- tests/test_airglow.py | 23 +-- tests/test_average.py | 19 +-- tests/test_cosutil.py | 301 +++++++++++++++++----------------------- tests/test_extract.py | 26 ++-- tests/test_shiftfile.py | 81 ++++++++--- 5 files changed, 219 insertions(+), 231 deletions(-) diff --git a/tests/test_airglow.py b/tests/test_airglow.py index 6f7c753..5e4e44b 100644 --- a/tests/test_airglow.py +++ b/tests/test_airglow.py @@ -3,7 +3,7 @@ import os -def test_find_airglow_limits(): +def test_find_airglow_limits(tmp_path): """ unit test for find_airglow_limits() test ran @@ -20,13 +20,23 @@ def test_find_airglow_limits(): """ # Setup - inf = {"obstype": "SPECTROSCOPIC", "cenwave": 1055, "aperture": "PSA", "detector": "FUV", - "opt_elem": "G130M", "segment": "FUVA"} + inf = { + "obstype": "SPECTROSCOPIC", + "cenwave": 1055, + "aperture": "PSA", + "detector": "FUV", + "opt_elem": "G130M", + "segment": "FUVA", + } seg = ["FUVA", "FUVB"] - disptab = create_disptab_file('49g17153l_disp.fits') + disptab = create_disptab_file(str(tmp_path / "49g17153l_disp.fits")) airglow_lines = ["Lyman_alpha", "N_I_1200", "O_I_1304", "O_I_1356", "N_I_1134"] actual_pxl = [ - [], [], (15421.504705213156, 15738.02214190493), (8853.838672375898, 9135.702216258482)] + [], + [], + (15421.504705213156, 15738.02214190493), + (8853.838672375898, 9135.702216258482), + ] # Test test_pxl = [[], []] # only works for FUV @@ -39,6 +49,3 @@ def test_find_airglow_limits(): # Verify for i in range(len(actual_pxl)): assert actual_pxl[i] == test_pxl[i] - - # Cleanup - os.remove(disptab) diff --git a/tests/test_average.py b/tests/test_average.py index ef58178..dc7f95a 100644 --- a/tests/test_average.py +++ b/tests/test_average.py @@ -7,7 +7,7 @@ from generate_tempfiles import create_count_file -def test_avg_image(): +def test_avg_image(tmp_path): """ tests avg_image() in average.py explanation of the test @@ -19,10 +19,8 @@ def test_avg_image(): pass if expected == actual fail otherwise. """ # Setup - infile = ["test_count1.fits", "test_count2.fits"] - outfile = "test_output.fits" - if os.path.exists(outfile): - os.remove(outfile) # avoid file exists error + infile = [str(tmp_path / "test_count1.fits"), str(tmp_path / "test_count2.fits")] + outfile = str(tmp_path / "test_output.fits") create_count_file(infile[0]) create_count_file(infile[1]) inhdr1, inhdr2 = fits.open(infile[0]), fits.open(infile[1]) @@ -32,11 +30,8 @@ def test_avg_image(): # Verify assert os.path.exists(outfile) - for (i, j, k) in zip(inhdr1[1].header, inhdr2[1].header, out_hdr[1].header): + for i, j, k in zip(inhdr1[1].header, inhdr2[1].header, out_hdr[1].header): assert i == j == k - np.testing.assert_array_equal((inhdr1[1].data + inhdr1[1].data) / 2, out_hdr[1].data) - - # Cleanup - for tempfile in infile: - os.remove(tempfile) - os.remove(outfile) + np.testing.assert_array_equal( + (inhdr1[1].data + inhdr1[1].data) / 2, out_hdr[1].data + ) diff --git a/tests/test_cosutil.py b/tests/test_cosutil.py index 2271a38..a4a484e 100644 --- a/tests/test_cosutil.py +++ b/tests/test_cosutil.py @@ -12,10 +12,10 @@ from generate_tempfiles import generate_fits_file -def test_find_column(): +def test_find_column(tmp_path): # Setup # create a test fits file - name = "findCol.fits" + name = str(tmp_path / "findCol.fits") ofd = generate_fits_file(name) target_col = "TIME" @@ -23,14 +23,12 @@ def test_find_column(): col_exists = True # Verify assert col_exists == cosutil.findColumn(name, target_col) - # Cleanup - os.remove(name) def test_get_table(tmp_path): # Setup # create a test fits file - name = tmp_path / "getTable.fits" + name = str(tmp_path / "getTable.fits") ofd = generate_fits_file(name) truth = [tuple(ofd[1].data[3])] time = ofd[1].data[3][0] @@ -42,7 +40,7 @@ def test_get_table(tmp_path): def test_get_table_exceptions(tmp_path): # Raise MissingRowError - name = tmp_path / "getTable.fits" + name = str(tmp_path / "getTable.fits") generate_fits_file(name) # truth = [tuple(ofd[1].data[3])] t = 1.0 # non-existent value @@ -53,7 +51,7 @@ def test_get_table_exceptions(tmp_path): def test_get_col_copy(tmp_path): # Setup # create a test fits file - name = tmp_path / "getTable.fits" + name = str(tmp_path / "getTable.fits") ofd = generate_fits_file(name) col_name = "XCORR" portion_of_array = ofd[1].data[:] @@ -68,9 +66,9 @@ def test_get_col_copy(tmp_path): def test_get_col_copy_exception(tmp_path): + name = str(tmp_path / "getTable.fits") # raise RuntimeError error with pytest.raises(RuntimeError): - name = tmp_path / "getTable.fits" ofd = generate_fits_file(name) col_name = "XCORR" portion_of_array = ofd[1].data[:] @@ -80,10 +78,10 @@ def test_get_col_copy_exception(tmp_path): cosutil.getColCopy(filename=None, column=col_name, data=None) -def test_get_headers(): +def test_get_headers(tmp_path): # Setup # create a test fits file - name = "getHeaders.fits" + name = str(tmp_path / "getHeaders.fits") ofd = generate_fits_file(name) true_hdr = ofd[0].header @@ -92,21 +90,16 @@ def test_get_headers(): # Verify np.testing.assert_array_equal(true_hdr, test_hdr[0]) - # Cleanup - os.remove(name) -def test_write_output_events(): +def test_write_output_events(tmp_path): # Setup - in_file = "outputEvents.fits" - out_file = "outputEvents_cpy.fits" + in_file = str(tmp_path / "outputEvents.fits") + out_file = str(tmp_path / "outputEvents_cpy.fits") generate_fits_file(in_file) actual_lines = 10 lines = cosutil.writeOutputEvents(in_file, out_file) assert actual_lines == lines - # Cleanup - os.remove(in_file) - os.remove(out_file) def test_concat_arrays(): @@ -120,27 +113,25 @@ def test_concat_arrays(): np.testing.assert_array_equal(actual, concat_arrays) -def test_update_filename(): +def test_update_filename(tmp_path): # Setup - filename = "update_filename" - generate_fits_file("update_filename.fits") - before_update_hdr = fits.open("update_filename.fits", mode="update") + filename = str(tmp_path / "update_filename.fits") + generate_fits_file(filename) + before_update_hdr = fits.open(filename, mode="update") # Test - cosutil.updateFilename(before_update_hdr[0].header, filename) + cosutil.updateFilename(before_update_hdr[0].header, "update_filename") before_update_hdr.close() - after_update_hdr = fits.open("update_filename.fits") + after_update_hdr = fits.open(filename) # Verity - assert filename == after_update_hdr[0].header["filename"] + assert "update_filename" == after_update_hdr[0].header["filename"] after_update_hdr.close() - # Cleanup - os.remove("update_filename.fits") -def test_copy_file(): +def test_copy_file(tmp_path): # Setup - infile = "input.fits" + infile = str(tmp_path / "input.fits") generate_fits_file(infile) - outfile = "output.fits" + outfile = str(tmp_path / "output.fits") # Test cosutil.copyFile(infile, outfile) # Verify @@ -149,15 +140,12 @@ def test_copy_file(): np.testing.assert_array_equal(inf[1].data, out[1].data) np.testing.assert_array_equal(inf[2].data, out[2].data) np.testing.assert_array_equal(inf[3].data, out[3].data) - # Cleanup - os.remove("input.fits") - os.remove("output.fits") -def test_is_product(): +def test_is_product(tmp_path): # Setup - product_file = "my0_product_a.fits" - raw_file = "my_raw.fits" + product_file = str(tmp_path / "my0_product_a.fits") + raw_file = str(tmp_path / "my_raw.fits") generate_fits_file(product_file) generate_fits_file(raw_file) # Test @@ -169,9 +157,6 @@ def test_is_product(): # Verify assert cosutil.isProduct(product_file) assert not cosutil.isProduct(raw_file) - # Cleanup - os.remove(product_file) - os.remove(raw_file) def test_cmp_part_exception(): @@ -204,23 +189,24 @@ def test_split_int_letter(): assert truth == test -def test_create_corrtag_hdu(): +def test_create_corrtag_hdu(tmp_path): + filename = str(tmp_path / "corrtag.fits") # Setup - hdu = generate_fits_file("corrtag.fits") + hdu = generate_fits_file(filename) num_of_rows = 10 # Test # detector parameter is not needed consider removing it out_bin_table = cosutil.createCorrtagHDU(num_of_rows, detector="FUV", hdu=hdu[0]) assert len(out_bin_table.data) == num_of_rows assert all(out_bin_table.header) == all(hdu[0].header) - # Cleanup - os.remove("corrtag.fits") -def test_remove_wcs_keywords(): +def test_remove_wcs_keywords(tmp_path): + hdu_filename = str(tmp_path / "removeWCS.fits") + hdu2_filename = str(tmp_path / "corrtag.fits") # Setup - hdu = generate_fits_file("removeWCS.fits") - hdu2 = generate_fits_file("corrtag.fits") + hdu = generate_fits_file(hdu_filename) + hdu2 = generate_fits_file(hdu2_filename) inhdr = hdu2[1].header cd = hdu[1].data.columns WCS_keywords = [ @@ -236,9 +222,6 @@ def test_remove_wcs_keywords(): for keys in WCS_keywords: assert keys not in newheader assert len(inhdr[keys]) > 0 - # Cleanup - os.remove("removeWCS.fits") - os.remove("corrtag.fits") def test_dummy_gti(): @@ -251,15 +234,14 @@ def test_dummy_gti(): assert dummy_hdu.data[0][1] == test_exptime_value -def test_return_gti(): +def test_return_gti(tmp_path): + filename = str(tmp_path / "gti_file.fits") # Setup - hdu = generate_fits_file("gti_file.fits") + hdu = generate_fits_file(filename) # Test - gti = cosutil.returnGTI("gti_file.fits") + gti = cosutil.returnGTI(filename) # Verify np.testing.assert_array_equal(list(hdu[2].data), gti) - # Cleanup - os.remove("gti_file.fits") def test_err_frequentist(): @@ -427,12 +409,12 @@ def test_change_segment(): assert fname3 == test_name3 -def test_copy_exptime_keywords(): +def test_copy_exptime_keywords(tmp_path): + files = [str(tmp_path / "original.fits"), str(tmp_path / "copy.fits")] # Setup # create two files - generate_fits_file("original.fits") - generate_fits_file("copy.fits") - files = ["original.fits", "copy.fits"] + for file in files: + generate_fits_file(file) headers = ["expstart", "expend", "exptime", "rawtime"] # set values to the exposure time for file in files: @@ -444,27 +426,22 @@ def test_copy_exptime_keywords(): else: fits.setval(file, header, value=0, ext=1) # get header of the files - inhdr = fits.getheader("original.fits", 1) - outhdr = fits.getheader("copy.fits", 1) + inhdr = fits.getheader(files[0], 1) + outhdr = fits.getheader(files[1], 1) # Test cosutil.copyExptimeKeywords(inhdr, outhdr) # Verify for header in headers: assert inhdr[header] == outhdr[header] - # Cleanup - for tempfile in files: - os.remove(tempfile) -def test_copy_voltage_keywords(): +def test_copy_voltage_keywords(tmp_path): + original = [str(tmp_path / "originalFUV.fits"), str(tmp_path / "originalNUV.fits")] + copy = [str(tmp_path / "copyFUV.fits"), str(tmp_path / "copyNUV.fits")] # Setup # create two files each for FUV and NUV - generate_fits_file("originalFUV.fits") - generate_fits_file("originalNUV.fits") - generate_fits_file("copyFUV.fits") - generate_fits_file("copyNUV.fits") - original = ["originalFUV.fits", "originalNUV.fits"] - copy = ["copyFUV.fits", "copyNUV.fits"] + for fileName in original + copy: + generate_fits_file(fileName) detectors = ["FUV", "NUV"] fuv_headers = ["dethvla", "dethvlb", "dethvca", "dethvcb", "dethvna", "dethvnb"] nuv_headers = ["dethvl", "dethvc"] @@ -491,10 +468,10 @@ def test_copy_voltage_keywords(): for header in nuv_headers: fits.setval(fileName, header, value=0.0, ext=1) index += 1 - in_FUV_hdr = fits.getheader("originalFUV.fits", 1) - in_NUV_hdr = fits.getheader("originalNUV.fits", 1) - out_FUV_hdr = fits.getheader("copyFUV.fits", 1) - out_NUV_hdr = fits.getheader("copyNUV.fits", 1) + in_FUV_hdr = fits.getheader(original[0], 1) + in_NUV_hdr = fits.getheader(original[1], 1) + out_FUV_hdr = fits.getheader(copy[0], 1) + out_NUV_hdr = fits.getheader(copy[1], 1) # Test 1 cosutil.copyVoltageKeywords(in_FUV_hdr, out_FUV_hdr, detectors[0]) # Verify 1 @@ -505,18 +482,13 @@ def test_copy_voltage_keywords(): # Verify 2 for header in nuv_headers: assert in_NUV_hdr[header] == out_NUV_hdr[header] - # Cleanup - for tempfile in original: - os.remove(tempfile) - for tempfile in copy: - os.remove(tempfile) -def test_copy_sub_keywords(): +def test_copy_sub_keywords(tmp_path): + files = [str(tmp_path / "subKeywords.fits"), str(tmp_path / "copySubKeywords.fits")] # Setup - generate_fits_file("subKeywords.fits") - generate_fits_file("copySubKeywords.fits") - files = ["subKeywords.fits", "copySubKeywords.fits"] + for file in files: + generate_fits_file(file) headers = ["corner%1dx", "corner%1dy", "size%1dx", "size%1dy"] for file in files: @@ -542,9 +514,6 @@ def test_copy_sub_keywords(): cosutil.copySubKeywords(inhdr, outhdr, True) # check if nsubarry has been set to 0 assert inhdr["nsubarry"] == outhdr["nsubarry"] - # Cleanup - for tempfile in files: - os.remove(tempfile) def test_modify_asn_mtyp(): @@ -566,13 +535,13 @@ def test_modify_asn_mtyp(): assert string3 == val3 -def test_rename_file(): +def test_rename_file(tmp_path): # Setup - original_filename1 = "raw-file.fits" - original_filename2 = "product0_file_a.fits" + original_filename1 = str(tmp_path / "raw-file.fits") + original_filename2 = str(tmp_path / "product0_file_a.fits") - new_filename1 = "renamed_raw_file.fits" - new_filename2 = "renamed0_file_a.fits" + new_filename1 = str(tmp_path / "renamed_raw_file.fits") + new_filename2 = str(tmp_path / "renamed0_file_a.fits") # Create the files generate_fits_file(original_filename1) generate_fits_file(original_filename2) @@ -582,15 +551,13 @@ def test_rename_file(): # Verify assert os.path.exists(new_filename1) assert os.path.exists(new_filename2) - # Cleanup - os.remove(new_filename1) - os.remove(new_filename2) -def test_del_corrtag_wcs(): +def test_del_corrtag_wcs(tmp_path): + filename = str(tmp_path / "del_corrtagWCS.fits") # Setup - generate_fits_file("del_corrtagWCS.fits") - thdr = fits.getheader("del_corrtagWCS.fits", 3) + generate_fits_file(filename) + thdr = fits.getheader(filename, 3) tkey = [ "TCTYP2", "TCRVL2", @@ -612,8 +579,6 @@ def test_del_corrtag_wcs(): # Verify for key in tkey: assert key not in thdr - # Cleanup - os.remove("del_corrtagWCS.fits") def test_set_verbosity(): @@ -865,11 +830,13 @@ def test_segment_specific_keyword(): assert key2 == root2 -def test_find_ref_file(): +def test_find_ref_file(tmp_path): + wrong_filename = str(tmp_path / "wrong_file.fits") + test_filename = str(tmp_path / "test.fits") # Setup - generate_fits_file("wrong_file.fits") - generate_fits_file("test.fits") - fits.setval("wrong_file.fits", "FILETYPE", value="FLAT FIELD REFERENCE IMAGE") + generate_fits_file(wrong_filename) + generate_fits_file(test_filename) + fits.setval(wrong_filename, "FILETYPE", value="FLAT FIELD REFERENCE IMAGE") # Missing ref1 = { "keyword": "FLATFILE", @@ -881,7 +848,7 @@ def test_find_ref_file(): # Bad version ref2 = { "keyword": "FLATFILE", - "filename": "test.fits", + "filename": test_filename, "calcos_ver": "2.21", "min_ver": "3.3", "filetype": "FLAT FIELD REFERENCE IMAGE", @@ -889,7 +856,7 @@ def test_find_ref_file(): # Wrong file ref3 = { "keyword": "FLATFILE", - "filename": "wrong_file.fits", + "filename": wrong_filename, "calcos_ver": "3.0", "min_ver": "2.3", "filetype": "IMAGE", @@ -906,9 +873,9 @@ def test_find_ref_file(): # Actual values actual_missing = {"FLATFILE": "test_flt.fits"} actual_bad_ver = { - "FLATFILE": ("test.fits", " the reference file must be at least version 3.3") + "FLATFILE": (test_filename, " the reference file must be at least version 3.3") } - actual_wrong_ver = {"FLATFILE": ("wrong_file.fits", "IMAGE")} + actual_wrong_ver = {"FLATFILE": (wrong_filename, "IMAGE")} # Test cosutil.findRefFile(ref1, missing1, wrong_f1, bad_ver1) cosutil.findRefFile(ref2, missing2, wrong_f2, bad_ver2) @@ -917,9 +884,6 @@ def test_find_ref_file(): assert actual_missing == missing1 assert actual_bad_ver == bad_ver2 assert actual_wrong_ver == wrong_f3 - # Cleanup - os.remove("wrong_file.fits") - os.remove("test.fits") def test_cmp_version(): @@ -978,16 +942,16 @@ def test_cmp_version(): assert expected_cmp[i] == test_cmp[i] -def test_get_pedigree(): +def test_get_pedigree(tmp_path): # Setup capture_msg = io.StringIO() sys.stdout = capture_msg switch = "perform" refkey = "statflag" - filename = "test_flt.file" + filename = str(tmp_path / "test_flt.file") generate_fits_file(filename) err_msg = ( - "Warning: STATFLAG test_flt.file is a dummy file\n" + f"Warning: STATFLAG {filename} is a dummy file\n" " so PERFORM will not be done.\n" ) # Test @@ -1001,37 +965,36 @@ def test_get_pedigree(): assert pedgr1 == "OK" assert pedgr2 == "DUMMY" assert err_msg == capture_msg.getvalue() - # Cleanup - os.remove(filename) -def test_get_aperture_keyword(): +def test_get_aperture_keyword(tmp_path): + filename = str(tmp_path / "aperture_test.fits") # Setup - generate_fits_file("aperture_test.fits") + generate_fits_file(filename) # condition 1 - fits.setval("aperture_test.fits", "aperture", value="PSA-FUV", ext=0) - fits.setval("aperture_test.fits", "propaper", value="PSA-FUV", ext=0) - hdr1 = fits.getheader("aperture_test.fits", ext=0) + fits.setval(filename, "aperture", value="PSA-FUV", ext=0) + fits.setval(filename, "propaper", value="PSA-FUV", ext=0) + hdr1 = fits.getheader(filename, ext=0) # condition 2 - fits.setval("aperture_test.fits", "aperture", value="RelMvReq", ext=0) - fits.setval("aperture_test.fits", "propaper", value="WCA", ext=0) - hdr2 = fits.getheader("aperture_test.fits", ext=0) + fits.setval(filename, "aperture", value="RelMvReq", ext=0) + fits.setval(filename, "propaper", value="WCA", ext=0) + hdr2 = fits.getheader(filename, ext=0) # condition 3 - fits.setval("aperture_test.fits", "propaper", value="NA", ext=0) - fits.setval("aperture_test.fits", "shutter", value="closed", ext=0) - fits.setval("aperture_test.fits", "lampused", value="P", ext=0) - hdr3 = fits.getheader("aperture_test.fits", ext=0) + fits.setval(filename, "propaper", value="NA", ext=0) + fits.setval(filename, "shutter", value="closed", ext=0) + fits.setval(filename, "lampused", value="P", ext=0) + hdr3 = fits.getheader(filename, ext=0) # condition 4 - fits.setval("aperture_test.fits", "lampused", value="D", ext=0) - hdr4 = fits.getheader("aperture_test.fits", ext=0) + fits.setval(filename, "lampused", value="D", ext=0) + hdr4 = fits.getheader(filename, ext=0) # condition 5 - fits.setval("aperture_test.fits", "lampused", value="A", ext=0) - hdr5 = fits.getheader("aperture_test.fits", ext=0) + fits.setval(filename, "lampused", value="A", ext=0) + hdr5 = fits.getheader(filename, ext=0) # condition 6 - fits.setval("aperture_test.fits", "shutter", value="open", ext=0) - fits.setval("aperture_test.fits", "life_adj", value=2, ext=0) - fits.setval("aperture_test.fits", "aperypos", value=4.56, ext=0) - hdr6 = fits.getheader("aperture_test.fits", ext=0) + fits.setval(filename, "shutter", value="open", ext=0) + fits.setval(filename, "life_adj", value=2, ext=0) + fits.setval(filename, "aperypos", value=4.56, ext=0) + hdr6 = fits.getheader(filename, ext=0) # Expected values rtn1 = ("PSA", "APERTURE changed from PSA-FUV to PSA") @@ -1054,30 +1017,26 @@ def test_get_aperture_keyword(): assert rtn4 == rtn_value4 assert rtn5 == rtn_value5 assert rtn6 == rtn_value6 - # Cleanup - os.remove("aperture_test.fits") -def test_write_version_to_trailer(): +def test_write_version_to_trailer(tmp_path): capture_msg = io.StringIO() sys.stdout = capture_msg - generate_fits_file("dummy_file.fits") - ascii_file = open("ascii.txt", mode="w") + generate_fits_file(str(tmp_path / "dummy_file.fits")) + ascii_file = open(str(tmp_path / "ascii.txt"), mode="w") cosutil.fd_trl = ascii_file cosutil.CALCOS_VERSION = "3.1.0" cosutil.writeVersionToTrailer() sys.stdout = sys.__stdout__ assert ascii_file == cosutil.fd_trl assert capture_msg.getvalue() == "" - # Cleanup - os.remove("ascii.txt") - os.remove("dummy_file.fits") -def test_get_switch(): +def test_get_switch(tmp_path): + filename = str(tmp_path / "switch.fits") # Setup - generate_fits_file("switch.fits") - phdr = fits.getheader("switch.fits", ext=0) + generate_fits_file(filename) + phdr = fits.getheader(filename, ext=0) keyword = "statflag" # Test switch = cosutil.getSwitch(phdr, keyword) @@ -1088,26 +1047,24 @@ def test_get_switch(): assert switch == "PERFORM" assert switch2 == "OMIT" assert switch3 == "N/A" - # Cleanup - os.remove("switch.fits") def test_temp_pulse_height_range(tmp_path): - generate_fits_file(tmp_path / "pulseHeightRef.fits") + name = str(tmp_path / "pulseHeightRef.fits") + generate_fits_file(name) true_pha_value = 4 - fits.setval( - tmp_path / "pulseHeightRef.fits", "pharange", value=true_pha_value, ext=0 - ) - fits.getheader(tmp_path / "pulseHeightRef.fits", ext=0) - pha_value = cosutil.tempPulseHeightRange(tmp_path / "pulseHeightRef.fits") + fits.setval(name, "pharange", value=true_pha_value, ext=0) + fits.getheader(name, ext=0) + pha_value = cosutil.tempPulseHeightRange(name) assert true_pha_value == pha_value def test_get_pulse_height_range(tmp_path): - generate_fits_file(tmp_path / "pulseHeightRef.fits") - fits.setval(tmp_path / "pulseHeightRef.fits", "phalowrA", value=7, ext=0) - fits.setval(tmp_path / "pulseHeightRef.fits", "phaupprA", value=10, ext=0) - hdu = fits.getheader(tmp_path / "pulseHeightRef.fits", ext=0) + name = str(tmp_path / "pulseHeightRef.fits") + generate_fits_file(name) + fits.setval(name, "phalowrA", value=7, ext=0) + fits.setval(name, "phaupprA", value=10, ext=0) + hdu = fits.getheader(name, ext=0) seg = ["FUVA", "FUVB"] actual = [" 7_10", None] for s, a in zip(seg, actual): @@ -1115,18 +1072,17 @@ def test_get_pulse_height_range(tmp_path): assert a == test_str -def test_time_at_midpoint(): +def test_time_at_midpoint(tmp_path): + filename = str(tmp_path / "test_timeAtMidpoint.fits") # Setup - generate_fits_file("test_timeAtMidpoint.fits") - hdr = fits.getheader("test_timeAtMidpoint.fits", ext=1) + generate_fits_file(filename) + hdr = fits.getheader(filename, ext=1) info = {"expstart": hdr["TCRPX7"], "expend": hdr["TCRVL7"]} average = 4776.314586556611 # Test test_average = cosutil.timeAtMidpoint(info) # Verify assert average == test_average - # Cleanup - os.remove("test_timeAtMidpoint.fits") def test_timeline_times(): @@ -1173,9 +1129,10 @@ def test_combine_stat(): assert actual[key] == test[key] -def test_override_keywords(): +def test_override_keywords(tmp_path): + filename = str(tmp_path / "overridekeywords.fits") # Setup - generate_fits_file("overridekeywords.fits") + generate_fits_file(filename) info = { "cal_ver": 3.1, "opt_elem": 2, @@ -1194,11 +1151,11 @@ def test_override_keywords(): "randcorr": "SKIPPED", } reffiles = {"flatfile": "abc_flat.fits", "flt_hdr": "lref$abc_flat.fits"} - fits.setval("overridekeywords.fits", "flt_hdr", value="NA", ext=0) - fits.setval("overridekeywords.fits", "dispaxis", value=2.9, ext=1) - fits.setval("overridekeywords.fits", "x_offset", value=1.9, ext=1) - phdr = fits.getheader("overridekeywords.fits", ext=0) - hdr = fits.getheader("overridekeywords.fits", ext=1) + fits.setval(filename, "flt_hdr", value="NA", ext=0) + fits.setval(filename, "dispaxis", value=2.9, ext=1) + fits.setval(filename, "x_offset", value=1.9, ext=1) + phdr = fits.getheader(filename, ext=0) + hdr = fits.getheader(filename, ext=1) # Actual Values val1 = True val2 = switches.values() @@ -1206,8 +1163,6 @@ def test_override_keywords(): val4 = reffiles["flt_hdr"] # Test cosutil.overrideKeywords(phdr, hdr, info, switches, reffiles) - phdr = fits.getheader("overridekeywords.fits", ext=0) + phdr = fits.getheader(filename, ext=0) # Verify assert val1 == phdr["statflag"] - # Cleanup - os.remove("overridekeywords.fits") diff --git a/tests/test_extract.py b/tests/test_extract.py index 40af48a..5212b36 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -4,12 +4,13 @@ import numpy as np from generate_tempfiles import generate_fits_file -def test_get_columns(): + +def test_get_columns(tmp_path): """ Test if the function is returning the right column fields """ # Setup - test_data = generate_fits_file("lbgu17qnq_corrtag_a.fits") + test_data = generate_fits_file(str(tmp_path / "lbgu17qnq_corrtag_a.fits")) dt = test_data[1].data detector = "FUV" @@ -31,18 +32,16 @@ def test_get_columns(): np.testing.assert_array_equal(yfull, yf) np.testing.assert_array_equal(dq, dq2) np.testing.assert_array_equal(epsilon, epsilon2) - # Cleanup - os.remove("lbgu17qnq_corrtag_a.fits") -def test_remove_unwanted_column(): +def test_remove_unwanted_column(tmp_path): """ Old column length should be equal to new column length + amount of the removed columns """ # Setup - target_cols = ['XFULL', 'YFULL'] + target_cols = ["XFULL", "YFULL"] # Truth - fd = generate_fits_file("lbgu17qnq_lampflash.fits") + fd = generate_fits_file(str(tmp_path / "lbgu17qnq_lampflash.fits")) table = fd[1].data cols = table.columns @@ -56,8 +55,6 @@ def test_remove_unwanted_column(): deleted_cols = deleted_cols[np.argsort(temp_cols)] # assert target_cols[0] == deleted_cols[0].name # assert target_cols[1] == deleted_cols[1].name - # Cleanup - os.remove('lbgu17qnq_lampflash.fits') def test_next_power_of_two(): @@ -72,17 +69,14 @@ def test_next_power_of_two(): assert next_power == extract.next_power_of_two(7) -def test_add_column_comment(): +def test_add_column_comment(tmp_path): # verify if entered comment to a header is present in the fits file. # Setup - ofd = generate_fits_file("myFitsFile.fits") + ofd = generate_fits_file(str(tmp_path / "myFitsFile.fits")) comment = "This comment is generated by a unit-test." # Exercise - test_table = extract.add_column_comment(ofd, 'TIME', comment) + test_table = extract.add_column_comment(ofd, "TIME", comment) # Verify - assert comment == test_table[1].header.comments['TTYPE1'] - # Cleanup - os.remove('myFitsFile.fits') - + assert comment == test_table[1].header.comments["TTYPE1"] diff --git a/tests/test_shiftfile.py b/tests/test_shiftfile.py index 6b38e25..b284eeb 100644 --- a/tests/test_shiftfile.py +++ b/tests/test_shiftfile.py @@ -2,45 +2,85 @@ from calcos import shiftfile -def create_shift_file(filename): + +def create_shift_file(shift_file): # Create the shift file for use in tests - shift_file = "shift_file.txt" with open(shift_file, "w") as file: file.write("#dataset\tfpoffset\tflash #\tstripe\tshift1\tshift2\n") for i in range(10): if i % 3 == 0: - file.write("{}\t{}\t{}\t{}\t{}\t{}\n".format("abc123def", "any", "1", "NUVA", "45.234435", "7")) + file.write( + "{}\t{}\t{}\t{}\t{}\t{}\n".format( + "abc123def", "any", "1", "NUVA", "45.234435", "7" + ) + ) elif i % 5 == 0: - file.write("{}\t{}\t{}\t{}\t{}\t{}\n".format("ghi456jkl", "any", "2", "NUVB", "34.543453", "7")) + file.write( + "{}\t{}\t{}\t{}\t{}\t{}\n".format( + "ghi456jkl", "any", "2", "NUVB", "34.543453", "7" + ) + ) elif i % 6 == 0 or i % 8 == 0: - file.write("{}\t{}\t{}\t{}\t{}\t{}\n".format("ghi456jkl", "any", "2", "FUVA", "19.543453", "5")) + file.write( + "{}\t{}\t{}\t{}\t{}\t{}\n".format( + "ghi456jkl", "any", "2", "FUVA", "19.543453", "5" + ) + ) elif i == 9: - file.write("{}\t{}\t{}\t{}\t{}\t{}\n".format("mno789pqr", "any", "2", "FUVB", "52.723453", "6")) + file.write( + "{}\t{}\t{}\t{}\t{}\t{}\n".format( + "mno789pqr", "any", "2", "FUVB", "52.723453", "6" + ) + ) else: - file.write("{}\t{}\t{}\t{}\t{}\t{}\n".format("mno789pqr", "any", "1", "NUVC", "-34.543453", "7")) + file.write( + "{}\t{}\t{}\t{}\t{}\t{}\n".format( + "mno789pqr", "any", "1", "NUVC", "-34.543453", "7" + ) + ) return -def test_shift_file(): - shift_file = "shift_file.txt" + +def test_shift_file(tmp_path): + shift_file = str(tmp_path / "shift_file.txt") create_shift_file(shift_file) # Test - ob = shiftfile.ShiftFile(shift_file, 'abc123def', 'any') + ob = shiftfile.ShiftFile(shift_file, "abc123def", "any") # Verify assert len(ob.user_shift_dict) > 0 - # Cleanup - os.remove(shift_file) -def test_get_shifts(): +def test_get_shifts(tmp_path): # Setup - shift_file = "shift_file.txt" + shift_file = str(tmp_path / "shift_file.txt") create_shift_file(shift_file) - ob1 = shiftfile.ShiftFile(shift_file, 'ghi456jkl', 'any') - ob2 = shiftfile.ShiftFile(shift_file, 'abc123def', 'any') - keys = [('any', 'nuva'), ('any', 'nuvb'), (2, 'nuvc'), ('any', 'any'), ('any', 'fuva'), ('any', 'fuvb')] - expected_values1 = [((None, None), 0), ((34.543453, 7.0), 1), ((None, None), 0), ((19.543453, 5.0), 2), ((19.543453, 5.0), 1), ((None, None), 0)] - expected_values2 = [((45.234435, 7.0), 1), ((None, None), 0), ((None, None), 0), ((45.234435, 7.0), 1), ((None, None), 0), ((None, None), 0)] + ob1 = shiftfile.ShiftFile(shift_file, "ghi456jkl", "any") + ob2 = shiftfile.ShiftFile(shift_file, "abc123def", "any") + keys = [ + ("any", "nuva"), + ("any", "nuvb"), + (2, "nuvc"), + ("any", "any"), + ("any", "fuva"), + ("any", "fuvb"), + ] + expected_values1 = [ + ((None, None), 0), + ((34.543453, 7.0), 1), + ((None, None), 0), + ((19.543453, 5.0), 2), + ((19.543453, 5.0), 1), + ((None, None), 0), + ] + expected_values2 = [ + ((45.234435, 7.0), 1), + ((None, None), 0), + ((None, None), 0), + ((45.234435, 7.0), 1), + ((None, None), 0), + ((None, None), 0), + ] # Test test_values1 = [] test_values2 = [] @@ -52,6 +92,3 @@ def test_get_shifts(): for i in range(len(expected_values1)): assert expected_values1[i] == test_values1[i] assert expected_values2[i] == test_values2[i] - # Cleanup - os.remove(shift_file) -