Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

**BREAKING** change of sampleplot behaviour and defaults #213

Merged
merged 27 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
54355de
fix error message
st-- Aug 24, 2021
8fccaaf
pop! for custom-defined kwarg
st-- Aug 24, 2021
2f1204b
only single label for multiple samples by default
st-- Aug 24, 2021
3f47368
remove markers
st-- Aug 24, 2021
98ec20b
bump version
st-- Aug 24, 2021
0aff474
update docstring example
st-- Aug 24, 2021
71eeb10
add explanation of `label` handling to docstring
st-- Aug 24, 2021
cbeb5a8
demonstrate `label` handling in regression-1d example
st-- Aug 24, 2021
e4ddf18
formatting
st-- Aug 24, 2021
58a8e71
treat multiple samples as single series
st-- Aug 24, 2021
c75de4f
do not set seriestype
st-- Aug 24, 2021
aeab678
Apply suggestions from code review
st-- Aug 24, 2021
f5a7772
adjust tests for new one-series-for-all-samples
st-- Aug 24, 2021
477e416
Merge branch 'st/sampleplot_update' of github.com:JuliaGaussianProces…
st-- Aug 24, 2021
707ef25
update test
st-- Aug 25, 2021
528124e
fix NaN comparison in test
st-- Aug 25, 2021
830249b
adjust linealpha to 0.35 to make up for no longer plotting markers
st-- Aug 25, 2021
d0c1f5b
fix mistake;
st-- Aug 25, 2021
bd3b3ed
clean up other plot code as well
st-- Aug 25, 2021
c4fa4d7
update docstring
st-- Aug 25, 2021
2c47712
Apply suggestions from code review
st-- Aug 25, 2021
add2dec
Merge branch 'master' into st/sampleplot_update
st-- Aug 25, 2021
a750d15
sampleplot: remove seriescolor=red default
st-- Aug 26, 2021
3d25778
Merge branch 'st/sampleplot_update' of github.com:JuliaGaussianProces…
st-- Aug 26, 2021
36f91b3
bump minor version (not backwards compatible)
st-- Aug 26, 2021
ab37628
bump docs/examples compat
st-- Aug 26, 2021
3c6a2ce
explicitly set seriescolor="red" to account for backwards-incompatibl…
st-- Aug 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AbstractGPs"
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
authors = ["JuliaGaussianProcesses Team"]
version = "0.4.0"
version = "0.5.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
AbstractGPs = "0.4"
AbstractGPs = "0.4, 0.5"
Documenter = "0.27"
2 changes: 1 addition & 1 deletion examples/regression-1d/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
AbstractGPs = "0.4"
AbstractGPs = "0.4, 0.5"
AdvancedHMC = "0.2"
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
DynamicHMC = "2.2, 3.1"
Expand Down
50 changes: 19 additions & 31 deletions examples/regression-1d/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,19 +212,19 @@ mean(logpdf(gp_posterior(x_train, y_train, p)(x_test), y_test) for p in samples)
# We sample 5 functions from each posterior GP given by the final 100 samples of kernel
# parameters.

plt = scatter(
x_train,
y_train;
xlim=(0, 1),
xlabel="x",
ylabel="y",
title="posterior (AdvancedHMC)",
label="Train Data",
)
for p in samples[(end - 100):end]
sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p); samples=5)
plt = plot(; xlim=(0, 1), xlabel="x", ylabel="y", title="posterior (AdvancedHMC)")
for (i, p) in enumerate(samples[(end - 100):end])
sampleplot!(
plt,
0:0.02:1,
gp_posterior(x_train, y_train, p);
samples=5,
seriescolor="red",
label=(i == 1 ? "samples" : nothing),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just

Suggested change
label=(i == 1 ? "samples" : nothing),
label=(i == 1 ? "samples" : ""),

? I did not know that label=nothing is a thing 😄

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, one of the complaints about Plots.jl is that there are plenty of different ways of saying the same thing!
I find label = nothing to be a bit more self-explanatory, but feel free to add&commit your suggestion if you feel moderately strongly about it:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't feel strongly about it. The main reason would be to avoid having two different types here but I'm fine with nothing as well.

)
end
scatter!(plt, x_test, y_test; label="Test Data")
scatter!(plt, x_train, y_train; label="Train Data", markercolor=1)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
scatter!(plt, x_test, y_test; label="Test Data", markercolor=2)
plt

# #### DynamicHMC
Expand Down Expand Up @@ -290,18 +290,11 @@ mean(logpdf(gp_posterior(x_train, y_train, p)(x_test), y_test) for p in samples)
# We sample a function from the posterior GP for the final 100 samples of kernel
# parameters.

plt = scatter(
x_train,
y_train;
xlim=(0, 1),
xlabel="x",
ylabel="y",
title="posterior (DynamicHMC)",
label="Train Data",
)
plt = plot(; xlim=(0, 1), xlabel="x", ylabel="y", title="posterior (DynamicHMC)")
scatter!(plt, x_train, y_train; label="Train Data")
scatter!(plt, x_test, y_test; label="Test Data")
for p in samples[(end - 100):end]
sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p))
sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p); seriescolor="red")
end
plt

