Skip to content

Commit

Permalink
bugfixes r2
Browse files Browse the repository at this point in the history
  • Loading branch information
ggmarshall committed Feb 10, 2025
1 parent 50a97d1 commit 7b70cba
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 71 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ dependencies = [
"pygama>=2",
"dspeed>=1.6",
"pylegendmeta==1.2.0a2",
"legend-pydataobj>=1.11.4",
"legend-pydataobj>=1.11.6",
"legend-daq2lh5>=1.4",
]

Expand Down
11 changes: 11 additions & 0 deletions workflow/rules/dsp_pars_geds.smk
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ rule build_pars_dsp_tau_geds:
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
"--channel {params.channel} "
"--raw-table-name {params.raw_table_name} "
"--plot-path {output.plots} "
"--output-file {output.decay_const} "
"--pulser-file {input.pulser} "
Expand Down Expand Up @@ -80,6 +81,7 @@ rule build_pars_evtsel_geds:
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
"--channel {params.channel} "
"--raw-table-name {params.raw_table_name} "
"--peak-file {output.peak_file} "
"--pulser-file {input.pulser_file} "
"--decay-const {input.database} "
Expand Down Expand Up @@ -120,6 +122,7 @@ rule build_pars_dsp_nopt_geds:
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
"--channel {params.channel} "
"--raw-table-name {params.raw_table_name} "
"--inplots {input.inplots} "
"--plot-path {output.plots} "
"--dsp-pars {output.dsp_pars_nopt} "
Expand All @@ -139,6 +142,9 @@ rule build_pars_dsp_dplms_geds:
timestamp="{timestamp}",
datatype="cal",
channel="{channel}",
raw_table_name=lambda wildcards: get_table_name(
metadata, config, "cal", wildcards.timestamp, wildcards.channel, "raw"
),
output:
dsp_pars=temp(get_pattern_pars_tmp_channel(config, "dsp", "dplms")),
lh5_path=temp(get_pattern_pars_tmp_channel(config, "dsp", extension="lh5")),
Expand All @@ -159,6 +165,7 @@ rule build_pars_dsp_dplms_geds:
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
"--channel {params.channel} "
"--raw-table-name {params.raw_table_name} "
"--dsp-pars {output.dsp_pars} "
"--lh5-path {output.lh5_path} "
"--plot-path {output.plots} "
Expand All @@ -174,6 +181,9 @@ rule build_pars_dsp_eopt_geds:
timestamp="{timestamp}",
datatype="cal",
channel="{channel}",
raw_table_name=lambda wildcards: get_table_name(
metadata, config, "cal", wildcards.timestamp, wildcards.channel, "raw"
),
output:
dsp_pars=temp(get_pattern_pars_tmp_channel(config, "dsp_eopt")),
qbb_grid=temp(
Expand All @@ -192,6 +202,7 @@ rule build_pars_dsp_eopt_geds:
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
"--channel {params.channel} "
"--raw-table-name {params.raw_table_name} "
"--peak-file {input.peak_file} "
"--inplots {input.inplots} "
"--decay-const {input.decay_const} "
Expand Down
4 changes: 4 additions & 0 deletions workflow/rules/tcm.smk
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ rule build_pulser_ids:
timestamp="{timestamp}",
datatype="cal",
channel="{channel}",
rawid=lambda wildcards: metadata.channelmap(wildcards.timestamp, system="cal")[
wildcards.channel
].daq.rawid,
output:
pulser=temp(get_pattern_pars_tmp_channel(config, "tcm", "pulser_ids")),
log:
Expand All @@ -61,6 +64,7 @@ rule build_pulser_ids:
"--datatype {params.datatype} "
"--timestamp {params.timestamp} "
"--channel {params.channel} "
"--rawid {params.rawid} "
"--tcm-files {params.input} "
"--pulser-file {output.pulser} "
"--metadata {meta} "
22 changes: 12 additions & 10 deletions workflow/src/legenddataflow/scripts/par/geds/dsp/dplms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pygama.pargen.dplms_ge_dict import dplms_ge_dict

from .....log import build_log
from ....table_name import get_table_name


def par_geds_dsp_dplms() -> None:
Expand All @@ -23,11 +22,13 @@ def par_geds_dsp_dplms() -> None:

argparser.add_argument("--log", help="log_file", type=str)
argparser.add_argument("--configs", help="configs", type=str, required=True)
argparser.add_argument("--metadata", help="metadata", type=str, required=True)

argparser.add_argument("--datatype", help="Datatype", type=str, required=True)
argparser.add_argument("--timestamp", help="Timestamp", type=str, required=True)
argparser.add_argument("--channel", help="Channel", type=str, required=True)
argparser.add_argument(
"--raw-table-name", help="raw table name", type=str, required=True
)

argparser.add_argument("--dsp-pars", help="dsp_pars", type=str, required=True)
argparser.add_argument("--lh5-path", help="lh5_path", type=str, required=True)
Expand All @@ -41,8 +42,6 @@ def par_geds_dsp_dplms() -> None:
log = build_log(config_dict, args.log)
sto = lh5.LH5Store()

channel = get_table_name(args.metadata, args.timestamp, args.datatype, args.channel)

configs = TextDB(args.configs).on(args.timestamp, system=args.datatype)
dsp_config = config_dict["inputs"]["proc_chain"][args.channel]

Expand All @@ -57,10 +56,13 @@ def par_geds_dsp_dplms() -> None:

t0 = time.time()
log.info("\nLoad fft data")
energies = sto.read(f"{channel}/raw/daqenergy", fft_files)[0]
energies = sto.read(f"{args.raw_table_name}/daqenergy", fft_files)[0]
idxs = np.where(energies.nda == 0)[0]
raw_fft = sto.read(
f"{channel}/raw", fft_files, n_rows=dplms_dict["n_baselines"], idx=idxs
f"{args.raw_table_name}/raw",
fft_files,
n_rows=dplms_dict["n_baselines"],
idx=idxs,
)[0]
t1 = time.time()
log.info(f"Time to load fft data {(t1-t0):.2f} s, total events {len(raw_fft)}")
Expand All @@ -70,14 +72,14 @@ def par_geds_dsp_dplms() -> None:
# kev_widths = [tuple(kev_width) for kev_width in dplms_dict["kev_widths"]]

peaks_rounded = [int(peak) for peak in peaks_kev]
peaks = sto.read(f"{channel}/raw", args.peak_file, field_mask=["peak"])[0][
peaks = sto.read(args.raw_table_name, args.peak_file, field_mask=["peak"])[0][
"peak"
].nda
ids = np.isin(peaks, peaks_rounded)
peaks = peaks[ids]
# idx_list = [np.where(peaks == peak)[0] for peak in peaks_rounded]

raw_cal = sto.read(f"{channel}/raw", args.peak_file, idx=ids)[0]
raw_cal = sto.read(args.raw_table_name, args.peak_file, idx=ids)[0]
log.info(
f"Time to run event selection {(time.time()-t1):.2f} s, total events {len(raw_cal)}"
)
Expand Down Expand Up @@ -111,7 +113,7 @@ def par_geds_dsp_dplms() -> None:
coeffs = out_dict["dplms"].pop("coefficients")
dplms_pars = Table(col_dict={"coefficients": Array(coeffs)})
out_dict["dplms"]["coefficients"] = (
f"loadlh5('{args.lh5_path}', '{channel}/dplms/coefficients')"
f"loadlh5('{args.lh5_path}', '{args.channel}/dplms/coefficients')"
)

log.info(f"DPLMS creation finished in {(time.time()-t0)/60} minutes")
Expand All @@ -129,7 +131,7 @@ def par_geds_dsp_dplms() -> None:
Path(args.lh5_path).parent.mkdir(parents=True, exist_ok=True)
sto.write(
Table(col_dict={"dplms": dplms_pars}),
name=channel,
name=args.channel,
lh5_file=args.lh5_path,
wo_mode="overwrite",
)
Expand Down
11 changes: 5 additions & 6 deletions workflow/src/legenddataflow/scripts/par/geds/dsp/eopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)

from .....log import build_log
from ....table_name import get_table_name

warnings.filterwarnings(action="ignore", category=RuntimeWarning)
warnings.filterwarnings(action="ignore", category=np.RankWarning)
Expand All @@ -34,11 +33,13 @@ def par_geds_dsp_eopt() -> None:

argparser.add_argument("--log", help="log_file", type=str)
argparser.add_argument("--configs", help="configs", type=str, required=True)
argparser.add_argument("--metadata", help="metadata", type=str, required=True)

argparser.add_argument("--datatype", help="Datatype", type=str, required=True)
argparser.add_argument("--timestamp", help="Timestamp", type=str, required=True)
argparser.add_argument("--channel", help="Channel", type=str, required=True)
argparser.add_argument(
"--raw-table-name", help="raw table name", type=str, required=True
)

argparser.add_argument(
"--final-dsp-pars", help="final_dsp_pars", type=str, required=True
Expand All @@ -59,8 +60,6 @@ def par_geds_dsp_eopt() -> None:
sto = lh5.LH5Store()
t0 = time.time()

channel = get_table_name(args.metadata, args.timestamp, args.datatype, args.channel)

dsp_config = config_dict["inputs"]["processing_chain"][args.channel]
opt_json = config_dict["inputs"]["optimiser_config"][args.channel]

Expand Down Expand Up @@ -107,14 +106,14 @@ def par_geds_dsp_eopt() -> None:
)

peaks_rounded = [int(peak) for peak in peaks_kev]
peaks = sto.read(f"{channel}/raw", args.peak_file, field_mask=["peak"])[0][
peaks = sto.read(args.raw_table_name, args.peak_file, field_mask=["peak"])[0][
"peak"
].nda
ids = np.isin(peaks, peaks_rounded)
peaks = peaks[ids]
idx_list = [np.where(peaks == peak)[0] for peak in peaks_rounded]

tb_data = sto.read(f"{channel}/raw", args.peak_file, idx=ids)[0]
tb_data = sto.read(args.raw_table_name, args.peak_file, idx=ids)[0]

t1 = time.time()
log.info(f"Data Loaded in {(t1-t0)/60} minutes")
Expand Down
31 changes: 14 additions & 17 deletions workflow/src/legenddataflow/scripts/par/geds/dsp/evtsel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from .....log import build_log
from ....pulser_removal import get_pulser_mask
from ....table_name import get_table_name

warnings.filterwarnings(action="ignore", category=RuntimeWarning)

Expand Down Expand Up @@ -85,10 +84,10 @@ def par_geds_dsp_evtsel() -> None:
argparser = argparse.ArgumentParser()
argparser.add_argument("--raw-filelist", help="raw_filelist", type=str)
argparser.add_argument(
"--tcm-filelist", help="tcm_filelist", type=str, required=False
"--pulser-file", help="pulser_file", type=str, required=False
)
argparser.add_argument(
"--pulser-file", help="pulser_file", type=str, required=False
"-p", "--no-pulse", help="no pulser present", action="store_true"
)

argparser.add_argument("--decay_const", help="decay_const", type=str, required=True)
Expand All @@ -98,11 +97,13 @@ def par_geds_dsp_evtsel() -> None:

argparser.add_argument("--log", help="log_file", type=str)
argparser.add_argument("--configs", help="configs", type=str, required=True)
argparser.add_argument("--metadata", help="metadata", type=str, required=True)

argparser.add_argument("--datatype", help="Datatype", type=str, required=True)
argparser.add_argument("--timestamp", help="Timestamp", type=str, required=True)
argparser.add_argument("--channel", help="Channel", type=str, required=True)
argparser.add_argument(
"--raw-table-name", help="raw table name", type=str, required=True
)

argparser.add_argument("--peak-file", help="peak_file", type=str, required=True)
args = argparser.parse_args()
Expand All @@ -115,8 +116,6 @@ def par_geds_dsp_evtsel() -> None:
sto = lh5.LH5Store()
t0 = time.time()

channel = get_table_name(args.metadata, args.timestamp, args.datatype, args.channel)

dsp_config = config_dict["inputs"]["processing_chain"][args.channel]
peak_json = config_dict["inputs"]["peak_config"][args.channel]

Expand All @@ -134,16 +133,7 @@ def par_geds_dsp_evtsel() -> None:
files = f.read().splitlines()
raw_files = sorted(files)

mask = get_pulser_mask(
pulser_file=args.pulser_file,
tcm_filelist=args.tcm_filelist,
channel=channel,
pulser_multiplicity_threshold=peak_dict.get(
"pulser_multiplicity_threshold"
),
)

raw_dict = Props.read_from(args.raw_cal)[channel]["pars"]["operations"]
raw_dict = Props.read_from(args.raw_cal)[args.channel]["pars"]["operations"]

peaks_kev = peak_dict["peaks"]
kev_widths = peak_dict["kev_widths"]
Expand All @@ -152,7 +142,7 @@ def par_geds_dsp_evtsel() -> None:
final_cut_field = peak_dict["final_cut_field"]
energy_parameter = peak_dict.get("energy_parameter", "trapTmax")

lh5_path = f"{channel}/raw"
lh5_path = args.raw_table_name

if not isinstance(kev_widths, list):
kev_widths = [kev_widths]
Expand All @@ -164,6 +154,13 @@ def par_geds_dsp_evtsel() -> None:
lh5_path, raw_files, field_mask=["daqenergy", "t_sat_lo", "timestamp"]
)[0]

if args.no_pulse is False:
mask = get_pulser_mask(
args.pulser_file,
)
else:
mask = np.full(len(tb), False)

discharges = tb["t_sat_lo"].nda > 0
discharge_timestamps = np.where(tb["timestamp"].nda[discharges])[0]
is_recovering = np.full(len(tb), False, dtype=bool)
Expand Down
26 changes: 15 additions & 11 deletions workflow/src/legenddataflow/scripts/par/geds/dsp/nopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pygama.pargen.noise_optimization as pno
from dbetto import TextDB
from dbetto.catalog import Props
from legendmeta import LegendMetadata
from pygama.pargen.data_cleaning import generate_cuts, get_cut_indexes
from pygama.pargen.dsp_optimize import run_one_dsp

Expand All @@ -24,12 +23,14 @@ def par_geds_dsp_nopt() -> None:
argparser.add_argument("--inplots", help="inplots", type=str)

argparser.add_argument("--configs", help="configs", type=str, required=True)
argparser.add_argument("--metadata", help="metadata", type=str, required=True)
argparser.add_argument("--log", help="log_file", type=str)

argparser.add_argument("--datatype", help="Datatype", type=str, required=True)
argparser.add_argument("--timestamp", help="Timestamp", type=str, required=True)
argparser.add_argument("--channel", help="Channel", type=str, required=True)
argparser.add_argument(
"--raw-table-name", help="raw table name", type=str, required=True
)

argparser.add_argument("--dsp-pars", help="dsp_pars", type=str, required=True)
argparser.add_argument("--plot-path", help="plot_path", type=str)
Expand All @@ -43,10 +44,6 @@ def par_geds_dsp_nopt() -> None:

t0 = time.time()

meta = LegendMetadata(path=args.metadata)
channel_dict = meta.channelmap(args.timestamp, system=args.datatype)
channel = f"ch{channel_dict[args.channel].daq.rawid:07}"

dsp_config = config_dict["inputs"]["processing_chain"][args.channel]
opt_json = config_dict["inputs"]["optimiser_config"][args.channel]

Expand All @@ -59,10 +56,12 @@ def par_geds_dsp_nopt() -> None:

raw_files = sorted(files)

energies = sto.read(f"{channel}/raw/daqenergy", raw_files)[0]
energies = sto.read(
f"{args.raw_table_name}", raw_files, field_mask=["daqenergy"]
)["daqenergy"][0]
idxs = np.where(energies.nda == 0)[0]
tb_data = sto.read(
f"{channel}/raw", raw_files, n_rows=opt_dict["n_events"], idx=idxs
"args.raw_table_name", raw_files, n_rows=opt_dict["n_events"], idx=idxs
)[0]
t1 = time.time()
log.info(f"Time to open raw files {t1-t0:.2f} s, n. baselines {len(tb_data)}")
Expand All @@ -72,7 +71,7 @@ def par_geds_dsp_nopt() -> None:
cut_dict = generate_cuts(dsp_data, cut_dict=opt_dict.pop("cut_pars"))
cut_idxs = get_cut_indexes(dsp_data, cut_dict)
tb_data = sto.read(
f"{channel}/raw",
args.raw_table_name,
raw_files,
n_rows=opt_dict.pop("n_events"),
idx=idxs[cut_idxs],
Expand All @@ -84,11 +83,16 @@ def par_geds_dsp_nopt() -> None:

if args.plot_path:
out_dict, plot_dict = pno.noise_optimization(
tb_data, dsp_config, db_dict.copy(), opt_dict, channel, display=1
tb_data,
dsp_config,
db_dict.copy(),
opt_dict,
args.raw_table_name,
display=1,
)
else:
out_dict = pno.noise_optimization(
raw_files, dsp_config, db_dict.copy(), opt_dict, channel
raw_files, dsp_config, db_dict.copy(), opt_dict, args.raw_table_name
)

t2 = time.time()
Expand Down
Loading

0 comments on commit 7b70cba

Please sign in to comment.