Skip to content

Commit

Permalink
Fix batching in cavity for particle beam
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Jan 6, 2024
1 parent 4b7c15f commit df48d85
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 26 deletions.
57 changes: 32 additions & 25 deletions cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,8 +992,8 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
delta_energy = self.voltage * torch.cos(phi)

T566 = 1.5 * self.length * igamma2 / beta0**3
T556 = 0
T555 = 0
T556 = 0.0
T555 = 0.0
if any(incoming.energy + delta_energy > 0):
k = 2 * torch.pi * self.frequency / constants.speed_of_light
outgoing_energy = incoming.energy + delta_energy
Expand Down Expand Up @@ -1021,13 +1021,20 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
else: # ParticleBeam
outgoing_particles[:, :, 5] = (
incoming.particles[:, :, 5]
+ incoming.energy * beta0 / (outgoing_energy * beta1)
+ self.voltage
* beta0
/ (outgoing_energy * beta1)
+ incoming.energy.unsqueeze(-1)
* beta0.unsqueeze(-1)
/ (outgoing_energy.unsqueeze(-1) * beta1.unsqueeze(-1))
+ self.voltage.unsqueeze(-1)
* beta0.unsqueeze(-1)
/ (outgoing_energy.unsqueeze(-1) * beta1.unsqueeze(-1))
* (
torch.cos(incoming.particles[:, :, 4] * beta0 * k + phi)
- torch.cos(phi)
torch.cos(
incoming.particles[:, :, 4]
* beta0.unsqueeze(-1)
* k.unsqueeze(-1)
+ phi.unsqueeze(-1)
)
- torch.cos(phi).unsqueeze(-1)
)
)

Expand Down Expand Up @@ -1089,7 +1096,7 @@ def _track_beam(self, incoming: ParticleBeam) -> ParticleBeam:
outgoing_cov[:, 5, 4] = outgoing_cov[:, 4, 5]
else: # ParticleBeam
outgoing_particles[:, :, 4] = (
T566 * incoming.particles[:, :, 5] ** 2
T566.unsqueeze(-1) * incoming.particles[:, :, 5] ** 2
+ T556 * incoming.particles[:, :, 4] * incoming.particles[:, :, 5]
+ T555 * incoming.particles[:, :, 4] ** 2
)
Expand Down Expand Up @@ -1126,7 +1133,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
Ei = energy / electron_mass_eV
Ef = (energy + delta_energy) / electron_mass_eV
Ep = (Ef - Ei) / self.length # Derivative of the energy
assert Ei > 0, "Initial energy must be larger than 0"
assert all(Ei > 0), "Initial energy must be larger than 0"

alpha = torch.sqrt(eta / 8) / torch.cos(phi) * torch.log(Ef / Ei)

Expand Down Expand Up @@ -1161,8 +1168,8 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
beta1 = torch.tensor(1.0)

k = 2 * torch.pi * self.frequency / torch.tensor(constants.speed_of_light)
r55_cor = 0
if self.voltage != 0 and energy != 0: # TODO: Do we need this if?
r55_cor = 0.0
if any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if?
beta0 = torch.sqrt(1 - 1 / Ei**2)
beta1 = torch.sqrt(1 - 1 / Ef**2)

Expand All @@ -1183,19 +1190,19 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
r66 = Ei / Ef * beta0 / beta1
r65 = k * torch.sin(phi) * self.voltage / (Ef * beta1 * electron_mass_eV)

R = torch.eye(7, device=device, dtype=dtype)
R[0, 0] = r11
R[0, 1] = r12
R[1, 0] = r21
R[1, 1] = r22
R[2, 2] = r11
R[2, 3] = r12
R[3, 2] = r21
R[3, 3] = r22
R[4, 4] = 1 + r55_cor
R[4, 5] = r56
R[5, 4] = r65
R[5, 5] = r66
R = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1))
R[:, 0, 0] = r11
R[:, 0, 1] = r12
R[:, 1, 0] = r21
R[:, 1, 1] = r22
R[:, 2, 2] = r11
R[:, 2, 3] = r12
R[:, 3, 2] = r21
R[:, 3, 3] = r22
R[:, 4, 4] = 1 + r55_cor
R[:, 4, 5] = r56
R[:, 5, 4] = r65
R[:, 5, 5] = r66

return R

Expand Down
2 changes: 1 addition & 1 deletion cheetah/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,7 +1450,7 @@ def broadcast(self, shape: torch.Size) -> "ParticleBeam":
return self.__class__(
particles=self.particles.repeat((*shape, 1, 1)),
energy=self.energy.repeat(shape),
particle_charges=self.particle_charges.repeat(shape),
particle_charges=self.particle_charges.repeat((*shape, 1)),
)

def __repr__(self) -> str:
Expand Down
31 changes: 31 additions & 0 deletions tests/test_cavity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch

import cheetah


def test_assert_ei_greater_zero():
"""
Reproduces
```
1127 Ef = (energy + delta_energy) / electron_mass_eV
1128 Ep = (Ef - Ei) / self.length # Derivative of the energy
-> 1129 assert Ei > 0, "Initial energy must be larger than 0"
1131 alpha = torch.sqrt(eta / 8) / torch.cos(phi) * torch.log(Ef / Ei)
1133 r11 = torch.cos(alpha) - torch.sqrt(2 / eta) * torch.cos(phi) * torch.sin(alpha) # noqa: E501
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
```
"""
cavity = cheetah.Cavity(
length=torch.tensor([3.0441, 3.0441, 3.0441]),
voltage=torch.tensor([48198468.0, 48198468.0, 48198468.0]),
phase=torch.tensor([48198468.0, 48198468.0, 48198468.0]),
frequency=torch.tensor([2.8560e09, 2.8560e09, 2.8560e09]),
name="k26_2a",
)
beam = cheetah.ParticleBeam.from_parameters(
num_particles=100_000, sigma_x=torch.tensor([1e-5])
).broadcast((3,))

_ = cavity.track(beam)

0 comments on commit df48d85

Please sign in to comment.