Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix option handling in pipeline.jl #70

Merged
merged 9 commits into from
Aug 9, 2024
38 changes: 24 additions & 14 deletions src/pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@
Store() = Store(":memory:")
DEFAULT = Store()

"""
get_tbl_name(source::String, tmp::Bool)

Generate table name from a filename by removing special characters.
If `tmp` is true, then the table name is prefixed by 't_'.

"""
function get_tbl_name(source::String, tmp::Bool)
suvayu marked this conversation as resolved.
Show resolved Hide resolved
name, _ = splitext(basename(source))
name = replace(name, r"[ ()\[\]{}\\+,.-]+" => "_")
Expand All @@ -55,7 +62,8 @@
# TODO: support "CREATE OR REPLACE" & "IF NOT EXISTS" for all create_* functions

function _create_tbl_impl(con::DB, query::String; name::String, tmp::Bool, show::Bool)
DBInterface.execute(con, "CREATE $(tmp ? "TEMP" : "") TABLE $name AS $query")
create_table_cmd = "CREATE" * (tmp ? " TEMP" : "") * " TABLE"
DBInterface.execute(con, "$create_table_cmd $name AS $query")
return show ? DF.DataFrame(DBInterface.execute(con, "SELECT * FROM $name")) : name
end

Expand Down Expand Up @@ -99,16 +107,16 @@
types = Dict(),
)
check_file(source) ? true : throw(FileNotFoundError(source))
if length(name) == 0
name = get_tbl_name(source, tmp)
end

kwargs = Dict{Symbol, String}()
if length(types) > 0
kwargs[:types] = "{" * join(("'$key': '$value'" for (key, value) in types), ",") * "}"
end
query = fmt_select(fmt_read(source; _read_opts..., kwargs...))

if (length(name) == 0)
name = get_tbl_name(source, tmp)
end

return _create_tbl_impl(con, query; name, tmp, show)
end

Expand Down Expand Up @@ -167,13 +175,13 @@
tmp::Bool = false,
show::Bool = false,
)
sources = [fmt_source(con, src) for src in (base_source, alt_source)]
query = fmt_join(sources...; on = on, cols = cols, fill = fill, fill_values = fill_values)

if (length(name) == 0)
if (check_file(alt_source) && length(name) == 0)
suvayu marked this conversation as resolved.
Show resolved Hide resolved
name = get_tbl_name(alt_source, tmp)
end

sources = [fmt_source(con, src) for src in (base_source, alt_source)]
query = fmt_join(sources...; on = on, cols = cols, fill = fill, fill_values = fill_values)

return _create_tbl_impl(con, query; name, tmp, show)
end

Expand Down Expand Up @@ -238,6 +246,9 @@
# columns? If such a feature is required, we can use
# cols::Dict{Symbol, Vector{Any}}, and get the cols and vals
# as: keys(cols), and values(cols)
if (check_file(source) && length(name) == 0)
suvayu marked this conversation as resolved.
Show resolved Hide resolved
name = get_tbl_name(source, tmp)

Check warning on line 250 in src/pipeline.jl

View check run for this annotation

Codecov / codecov/patch

src/pipeline.jl#L250

Added line #L250 was not covered by tests
end

# for now, support only one column
if length(cols) > 1
Expand Down Expand Up @@ -276,7 +287,7 @@
source::String,
cols::Dict{Symbol, T};
on::Symbol,
name::String,
name::String = "",
where_::String = "",
tmp::Bool = false,
show::Bool = false,
Expand All @@ -298,16 +309,15 @@
source::String,
cols::Dict{Symbol, T};
on::Symbol,
name::String,
name::String = "",
where_::String = "",
tmp::Bool = false,
show::Bool = false,
) where {T}
if (length(name) == 0)
if (check_file(source) && length(name) == 0)
suvayu marked this conversation as resolved.
Show resolved Hide resolved
name = get_tbl_name(source, tmp)
end

