From 7996733ccaeb58befdc6b93483ef7044274bcf1e Mon Sep 17 00:00:00 2001 From: Zach Burnett Date: Wed, 21 Aug 2024 14:37:18 -0400 Subject: [PATCH] use `tmp_path` in tests that create temporary files (#243) --- tests/test_airglow.py | 7 +- tests/test_average.py | 13 +- tests/test_cosutil.py | 318 ++++++++++++++++------------------------ tests/test_extract.py | 18 +-- tests/test_shiftfile.py | 12 +- 5 files changed, 144 insertions(+), 224 deletions(-) diff --git a/tests/test_airglow.py b/tests/test_airglow.py index 6f7c753..a9ed567 100644 --- a/tests/test_airglow.py +++ b/tests/test_airglow.py @@ -3,7 +3,7 @@ import os -def test_find_airglow_limits(): +def test_find_airglow_limits(tmp_path): """ unit test for find_airglow_limits() test ran @@ -23,7 +23,7 @@ def test_find_airglow_limits(): inf = {"obstype": "SPECTROSCOPIC", "cenwave": 1055, "aperture": "PSA", "detector": "FUV", "opt_elem": "G130M", "segment": "FUVA"} seg = ["FUVA", "FUVB"] - disptab = create_disptab_file('49g17153l_disp.fits') + disptab = create_disptab_file(str(tmp_path / '49g17153l_disp.fits')) airglow_lines = ["Lyman_alpha", "N_I_1200", "O_I_1304", "O_I_1356", "N_I_1134"] actual_pxl = [ [], [], (15421.504705213156, 15738.02214190493), (8853.838672375898, 9135.702216258482)] @@ -39,6 +39,3 @@ def test_find_airglow_limits(): # Verify for i in range(len(actual_pxl)): assert actual_pxl[i] == test_pxl[i] - - # Cleanup - os.remove(disptab) diff --git a/tests/test_average.py b/tests/test_average.py index ef58178..2e81d10 100644 --- a/tests/test_average.py +++ b/tests/test_average.py @@ -7,7 +7,7 @@ from generate_tempfiles import create_count_file -def test_avg_image(): +def test_avg_image(tmp_path): """ tests avg_image() in average.py explanation of the test @@ -19,10 +19,8 @@ def test_avg_image(): pass if expected == actual fail otherwise. """ # Setup - infile = ["test_count1.fits", "test_count2.fits"] - outfile = "test_output.fits" - if os.path.exists(outfile): - os.remove(outfile) # avoid file exists error + infile = [str(tmp_path / "test_count1.fits"), str(tmp_path / "test_count2.fits")] + outfile = str(tmp_path / "test_output.fits") create_count_file(infile[0]) create_count_file(infile[1]) inhdr1, inhdr2 = fits.open(infile[0]), fits.open(infile[1]) @@ -35,8 +33,3 @@ def test_avg_image(): for (i, j, k) in zip(inhdr1[1].header, inhdr2[1].header, out_hdr[1].header): assert i == j == k np.testing.assert_array_equal((inhdr1[1].data + inhdr1[1].data) / 2, out_hdr[1].data) - - # Cleanup - for tempfile in infile: - os.remove(tempfile) - os.remove(outfile) diff --git a/tests/test_cosutil.py b/tests/test_cosutil.py index 8eb955a..4ddb25b 100644 --- a/tests/test_cosutil.py +++ b/tests/test_cosutil.py @@ -12,10 +12,10 @@ from generate_tempfiles import generate_fits_file -def test_find_column(): +def test_find_column(tmp_path): # Setup # create a test fits file - name = "findCol.fits" + name = str(tmp_path / "findCol.fits") ofd = generate_fits_file(name) target_col = 'TIME' @@ -23,14 +23,12 @@ def test_find_column(): col_exists = True # Verify assert col_exists == cosutil.findColumn(name, target_col) - # Cleanup - os.remove(name) -def test_get_table(): +def test_get_table(tmp_path): # Setup # create a test fits file - name = "getTable.fits" + name = str(tmp_path / "getTable.fits") ofd = generate_fits_file(name) truth = [tuple(ofd[1].data[3])] time = ofd[1].data[3][0] @@ -38,26 +36,22 @@ def test_get_table(): dt = list(cosutil.getTable(name, {'TIME': time}, exactly_one=True)) # Verify np.testing.assert_array_equal(truth, dt) - # Cleanup - os.remove(name) -def test_get_table_exceptions(): +def test_get_table_exceptions(tmp_path): # Raise MissingRowError - name = "getTable.fits" + name = str(tmp_path / "getTable.fits") generate_fits_file(name) # truth = [tuple(ofd[1].data[3])] t = 1.0 # non-existent value with pytest.raises(MissingRowError): cosutil.getTable(name, {'Time': t}, exactly_one=True) - # Cleanup - os.remove(name) -def test_get_col_copy(): +def test_get_col_copy(tmp_path): # Setup # create a test fits file - name = "getTable.fits" + name = str(tmp_path / "getTable.fits") ofd = generate_fits_file(name) col_name = 'XCORR' portion_of_array = ofd[1].data[:] @@ -69,27 +63,23 @@ def test_get_col_copy(): # Verify np.testing.assert_array_equal(truth_values, test1) np.testing.assert_array_equal(truth_values, test2) - # Cleanup - os.remove(name) -def test_get_col_copy_exception(): +def test_get_col_copy_exception(tmp_path): # raise RuntimeError error with pytest.raises(RuntimeError): - name = "getTable.fits" + name = str(tmp_path / "getTable.fits") ofd = generate_fits_file(name) col_name = 'XCORR' portion_of_array = ofd[1].data[:] - cosutil.getColCopy(filename="Output/getTable.fits", column=col_name, data=portion_of_array) + cosutil.getColCopy(filename=str(tmp_path / "Output/getTable.fits"), column=col_name, data=portion_of_array) cosutil.getColCopy(filename=None, column=col_name, data=None) - # Cleanup - os.remove(name) -def test_get_headers(): +def test_get_headers(tmp_path): # Setup # create a test fits file - name = "getHeaders.fits" + name = str(tmp_path / "getHeaders.fits") ofd = generate_fits_file(name) true_hdr = ofd[0].header @@ -98,21 +88,16 @@ def test_get_headers(): # Verify np.testing.assert_array_equal(true_hdr, test_hdr[0]) - # Cleanup - os.remove(name) -def test_write_output_events(): +def test_write_output_events(tmp_path): # Setup - in_file = "outputEvents.fits" - out_file = "outputEvents_cpy.fits" + in_file = str(tmp_path / "outputEvents.fits") + out_file = str(tmp_path / "outputEvents_cpy.fits") generate_fits_file(in_file) actual_lines = 10 lines = cosutil.writeOutputEvents(in_file, out_file) assert actual_lines == lines - # Cleanup - os.remove(in_file) - os.remove(out_file) def test_concat_arrays(): @@ -126,27 +111,25 @@ def test_concat_arrays(): np.testing.assert_array_equal(actual, concat_arrays) -def test_update_filename(): +def test_update_filename(tmp_path): # Setup - filename = "update_filename" - generate_fits_file("update_filename.fits") - before_update_hdr = fits.open("update_filename.fits", mode="update") + filename = str(tmp_path / "update_filename") + generate_fits_file(filename) + before_update_hdr = fits.open(filename, mode="update") # Test cosutil.updateFilename(before_update_hdr[0].header, filename) before_update_hdr.close() - after_update_hdr = fits.open("update_filename.fits") + after_update_hdr = fits.open(filename) # Verity assert filename == after_update_hdr[0].header["filename"] after_update_hdr.close() - # Cleanup - os.remove("update_filename.fits") -def test_copy_file(): +def test_copy_file(tmp_path): # Setup - infile = "input.fits" + infile = str(tmp_path / "input.fits") generate_fits_file(infile) - outfile = "output.fits" + outfile = str(tmp_path / "output.fits") # Test cosutil.copyFile(infile, outfile) # Verify @@ -155,15 +138,12 @@ def test_copy_file(): np.testing.assert_array_equal(inf[1].data, out[1].data) np.testing.assert_array_equal(inf[2].data, out[2].data) np.testing.assert_array_equal(inf[3].data, out[3].data) - # Cleanup - os.remove("input.fits") - os.remove("output.fits") -def test_is_product(): +def test_is_product(tmp_path): # Setup - product_file = "my0_product_a.fits" - raw_file = "my_raw.fits" + product_file = str(tmp_path / "my0_product_a.fits") + raw_file = str(tmp_path / "my_raw.fits") generate_fits_file(product_file) generate_fits_file(raw_file) # Test @@ -175,9 +155,6 @@ def test_is_product(): # Verify assert cosutil.isProduct(product_file) assert not cosutil.isProduct(raw_file) - # Cleanup - os.remove(product_file) - os.remove(raw_file) @@ -212,23 +189,21 @@ def test_split_int_letter(): assert truth == test -def test_create_corrtag_hdu(): +def test_create_corrtag_hdu(tmp_path): # Setup - hdu = generate_fits_file("corrtag.fits") + hdu = generate_fits_file(str(tmp_path / "corrtag.fits")) num_of_rows = 10 # Test # detector parameter is not needed consider removing it out_bin_table = cosutil.createCorrtagHDU(num_of_rows, detector="FUV", hdu=hdu[0]) assert len(out_bin_table.data) == num_of_rows assert all(out_bin_table.header) == all(hdu[0].header) - # Cleanup - os.remove("corrtag.fits") -def test_remove_wcs_keywords(): +def test_remove_wcs_keywords(tmp_path): # Setup - hdu = generate_fits_file("removeWCS.fits") - hdu2 = generate_fits_file("corrtag.fits") + hdu = generate_fits_file(str(tmp_path / "removeWCS.fits")) + hdu2 = generate_fits_file(str(tmp_path / "corrtag.fits")) inhdr = hdu2[1].header cd = hdu[1].data.columns WCS_keywords = ['TCTYP*', @@ -243,9 +218,6 @@ def test_remove_wcs_keywords(): for keys in WCS_keywords: assert keys not in newheader assert len(inhdr[keys]) > 0 - # Cleanup - os.remove("removeWCS.fits") - os.remove("corrtag.fits") def test_dummy_gti(): # Setup @@ -257,15 +229,14 @@ def test_dummy_gti(): assert dummy_hdu.data[0][1] == test_exptime_value -def test_return_gti(): +def test_return_gti(tmp_path): + filename = str(tmp_path / "gti_file.fits") # Setup - hdu = generate_fits_file("gti_file.fits") + hdu = generate_fits_file(filename) # Test - gti = cosutil.returnGTI("gti_file.fits") + gti = cosutil.returnGTI(filename) # Verify np.testing.assert_array_equal(list(hdu[2].data), gti) - # Cleanup - os.remove("gti_file.fits") def test_err_frequentist(): @@ -387,15 +358,15 @@ def test_fit_quadratic(): np.testing.assert_array_equal(expected_quadratic[1], fitted_quadratic[1]) -def test_change_segment(): +def test_change_segment(tmp_path): # Setup - filename1 = "testfits_a.fits" - filename2 = "testfits_b.fits" - filename3 = "testfits.fits" + filename1 = str(tmp_path / "testfits_a.fits") + filename2 = str(tmp_path / "testfits_b.fits") + filename3 = str(tmp_path / "testfits.fits") # Expected - fname1 = "testfits_b.fits" - fname2 = "testfits_a.fits" - fname3 = "testfits.fits" + fname1 = filename1 + fname2 = filename2 + fname3 = filename3 # Test test_name1 = cosutil.changeSegment(filename1, "FUV", "FUVB") test_name2 = cosutil.changeSegment(filename2, "FUV", "FUVA") @@ -406,12 +377,12 @@ def test_change_segment(): assert fname3 == test_name3 -def test_copy_exptime_keywords(): +def test_copy_exptime_keywords(tmp_path): # Setup # create two files - generate_fits_file("original.fits") - generate_fits_file("copy.fits") - files = ["original.fits", "copy.fits"] + files = [str(tmp_path / "original.fits"), str(tmp_path / "copy.fits")] + for file in files: + generate_fits_file(file) headers = ["expstart", "expend", "exptime", "rawtime"] # set values to the exposure time for file in files: @@ -423,26 +394,24 @@ def test_copy_exptime_keywords(): else: fits.setval(file, header, value=0, ext=1) # get header of the files - inhdr = fits.getheader("original.fits", 1) - outhdr = fits.getheader("copy.fits", 1) + inhdr = fits.getheader(files[0], 1) + outhdr = fits.getheader(files[1], 1) # Test cosutil.copyExptimeKeywords(inhdr, outhdr) # Verify for header in headers: assert inhdr[header] == outhdr[header] - # Cleanup - for tempfile in files: - os.remove(tempfile) -def test_copy_voltage_keywords(): +def test_copy_voltage_keywords(tmp_path): # Setup + original = [str(tmp_path / "originalFUV.fits"), str(tmp_path / "originalNUV.fits")] + copy = [str(tmp_path / "copyFUV.fits"), str(tmp_path / "copyNUV.fits")] + # create two files each for FUV and NUV - generate_fits_file("originalFUV.fits") - generate_fits_file("originalNUV.fits") - generate_fits_file("copyFUV.fits") - generate_fits_file("copyNUV.fits") - original = ["originalFUV.fits", "originalNUV.fits"] - copy = ["copyFUV.fits", "copyNUV.fits"] + for file in original: + generate_fits_file(file) + for file in copy: + generate_fits_file(file) detectors = ["FUV", "NUV"] fuv_headers = ["dethvla", "dethvlb", "dethvca", "dethvcb", "dethvna", "dethvnb"] nuv_headers = ["dethvl", "dethvc"] @@ -469,10 +438,10 @@ def test_copy_voltage_keywords(): for header in nuv_headers: fits.setval(fileName, header, value=0., ext=1) index += 1 - in_FUV_hdr = fits.getheader("originalFUV.fits", 1) - in_NUV_hdr = fits.getheader("originalNUV.fits", 1) - out_FUV_hdr = fits.getheader("copyFUV.fits", 1) - out_NUV_hdr = fits.getheader("copyNUV.fits", 1) + in_FUV_hdr = fits.getheader(original[0], 1) + in_NUV_hdr = fits.getheader(original[1], 1) + out_FUV_hdr = fits.getheader(copy[0], 1) + out_NUV_hdr = fits.getheader(copy[1], 1) # Test 1 cosutil.copyVoltageKeywords(in_FUV_hdr, out_FUV_hdr, detectors[0]) # Verify 1 @@ -483,17 +452,12 @@ def test_copy_voltage_keywords(): # Verify 2 for header in nuv_headers: assert in_NUV_hdr[header] == out_NUV_hdr[header] - # Cleanup - for tempfile in original: - os.remove(tempfile) - for tempfile in copy: - os.remove(tempfile) -def test_copy_sub_keywords(): +def test_copy_sub_keywords(tmp_path): # Setup - generate_fits_file("subKeywords.fits") - generate_fits_file("copySubKeywords.fits") - files = ["subKeywords.fits", "copySubKeywords.fits"] + files = [str(tmp_path / "subKeywords.fits"), str(tmp_path / "copySubKeywords.fits")] + for file in files: + generate_fits_file(file) headers = ["corner%1dx", "corner%1dy", "size%1dx", "size%1dy"] for file in files: @@ -519,9 +483,6 @@ def test_copy_sub_keywords(): cosutil.copySubKeywords(inhdr, outhdr, True) # check if nsubarry has been set to 0 assert inhdr["nsubarry"] == outhdr["nsubarry"] - # Cleanup - for tempfile in files: - os.remove(tempfile) def test_modify_asn_mtyp(): @@ -543,13 +504,13 @@ def test_modify_asn_mtyp(): assert string3 == val3 -def test_rename_file(): +def test_rename_file(tmp_path): # Setup - original_filename1 = "raw-file.fits" - original_filename2 = "product0_file_a.fits" + original_filename1 = str(tmp_path / "raw-file.fits") + original_filename2 = str(tmp_path / "product0_file_a.fits") - new_filename1 = "renamed_raw_file.fits" - new_filename2 = "renamed0_file_a.fits" + new_filename1 = str(tmp_path / "renamed_raw_file.fits") + new_filename2 = str(tmp_path / "renamed0_file_a.fits") # Create the files generate_fits_file(original_filename1) generate_fits_file(original_filename2) @@ -559,14 +520,11 @@ def test_rename_file(): # Verify assert os.path.exists(new_filename1) assert os.path.exists(new_filename2) - # Cleanup - os.remove(new_filename1) - os.remove(new_filename2) -def test_del_corrtag_wcs(): +def test_del_corrtag_wcs(tmp_path): # Setup - generate_fits_file("del_corrtagWCS.fits") + generate_fits_file(str(tmp_path / "del_corrtagWCS.fits")) thdr = fits.getheader("del_corrtagWCS.fits", 3) tkey = ["TCTYP2", "TCRVL2", "TCRPX2", "TCDLT2", "TCUNI2", "TC2_2", "TC2_3", "TCTYP3", "TCRVL3", "TCRPX3", "TCDLT3", "TCUNI3", "TC3_2", "TC3_3"] @@ -575,8 +533,6 @@ def test_del_corrtag_wcs(): # Verify for key in tkey: assert key not in thdr - # Cleanup - os.remove("del_corrtagWCS.fits") def test_set_verbosity(): @@ -819,10 +775,8 @@ def test_segment_specific_keyword(): assert key2 == root2 -def test_find_ref_file(): - # Setup - generate_fits_file("wrong_file.fits") - generate_fits_file("test.fits") +def test_find_ref_file(tmp_path): + generate_fits_file(str(tmp_path / "test.fits")) fits.setval("wrong_file.fits", 'FILETYPE', value='FLAT FIELD REFERENCE IMAGE') # Missing ref1 = {"keyword": "FLATFILE", "filename": "test_flt.fits", "calcos_ver": "3.0", @@ -854,9 +808,6 @@ def test_find_ref_file(): assert actual_missing == missing1 assert actual_bad_ver == bad_ver2 assert actual_wrong_ver == wrong_f3 - # Cleanup - os.remove("wrong_file.fits") - os.remove("test.fits") def test_cmp_version(): @@ -873,13 +824,13 @@ def test_cmp_version(): assert expected_cmp[i] == test_cmp[i] -def test_get_pedigree(): +def test_get_pedigree(tmp_path): # Setup capture_msg = io.StringIO() sys.stdout = capture_msg switch = "perform" refkey = "statflag" - filename = "test_flt.file" + filename = str(tmp_path / "test_flt.file") generate_fits_file(filename) err_msg = "Warning: STATFLAG test_flt.file is a dummy file\n" \ " so PERFORM will not be done.\n" @@ -894,37 +845,36 @@ def test_get_pedigree(): assert pedgr1 == "OK" assert pedgr2 == "DUMMY" assert err_msg == capture_msg.getvalue() - # Cleanup - os.remove(filename) -def test_get_aperture_keyword(): +def test_get_aperture_keyword(tmp_path): + filename = str(tmp_path / "aperture_test.fits") # Setup - generate_fits_file("aperture_test.fits") + generate_fits_file(filename) # condition 1 - fits.setval("aperture_test.fits", "aperture", value="PSA-FUV", ext=0) - fits.setval("aperture_test.fits", "propaper", value="PSA-FUV", ext=0) - hdr1 = fits.getheader("aperture_test.fits", ext=0) + fits.setval(filename, "aperture", value="PSA-FUV", ext=0) + fits.setval(filename, "propaper", value="PSA-FUV", ext=0) + hdr1 = fits.getheader(filename, ext=0) # condition 2 - fits.setval("aperture_test.fits", "aperture", value="RelMvReq", ext=0) - fits.setval("aperture_test.fits", "propaper", value="WCA", ext=0) - hdr2 = fits.getheader("aperture_test.fits", ext=0) + fits.setval(filename, "aperture", value="RelMvReq", ext=0) + fits.setval(filename, "propaper", value="WCA", ext=0) + hdr2 = fits.getheader(filename, ext=0) # condition 3 - fits.setval("aperture_test.fits", "propaper", value="NA", ext=0) - fits.setval("aperture_test.fits", "shutter", value="closed", ext=0) - fits.setval("aperture_test.fits", "lampused", value="P", ext=0) - hdr3 = fits.getheader("aperture_test.fits", ext=0) + fits.setval(filename, "propaper", value="NA", ext=0) + fits.setval(filename, "shutter", value="closed", ext=0) + fits.setval(filename, "lampused", value="P", ext=0) + hdr3 = fits.getheader(filename, ext=0) # condition 4 - fits.setval("aperture_test.fits", "lampused", value="D", ext=0) - hdr4 = fits.getheader("aperture_test.fits", ext=0) + fits.setval(filename, "lampused", value="D", ext=0) + hdr4 = fits.getheader(filename, ext=0) # condition 5 - fits.setval("aperture_test.fits", "lampused", value="A", ext=0) - hdr5 = fits.getheader("aperture_test.fits", ext=0) + fits.setval(filename, "lampused", value="A", ext=0) + hdr5 = fits.getheader(filename, ext=0) # condition 6 - fits.setval("aperture_test.fits", "shutter", value="open", ext=0) - fits.setval("aperture_test.fits", "life_adj", value=2, ext=0) - fits.setval("aperture_test.fits", "aperypos", value=4.56, ext=0) - hdr6 = fits.getheader("aperture_test.fits", ext=0) + fits.setval(filename, "shutter", value="open", ext=0) + fits.setval(filename, "life_adj", value=2, ext=0) + fits.setval(filename, "aperypos", value=4.56, ext=0) + hdr6 = fits.getheader(filename, ext=0) # Expected values rtn1 = ('PSA', 'APERTURE changed from PSA-FUV to PSA') @@ -947,29 +897,25 @@ def test_get_aperture_keyword(): assert rtn4 == rtn_value4 assert rtn5 == rtn_value5 assert rtn6 == rtn_value6 - # Cleanup - os.remove("aperture_test.fits") -def test_write_version_to_trailer(): +def test_write_version_to_trailer(tmp_path): capture_msg = io.StringIO() sys.stdout = capture_msg - generate_fits_file("dummy_file.fits") - ascii_file = open("ascii.txt", mode="w") + generate_fits_file(str(tmp_path / "dummy_file.fits")) + ascii_file = open(str(tmp_path / "ascii.txt"), mode="w") cosutil.fd_trl = ascii_file cosutil.CALCOS_VERSION = '3.1.0' cosutil.writeVersionToTrailer() sys.stdout = sys.__stdout__ assert ascii_file == cosutil.fd_trl assert capture_msg.getvalue() == '' - # Cleanup - os.remove("ascii.txt") - os.remove("dummy_file.fits") -def test_get_switch(): +def test_get_switch(tmp_path): # Setup - generate_fits_file("switch.fits") - phdr = fits.getheader("switch.fits", ext=0) + filename = str(tmp_path / "switch.fits") + generate_fits_file(filename) + phdr = fits.getheader(filename, ext=0) keyword = "statflag" # Test switch = cosutil.getSwitch(phdr, keyword) @@ -980,47 +926,42 @@ def test_get_switch(): assert switch == 'PERFORM' assert switch2 == 'OMIT' assert switch3 == 'N/A' - # Cleanup - os.remove("switch.fits") -def test_temp_pulse_height_range(): - generate_fits_file("pulseHeightRef.fits") +def test_temp_pulse_height_range(tmp_path): + filename = "pulseHeightRef.fits" + generate_fits_file(filename) true_pha_value = 4 - fits.setval("pulseHeightRef.fits", "pharange", value=true_pha_value, ext=0) - fits.getheader("pulseHeightRef.fits", ext=0) - pha_value = cosutil.tempPulseHeightRange('pulseHeightRef.fits') + fits.setval(filename, "pharange", value=true_pha_value, ext=0) + fits.getheader(filename, ext=0) + pha_value = cosutil.tempPulseHeightRange(filename) assert true_pha_value == pha_value - # Cleanup - os.remove("pulseHeightRef.fits") -def test_get_pulse_height_range(): - generate_fits_file("pulseHeightRef.fits") - fits.setval("pulseHeightRef.fits", "phalowrA", value=7, ext=0) - fits.setval("pulseHeightRef.fits", "phaupprA", value=10, ext=0) - hdu = fits.getheader("pulseHeightRef.fits", ext=0) +def test_get_pulse_height_range(tmp_path): + filename = str(tmp_path / "pulseHeightRef.fits") + generate_fits_file(filename) + fits.setval(filename, "phalowrA", value=7, ext=0) + fits.setval(filename, "phaupprA", value=10, ext=0) + hdu = fits.getheader(filename, ext=0) seg = ['FUVA', 'FUVB'] actual = [' 7_10', None] for s, a in zip(seg, actual): test_str = cosutil.getPulseHeightRange(hdu, s) assert a == test_str - # Cleanup - os.remove("pulseHeightRef.fits") -def test_time_at_midpoint(): +def test_time_at_midpoint(tmp_path): # Setup - generate_fits_file("test_timeAtMidpoint.fits") - hdr = fits.getheader("test_timeAtMidpoint.fits", ext=1) + filename = str(tmp_path / "test_timeAtMidpoint.fits") + generate_fits_file(filename) + hdr = fits.getheader(filename, ext=1) info = {'expstart': hdr['TCRPX7'], 'expend': hdr['TCRVL7']} average = 4776.314586556611 # Test test_average = cosutil.timeAtMidpoint(info) # Verify assert average == test_average - # Cleanup - os.remove("test_timeAtMidpoint.fits") def test_timeline_times(): @@ -1049,18 +990,19 @@ def test_combine_stat(): assert actual[key] == test[key] -def test_override_keywords(): +def test_override_keywords(tmp_path): # Setup - generate_fits_file("overridekeywords.fits") + filename = str(tmp_path / "overridekeywords.fits") + generate_fits_file(filename) info = {"cal_ver": 3.1, "opt_elem": 2, "cenwave": 0.34, "fpoffset": 3.43, "obstype": "FUV", "exptype": "N/A", "aperture": "PSA", "x_offset": 1.2, "dispaxis": 2.5} switches = {"statflag": "PERFORM", "flatcorr": "PERFORM", "geocorr": "COMPLETE", "randcorr": "SKIPPED"} reffiles = {"flatfile": "abc_flat.fits", "flt_hdr": "lref$abc_flat.fits"} - fits.setval("overridekeywords.fits", "flt_hdr", value="NA", ext=0) - fits.setval("overridekeywords.fits", "dispaxis", value=2.9, ext=1) - fits.setval("overridekeywords.fits", "x_offset", value=1.9, ext=1) - phdr = fits.getheader("overridekeywords.fits", ext=0) - hdr = fits.getheader("overridekeywords.fits", ext=1) + fits.setval(filename, "flt_hdr", value="NA", ext=0) + fits.setval(filename, "dispaxis", value=2.9, ext=1) + fits.setval(filename, "x_offset", value=1.9, ext=1) + phdr = fits.getheader(filename, ext=0) + hdr = fits.getheader(filename, ext=1) # Actual Values val1 = True val2 = switches.values() @@ -1068,8 +1010,6 @@ def test_override_keywords(): val4 = reffiles["flt_hdr"] # Test cosutil.overrideKeywords(phdr, hdr, info, switches, reffiles) - phdr = fits.getheader("overridekeywords.fits", ext=0) + phdr = fits.getheader(filename, ext=0) # Verify assert val1 == phdr["statflag"] - # Cleanup - os.remove("overridekeywords.fits") diff --git a/tests/test_extract.py b/tests/test_extract.py index 40af48a..810c875 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -4,12 +4,12 @@ import numpy as np from generate_tempfiles import generate_fits_file -def test_get_columns(): +def test_get_columns(tmp_path): """ Test if the function is returning the right column fields """ # Setup - test_data = generate_fits_file("lbgu17qnq_corrtag_a.fits") + test_data = generate_fits_file(str(tmp_path / "lbgu17qnq_corrtag_a.fits")) dt = test_data[1].data detector = "FUV" @@ -31,18 +31,16 @@ def test_get_columns(): np.testing.assert_array_equal(yfull, yf) np.testing.assert_array_equal(dq, dq2) np.testing.assert_array_equal(epsilon, epsilon2) - # Cleanup - os.remove("lbgu17qnq_corrtag_a.fits") -def test_remove_unwanted_column(): +def test_remove_unwanted_column(tmp_path): """ Old column length should be equal to new column length + amount of the removed columns """ # Setup target_cols = ['XFULL', 'YFULL'] # Truth - fd = generate_fits_file("lbgu17qnq_lampflash.fits") + fd = generate_fits_file(str(tmp_path / "lbgu17qnq_lampflash.fits")) table = fd[1].data cols = table.columns @@ -56,8 +54,6 @@ def test_remove_unwanted_column(): deleted_cols = deleted_cols[np.argsort(temp_cols)] # assert target_cols[0] == deleted_cols[0].name # assert target_cols[1] == deleted_cols[1].name - # Cleanup - os.remove('lbgu17qnq_lampflash.fits') def test_next_power_of_two(): @@ -72,10 +68,10 @@ def test_next_power_of_two(): assert next_power == extract.next_power_of_two(7) -def test_add_column_comment(): +def test_add_column_comment(tmp_path): # verify if entered comment to a header is present in the fits file. # Setup - ofd = generate_fits_file("myFitsFile.fits") + ofd = generate_fits_file(str(tmp_path / "myFitsFile.fits")) comment = "This comment is generated by a unit-test." # Exercise @@ -83,6 +79,4 @@ def test_add_column_comment(): # Verify assert comment == test_table[1].header.comments['TTYPE1'] - # Cleanup - os.remove('myFitsFile.fits') diff --git a/tests/test_shiftfile.py b/tests/test_shiftfile.py index 6b38e25..a098910 100644 --- a/tests/test_shiftfile.py +++ b/tests/test_shiftfile.py @@ -21,20 +21,18 @@ def create_shift_file(filename): return -def test_shift_file(): - shift_file = "shift_file.txt" +def test_shift_file(tmp_path): + shift_file = str(tmp_path / "shift_file.txt") create_shift_file(shift_file) # Test ob = shiftfile.ShiftFile(shift_file, 'abc123def', 'any') # Verify assert len(ob.user_shift_dict) > 0 - # Cleanup - os.remove(shift_file) -def test_get_shifts(): +def test_get_shifts(tmp_path): # Setup - shift_file = "shift_file.txt" + shift_file = str(tmp_path / "shift_file.txt") create_shift_file(shift_file) ob1 = shiftfile.ShiftFile(shift_file, 'ghi456jkl', 'any') ob2 = shiftfile.ShiftFile(shift_file, 'abc123def', 'any') @@ -52,6 +50,4 @@ def test_get_shifts(): for i in range(len(expected_values1)): assert expected_values1[i] == test_values1[i] assert expected_values2[i] == test_values2[i] - # Cleanup - os.remove(shift_file)