Expand Down Expand Up @@ -349,18 +342,13 @@ mean(logpdf(gp_posterior(x_train, y_train, p)(x_test), y_test) for p in samples)
# We sample a function from the posterior GP for the final 100 samples of kernel
# parameters.

plt = scatter(
x_train,
y_train;
xlim=(0, 1),
xlabel="x",
ylabel="y",
title="posterior (EllipticalSliceSampling)",
label="Train Data",
plt = plot(;
xlim=(0, 1), xlabel="x", ylabel="y", title="posterior (EllipticalSliceSampling)"
)
scatter!(plt, x_train, y_train; label="Train Data")
scatter!(plt, x_test, y_test; label="Test Data")
for p in samples[(end - 100):end]
sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p))
sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p); seriescolor="red")
end
plt

Expand Down
26 changes: 13 additions & 13 deletions src/util/plotting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
length(x) == length(gp.x) ||
throw(DimensionMismatch("length of `x` and `gp.x` has to be equal"))
scale::Float64 = pop!(plotattributes, :ribbon_scale, 1.0)
scale > 0.0 || error("`bandwidth` keyword argument must be non-negative")
scale >= 0.0 || error("`ribbon_scale` keyword argument must be non-negative")

# compute marginals
μ, σ2 = mean_and_var(gp)
Expand Down Expand Up @@ -82,16 +82,19 @@ Plot samples from the projection `f` of a Gaussian process versus `x`.
Make sure to load [Plots.jl](https://github.com/JuliaPlots/Plots.jl) before you use
this function.

When plotting multiple samples, these are treated as a _single_ series (i.e.,
only a single entry will be added to the legend when providing a `label`).

# Example

```julia
using Plots

gp = GP(SqExponentialKernel())
sampleplot(gp(rand(5)); samples=10, markersize=5)
sampleplot(gp(rand(5)); samples=10, linealpha=1.0)
```
The given example plots 10 samples from the projection of the GP `gp`. The `markersize` is modified
from default of 0.5 to 5.
The given example plots 10 samples from the projection of the GP `gp`.
The `linealpha` is modified from default of 0.35 to 1.

---
sampleplot(x::AbstractVector, gp::AbstractGP; samples=1, kwargs...)
Expand All @@ -115,18 +118,15 @@ SamplePlot((f,)::Tuple{<:FiniteGP}) = SamplePlot((f.x, f))
SamplePlot((x, gp)::Tuple{<:AbstractVector,<:AbstractGP}) = SamplePlot((gp(x, 1e-9),))

@recipe function f(sp::SamplePlot)
nsamples::Int = get(plotattributes, :samples, 1)
nsamples::Int = pop!(plotattributes, :samples, 1)
samples = rand(sp.f, nsamples)

flat_x = repeat(vcat(sp.x, NaN), nsamples)
flat_f = vec(vcat(samples, fill(NaN, 1, nsamples)))

# Set default attributes
seriestype --> :line
linealpha --> 0.2
markershape --> :circle
markerstrokewidth --> 0.0
markersize --> 0.5
markeralpha --> 0.3
seriescolor --> "red"
linealpha --> 0.35
label --> ""

return sp.x, samples
return flat_x, flat_f
end
6 changes: 3 additions & 3 deletions test/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
gp = f(x, 0.1)

plt = @test_deprecated sampleplot(gp, 10)
@test plt.n == 10
@test plt.n == 1

@test_deprecated sampleplot!(gp, 4)
@test plt.n == 14
@test plt.n == 2

@test_deprecated sampleplot!(Plots.current(), gp, 3)
@test plt.n == 17
@test plt.n == 3
end
22 changes: 12 additions & 10 deletions test/util/plotting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
z = rand(10)
plt1 = sampleplot(z, gp)
@test plt1.n == 1
@test plt1.series_list[1].plotattributes[:x] == sort(z)
@test isequal(plt1.series_list[1].plotattributes[:x], vcat(z, NaN))

plt2 = sampleplot(gp; samples=10)
@test plt2.n == 10
sort_x = sort(x)
@test all(series.plotattributes[:x] == sort_x for series in plt2.series_list)
plt2 = sampleplot(gp; samples=3)
@test plt2.n == 1
plt2_x = plt2.series_list[1].plotattributes[:x]
plt2_y = plt2.series_list[1].plotattributes[:y]
@test isequal(plt2_x, vcat(x, NaN, x, NaN, x, NaN))
@test length(plt2_y) == length(plt2_x)
@test isnan(plt2_y[length(z) + 1]) && isnan(plt2_y[2length(z) + 2])

z = rand(7)
plt3 = sampleplot(z, f; samples=8)
@test plt3.n == 8
sort_z = sort(z)
@test all(series.plotattributes[:x] == sort_z for series in plt3.series_list)
z3 = rand(7)
plt3 = sampleplot(z3, f; samples=2)
@test plt3.n == 1
@test isequal(plt3.series_list[1].plotattributes[:x], vcat(z3, NaN, z3, NaN))

# Check recipe dispatches for `FiniteGP`s
rec = RecipesBase.apply_recipe(Dict{Symbol,Any}(), gp)
Expand Down