# FIXME: accept NamedTuple|Dict as cols in stead of value & col
source = fmt_source(con, source)
subquery = fmt_select(source; cols...)
if length(where_) > 0
Expand Down Expand Up @@ -340,8 +350,8 @@
src = fmt_source(con, source)
query = "SELECT * FROM $src WHERE $expression"

if (length(name) == 0)
if (check_file(source) && length(name) == 0)
suvayu marked this conversation as resolved.
Show resolved Hide resolved
name = get_tbl_name(source, tmp)

Check warning on line 354 in src/pipeline.jl

View check run for this annotation

Codecov / codecov/patch

src/pipeline.jl#L353-L354

Added lines #L353 - L354 were not covered by tests
end

return _create_tbl_impl(con, query; name = name, tmp = tmp, show = show)
Expand Down
31 changes: 20 additions & 11 deletions test/test-pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ end

@testset "get_tbl_name(source, tmp)" begin
for (name, tmp) in [["my_file", false], ["t_my_file", true]]
@test name == TIO.get_tbl_name("my-file.csv", tmp)
@test name == TIO.get_tbl_name("path/my-file.csv", tmp)
end
end

Expand Down Expand Up @@ -82,6 +82,7 @@ end
end

@testset "CSV -> DataFrame w/ a schema" begin
con = DBInterface.connect(DB)
mapping_csv_path = joinpath(DATA, "Norse/rep-periods-mapping.csv")
col_schema = Dict(:period => "INT", :rep_period => "VARCHAR", :weight => "DOUBLE")
TIO.create_tbl(con, mapping_csv_path; types = col_schema)
Expand Down Expand Up @@ -133,8 +134,18 @@ end
@test (DF.subset(cmp, :investable_1 => DF.ByRow(ismissing)).investable) |> all
end

con = DBInterface.connect(DB)
@testset "temporary tables" begin
con = DBInterface.connect(DB)
tbl_name = TIO.create_tbl(con, csv_path; name = "tmp_assets", tmp = true)
@test tbl_name in tmp_tbls(con)[!, :name]

tbl_name = TIO.create_tbl(con, csv_path; tmp = true)
@test tbl_name == "t_assets_data" # t_<cleaned up filename>
@test tbl_name in tmp_tbls(con)[!, :name]
end

@testset "CSV -> table" begin
con = DBInterface.connect(DB)
tbl_name = TIO.create_tbl(con, csv_path; name = "no_assets")
df_res = DF.DataFrame(DBInterface.execute(con, "SELECT * FROM $tbl_name"))
@test shape(df_org) == shape(df_res)
Expand All @@ -156,16 +167,11 @@ end
# 1 │ Asgard_Battery storage true
end

@testset "temporary tables" begin
tbl_name = TIO.create_tbl(con, csv_path; name = "tmp_assets", tmp = true)
@test tbl_name in tmp_tbls(con)[!, :name]

tbl_name = TIO.create_tbl(con, csv_path; tmp = true)
@test tbl_name == "t_assets_data" # t_<cleaned up filename>
@test tbl_name in tmp_tbls(con)[!, :name]
end

@testset "table + CSV w/ alternatives -> table" begin
# test setup
con = DBInterface.connect(DB)
TIO.create_tbl(con, csv_path; name = "no_assets")

opts = Dict(:on => [:name], :cols => [:investable])
tbl_name =
TIO.create_tbl(con, "no_assets", csv_copy; name = "alt_assets", opts..., fill = false)
Expand Down Expand Up @@ -239,6 +245,9 @@ end
con = DBInterface.connect(DB)
df_res = TIO.set_tbl_col(con, csv_path, Dict(:investable => true); opts...)
@test df_res.investable |> all

table_name = TIO.set_tbl_col(con, csv_path, Dict(:investable => true); on = :name)
@test "assets_data" == table_name
end

@testset "w/ constant after filtering" begin
Expand Down
Loading