Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
austin-hoover committed Feb 13, 2025
1 parent d1d57ee commit abb9dcb
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 3 deletions.
6 changes: 6 additions & 0 deletions orbit_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,9 @@
from . import sim
from . import utils
from .core import *
from .bunch import *
from .diag import Diagnostic
from .diag import BunchHistogram
from .lattice import *
from .sim import *
from .utils import *
2 changes: 1 addition & 1 deletion tests/test_coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def make_bunch(mass: float = 0.938, energy: float = 1.000) -> Bunch:
bunch = Bunch(mass=mass, energy=energy)
bunch = Bunch()
bunch.mass(mass)
bunch.getSyncParticle().kinEnergy(energy)
return bunch
Expand Down
56 changes: 56 additions & 0 deletions tests/test_diag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import numpy as np
import pytest

from orbit.core.bunch import Bunch
from orbit.lattice import AccLattice
from orbit.lattice import AccNode
from orbit.teapot import DriftTEAPOT
from orbit.teapot import QuadTEAPOT
from orbit.teapot import TEAPOT_Lattice

from orbit_tools.bunch import set_bunch_coords
from orbit_tools.diag import BunchHistogram


def test_hist():
nbins = 100
seed = 123

rng = np.random.default_rng(seed)
x = rng.normal(size=(10_000, 6))

bunch = Bunch()
bunch.mass(0.938)
bunch.getSyncParticle().kinEnergy(1.000)
bunch.macroSize(1.0)
bunch = set_bunch_coords(bunch, x)

axis_list = []
for i in range(6):
axis_list.append((i,))

for i in range(6):
for j in range(i):
axis_list.append((i, j))

for axis in axis_list:
ndim = len(axis)
shape = tuple(ndim * [nbins])
limits = ndim * [(-5.0, 5.0)]

# Compute histogram using BunchHistogram
hist = BunchHistogram(axis=axis, shape=shape, limits=limits)
values = hist.compute_histogram(bunch)
values = values / np.max(values)

# Compute histogram using NumPy
values_np, _ = np.histogramdd(x[:, axis], bins=hist.edges)
values_np = values_np / np.max(values_np)

# Compare the histograms. There will be differences because Grid
# classes use weighting.
print(np.max(np.abs(values - values_np)))



2 changes: 1 addition & 1 deletion tests/test_ring.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def make_bunch(mass: float = 0.938, energy: float = 1.000) -> Bunch:
bunch = Bunch(mass=mass, energy=energy)
bunch = Bunch()
bunch.mass(mass)
bunch.getSyncParticle().kinEnergy(energy)
return bunch
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def make_bunch(mass: float = 0.938, energy: float = 1.000) -> Bunch:
bunch = Bunch(mass=mass, energy=energy)
bunch = Bunch()
bunch.mass(mass)
bunch.getSyncParticle().kinEnergy(energy)
return bunch
Expand Down

0 comments on commit abb9dcb

Please sign in to comment.