Skip to content

Commit

Permalink
more functions, better adopted switching
Browse files Browse the repository at this point in the history
  • Loading branch information
kelle committed Jul 24, 2024
1 parent 76af162 commit 402ac81
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 113 deletions.
158 changes: 85 additions & 73 deletions simple/utils/spectral_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,6 @@ def ingest_spectral_type(
else:
db_name = db_name[0]

source_spt_data = (
db.query(db.SpectralTypes).filter(db.SpectralTypes.c.source == db_name).table()
)
logger.debug(f"Pre-existing Spectral Type data: \n {source_spt_data}")

# Check for duplicates
duplicate_check = (
db.query(db.SpectralTypes.c.source)
Expand All @@ -101,37 +96,9 @@ def ingest_spectral_type(
else:
logger.debug(f"No duplicates found for : {db_name}, {regime}, {reference}")

# set adopted flag
if len(source_spt_data) == 0:
adopted = True
old_adopted = None
logger.debug(
"No Spectral Type data for this source in the database, setting adopted flag to True"
)
elif len(source_spt_data) > 0:
# Spectral Type Data already exists
adopted_ind = source_spt_data["adopted"] == 1
if sum(adopted_ind):
old_adopted = source_spt_data[adopted_ind]
logger.debug(f"Old adopted data: {old_adopted}")
if (
old_adopted["spectral_type_error"] is not None
and spectral_type_error is not None
):
if spectral_type_error < min(old_adopted["spectral_type_error"]):
adopted = True
else:
adopted = False
logger.debug(f"The new spectral type's adopted flag is:, {adopted}")
else:
adopted = True
logger.debug(
"No spectral type error found, setting adopted flag to True"
)
else:
adopted = True
old_adopted = None
logger.debug("No adopted data found, setting adopted flag to True")
adopted = adopt_spectral_type(db, db_name, spectral_type_error)
if adopted:
unset_previously_adopted(db, db_name)

if spectral_type_code is None:
spectral_type_code = convert_spt_string_to_code(spectral_type_string)
Expand Down Expand Up @@ -159,22 +126,6 @@ def ingest_spectral_type(
session.add(spt_obj)
session.commit()
logger.info(f"Spectral type added to database: {spt_data}\n")

# unset old adopted only after ingest is successful!
if adopted and old_adopted is not None:
with db.engine.connect() as conn:
conn.execute(
db.SpectralTypes.update()
.where(
and_(
db.SpectralTypes.c.source == old_adopted["source"][0],
db.SpectralTypes.c.regime == old_adopted["regime"][0],
db.SpectralTypes.c.reference == old_adopted["reference"][0],
)
)
.values(adopted=False)
)
conn.commit()
except sqlalchemy.exc.IntegrityError as e:
if (
db.query(db.Publications)
Expand All @@ -198,32 +149,23 @@ def ingest_spectral_type(
else:
logger.warning(msg)

# check that adopted flag is successfully changed
if old_adopted is not None:
old_adopted_data = (
db.query(db.SpectralTypes)
.filter(
and_(
db.SpectralTypes.c.source == old_adopted["source"][0],
db.SpectralTypes.c.regime == old_adopted["regime"][0],
db.SpectralTypes.c.reference == old_adopted["reference"][0],
)
)
.table()
)
logger.debug("Old adopted measurement unset")
logger.debug(f"Old adopted data:\n {old_adopted_data}")

# check that there is only one adopted measurement
results = (
db.query(db.SpectralTypes)
.filter(
and_(db.SpectralTypes.c.source == db_name, db.SpectralTypes.c.adopted == 1)
and_(
db.SpectralTypes.c.source == db_name, db.SpectralTypes.c.adopted == True
)
)
.table()
)
logger.debug(f"Adopted measurements for {db_name}: {results}")
if len(results) > 1:
logger.debug(f"Adopted measurements for {db_name}:{results}")
if logger.level <= 10:
results.pprint_all()
logger.debug(f"adopted column: {results['adopted']}")
if len(results) == 1:
logger.debug(f"One adopted measurement for {db_name}")
elif len(results) > 2:
msg = f"Multiple adopted measurements for {db_name}"
if raise_error:
logger.error(msg)
Expand All @@ -237,8 +179,6 @@ def ingest_spectral_type(
raise AstroDBError(msg)
else:
logger.warning(msg)
else:
logger.debug(f"Adopted measurement for {db_name}: {results}")


def convert_spt_string_to_code(spectral_type_string):
Expand Down Expand Up @@ -317,3 +257,75 @@ def convert_spt_code_to_string(spectral_code, decimals=1):
logger.debug(f"Converting: {spectral_code} -> {spt_type}")

return spt_type


def adopt_spectral_type(db, source, spectral_type_error):
source_spt_data = (
db.query(db.SpectralTypes).filter(db.SpectralTypes.c.source == source).table()
)

# set adopted flag
if len(source_spt_data) == 0:
logger.debug(
"No Spectral Type data for this source in the database, setting adopted flag to True"
)
return True
elif len(source_spt_data) > 0:
# Spectral Type Data already exists
logger.debug("Pre-existing Spectral Type data:")
if logger.level <= 10:
source_spt_data.pprint_all()
adopted_ind = source_spt_data["adopted"] == 1
if sum(adopted_ind):
old_adopted = source_spt_data[adopted_ind]
logger.debug(f"Old adopted data: {old_adopted}")
if (
old_adopted["spectral_type_error"] is not None
and spectral_type_error is not None
):
if spectral_type_error < min(old_adopted["spectral_type_error"]):
adopted = True
logger.debug(f"The new spectral type's adopted flag is:, {adopted}")
else:
adopted = False
logger.debug(f"The new spectral type's adopted flag is: {adopted}")
return adopted
else:
return True
logger.debug(
"No spectral type error found, setting adopted flag to True"
)
else:
return True
logger.debug("No adopted data found, setting adopted flag to True")


def unset_previously_adopted(db, source):
source_spt_data = (
db.query(db.SpectralTypes).filter(db.SpectralTypes.c.source == source).table()
)
logger.debug(f"Pre-existing Spectral Type data: \n {source_spt_data}")

if len(source_spt_data) == 0:
logger.debug("No previous data for this source in the database, doing nothing")
return
elif len(source_spt_data) > 0:
# Spectral Type Data already exists
adopted_ind = source_spt_data["adopted"] == 1
if sum(adopted_ind):
old_adopted = source_spt_data[adopted_ind]
with db.engine.connect() as conn:
conn.execute(
db.SpectralTypes.update()
.where(
and_(
db.SpectralTypes.c.source == old_adopted["source"][0],
db.SpectralTypes.c.regime == old_adopted["regime"][0],
db.SpectralTypes.c.reference == old_adopted["reference"][0],
)
)
.values(adopted=False)
)
conn.commit()
else:
logger.debug("No previously adopted data found, doing nothing")
86 changes: 46 additions & 40 deletions tests/test_spectral_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,57 +23,63 @@ def test_convert_spt_code_to_string():
assert convert_spt_code_to_string(92, decimals=0) == "Y2"


def test_ingest_spectral_type(temp_db):
spt_data1 = {
"source": "Fake 1",
"spectral_type": "M5.6",
"regime": "nir",
"reference": "Ref 1",
}
spt_data2 = {
"source": "Fake 2",
"spectral_type": "T0.1",
"regime": "nir",
"reference": "Ref 1",
}
spt_data3 = {
"source": "Fake 3",
"spectral_type": "Y2pec",
"regime": "nir",
"reference": "Ref 2",
}
for spt_data in [spt_data1, spt_data2, spt_data3]:
ingest_spectral_type(
temp_db,
source=spt_data["source"],
spectral_type_string=spt_data["spectral_type"],
spectral_type_error=1.0,
reference=spt_data["reference"],
regime=spt_data["regime"],
)
results = (
temp_db.query(temp_db.SpectralTypes)
.filter(temp_db.SpectralTypes.c.source == spt_data["source"])
.table()
)
assert len(results) == 1, f"Expecting this data: {spt_data} in \n {results}"
assert results["adopted"][0] == True # noqa: E712
@pytest.mark.parametrize(
"test_input",
[
{
"source": "Fake 1",
"spectral_type": "M5.6",
"regime": "nir",
"reference": "Ref 1",
},
{
"source": "Fake 2",
"spectral_type": "T0.1",
"regime": "nir",
"reference": "Ref 1",
},
{
"source": "Fake 3",
"spectral_type": "Y2pec",
"regime": "nir",
"reference": "Ref 2",
},
],
)
def test_ingest_spectral_type(temp_db, test_input):
ingest_spectral_type(
temp_db,
source=test_input["source"],
spectral_type_string=test_input["spectral_type"],
spectral_type_error=1.0,
reference=test_input["reference"],
regime=test_input["regime"],
)
results = (
temp_db.query(temp_db.SpectralTypes)
.filter(temp_db.SpectralTypes.c.source == test_input["source"])
.table()
)
assert len(results) == 1, f"Expecting this data: {test_input} in \n {results}"
assert results["adopted"][0] == True # noqa: E712


def test_ingest_spectral_type_multiple(temp_db):
assert (
temp_db.query(temp_db.SpectralTypes)
.filter(temp_db.SpectralTypes.c.reference == "Ref 1")
.count()
== 2
)
results = (
results_ref2 = (
temp_db.query(temp_db.SpectralTypes)
.filter(temp_db.SpectralTypes.c.reference == "Ref 2")
.table()
)
assert len(results) == 1
assert results["source"][0] == "Fake 3"
assert results["spectral_type_string"][0] == "Y2pec"
assert results["spectral_type_code"][0] == [92]
assert len(results_ref2) == 1
assert results_ref2["source"][0] == "Fake 3"
assert results_ref2["spectral_type_string"][0] == "Y2pec"
assert results_ref2["spectral_type_code"][0] == [92]


def test_ingest_spectral_type_adopted(temp_db):
Expand Down

0 comments on commit 402ac81

Please sign in to comment.