Skip to content

Commit

Permalink
Fix examples to run with batched computations
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Jan 6, 2024
1 parent 5b4e39c commit 94064cb
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 295 deletions.
43 changes: 25 additions & 18 deletions cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(
super().__init__(name=name)

assert isinstance(transfer_map, torch.Tensor)
assert transfer_map.shape == (1, 7, 7)
assert transfer_map.shape[-2:] == (7, 7)

self._transfer_map = torch.as_tensor(transfer_map, **factory_kwargs)
self.length = (
Expand Down Expand Up @@ -219,6 +219,13 @@ def from_merging_elements(
def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
return self._transfer_map

def broadcast(self, shape: Size) -> Element:
return self.__class__(
self._transfer_map.repeat((*shape, 1, 1)),
length=self.length.repeat(shape),
name=self.name,
)

@property
def is_skippable(self) -> bool:
return True
Expand Down Expand Up @@ -394,9 +401,9 @@ def split(self, resolution: torch.Tensor) -> list[Element]:

def plot(self, ax: matplotlib.axes.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.k1) if self.is_active else 1)
height = 0.8 * (np.sign(self.k1[0]) if self.is_active else 1)
patch = Rectangle(
(s, 0), self.length, height, color="tab:red", alpha=alpha, zorder=2
(s, 0), self.length[0], height, color="tab:red", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down Expand Up @@ -621,10 +628,10 @@ def defining_features(self) -> list[str]:

def plot(self, ax: matplotlib.axes.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.angle) if self.is_active else 1)
height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1)

patch = Rectangle(
(s, 0), self.length, height, color="tab:green", alpha=alpha, zorder=2
(s, 0), self.length[0], height, color="tab:green", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down Expand Up @@ -713,7 +720,7 @@ def __init__(
self.angle = (
torch.as_tensor(angle, **factory_kwargs)
if angle is not None
else torch.tensor(0.0, **factory_kwargs)
else torch.zeros_like(self.length)
)

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -758,10 +765,10 @@ def split(self, resolution: torch.Tensor) -> list[Element]:

def plot(self, ax: matplotlib.axes.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.angle) if self.is_active else 1)
height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1)

patch = Rectangle(
(s, 0), self.length, height, color="tab:blue", alpha=alpha, zorder=2
(s, 0), self.length[0], height, color="tab:blue", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down Expand Up @@ -803,7 +810,7 @@ def __init__(
self.angle = (
torch.as_tensor(angle, **factory_kwargs)
if angle is not None
else torch.tensor(0.0, **factory_kwargs)
else torch.zeros_like(self.length)
)

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -847,10 +854,10 @@ def split(self, resolution: torch.Tensor) -> list[Element]:

def plot(self, ax: matplotlib.axes.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.angle) if self.is_active else 1)
height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1)

patch = Rectangle(
(s, 0), self.length, height, color="tab:cyan", alpha=alpha, zorder=2
(s, 0), self.length[0], height, color="tab:cyan", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down Expand Up @@ -1211,7 +1218,7 @@ def plot(self, ax: matplotlib.axes.Axes, s: float) -> None:
height = 0.4

patch = Rectangle(
(s, 0), self.length, height, color="gold", alpha=alpha, zorder=2
(s, 0), self.length[0], height, color="gold", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down Expand Up @@ -1771,7 +1778,7 @@ def plot(self, ax: matplotlib.axes.Axes, s: float) -> None:
height = 0.4

patch = Rectangle(
(s, 0), self.length, height, color="tab:purple", alpha=alpha, zorder=2
(s, 0), self.length[0], height, color="tab:purple", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down Expand Up @@ -1895,7 +1902,7 @@ def plot(self, ax: matplotlib.axes.Axes, s: float) -> None:
height = 0.8

patch = Rectangle(
(s, 0), self.length, height, color="tab:orange", alpha=alpha, zorder=2
(s, 0), self.length[0], height, color="tab:orange", alpha=alpha, zorder=2
)
ax.add_patch(patch)

Expand Down Expand Up @@ -2251,7 +2258,7 @@ def split(self, resolution: torch.Tensor) -> list[Element]:

def plot(self, ax: matplotlib.axes.Axes, s: float) -> None:
element_lengths = [
element.length if hasattr(element, "length") else 0.0
element.length[0] if hasattr(element, "length") else 0.0
for element in self.elements
]
element_ss = [0] + [
Expand Down Expand Up @@ -2291,7 +2298,7 @@ def plot_reference_particle_traces(
splits = reference_segment.split(resolution=torch.tensor(resolution))

split_lengths = [
split.length if hasattr(split, "length") else 0.0 for split in splits
split.length[0] if hasattr(split, "length") else 0.0 for split in splits
]
ss = [0] + [sum(split_lengths[: i + 1]) for i, _ in enumerate(split_lengths)]

Expand Down Expand Up @@ -2324,7 +2331,7 @@ def plot_reference_particle_traces(

for particle_index in range(num_particles):
xs = [
float(reference_beam.xs[particle_index].cpu())
float(reference_beam.xs[0, particle_index].cpu())
for reference_beam in references
if reference_beam is not Beam.empty
]
Expand All @@ -2335,7 +2342,7 @@ def plot_reference_particle_traces(

for particle_index in range(num_particles):
ys = [
float(reference_beam.ys[particle_index].cpu())
float(reference_beam.ys[0, particle_index].cpu())
for reference_beam in references
if reference_beam is not Beam.empty
]
Expand Down
64 changes: 32 additions & 32 deletions cheetah/converters/dontbmad.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,43 +466,43 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
["element_type", "alias", "type", "l"], bmad_parsed
)
if "l" in bmad_parsed:
return cheetah.Drift(length=torch.tensor(bmad_parsed["l"]), name=name)
return cheetah.Drift(length=torch.tensor([bmad_parsed["l"]]), name=name)
else:
return cheetah.Marker(name=name)
elif bmad_parsed["element_type"] == "instrument":
validate_understood_properties(
["element_type", "alias", "type", "l"], bmad_parsed
)
if "l" in bmad_parsed:
return cheetah.Drift(length=torch.tensor(bmad_parsed["l"]), name=name)
return cheetah.Drift(length=torch.tensor([bmad_parsed["l"]]), name=name)
else:
return cheetah.Marker(name=name)
elif bmad_parsed["element_type"] == "pipe":
validate_understood_properties(
["element_type", "alias", "type", "l", "descrip"], bmad_parsed
)
return cheetah.Drift(length=torch.tensor(bmad_parsed["l"]), name=name)
return cheetah.Drift(length=torch.tensor([bmad_parsed["l"]]), name=name)
elif bmad_parsed["element_type"] == "drift":
validate_understood_properties(
["element_type", "l", "type", "descrip"], bmad_parsed
)
return cheetah.Drift(length=torch.tensor(bmad_parsed["l"]), name=name)
return cheetah.Drift(length=torch.tensor([bmad_parsed["l"]]), name=name)
elif bmad_parsed["element_type"] == "hkicker":
validate_understood_properties(
["element_type", "type", "alias"], bmad_parsed
)
return cheetah.HorizontalCorrector(
length=torch.tensor(bmad_parsed.get("l", 0.0)),
angle=torch.tensor(bmad_parsed.get("kick", 0.0)),
length=torch.tensor([bmad_parsed.get("l", 0.0)]),
angle=torch.tensor([bmad_parsed.get("kick", 0.0)]),
name=name,
)
elif bmad_parsed["element_type"] == "vkicker":
validate_understood_properties(
["element_type", "type", "alias"], bmad_parsed
)
return cheetah.VerticalCorrector(
length=torch.tensor(bmad_parsed.get("l", 0.0)),
angle=torch.tensor(bmad_parsed.get("kick", 0.0)),
length=torch.tensor([bmad_parsed.get("l", 0.0)]),
angle=torch.tensor([bmad_parsed.get("kick", 0.0)]),
name=name,
)
elif bmad_parsed["element_type"] == "sbend":
Expand All @@ -526,15 +526,15 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
bmad_parsed,
)
return cheetah.Dipole(
length=torch.tensor(bmad_parsed["l"]),
gap=torch.tensor(bmad_parsed.get("hgap", 0.0)),
angle=torch.tensor(bmad_parsed.get("angle", 0.0)),
e1=torch.tensor(bmad_parsed["e1"]),
e2=torch.tensor(bmad_parsed.get("e2", 0.0)),
tilt=torch.tensor(bmad_parsed.get("ref_tilt", 0.0)),
fringe_integral=torch.tensor(bmad_parsed.get("fint", 0.0)),
length=torch.tensor([bmad_parsed["l"]]),
gap=torch.tensor([bmad_parsed.get("hgap", 0.0)]),
angle=torch.tensor([bmad_parsed.get("angle", 0.0)]),
e1=torch.tensor([bmad_parsed["e1"]]),
e2=torch.tensor([bmad_parsed.get("e2", 0.0)]),
tilt=torch.tensor([bmad_parsed.get("ref_tilt", 0.0)]),
fringe_integral=torch.tensor([bmad_parsed.get("fint", 0.0)]),
fringe_integral_exit=(
torch.tensor(bmad_parsed["fintx"])
torch.tensor([bmad_parsed["fintx"]])
if "fintx" in bmad_parsed
else None
),
Expand All @@ -547,18 +547,18 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
bmad_parsed,
)
return cheetah.Quadrupole(
length=torch.tensor(bmad_parsed["l"]),
k1=torch.tensor(bmad_parsed["k1"]),
tilt=torch.tensor(bmad_parsed.get("tilt", 0.0)),
length=torch.tensor([bmad_parsed["l"]]),
k1=torch.tensor([bmad_parsed["k1"]]),
tilt=torch.tensor([bmad_parsed.get("tilt", 0.0)]),
name=name,
)
elif bmad_parsed["element_type"] == "solenoid":
validate_understood_properties(
["element_type", "l", "ks", "alias"], bmad_parsed
)
return cheetah.Solenoid(
length=torch.tensor(bmad_parsed["l"]),
k=torch.tensor(bmad_parsed["ks"]),
length=torch.tensor([bmad_parsed["l"]]),
k=torch.tensor([bmad_parsed["ks"]]),
name=name,
)
elif bmad_parsed["element_type"] == "lcavity":
Expand All @@ -577,12 +577,12 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
bmad_parsed,
)
return cheetah.Cavity(
length=torch.tensor(bmad_parsed["l"]),
voltage=torch.tensor(bmad_parsed.get("voltage", 0.0)),
length=torch.tensor([bmad_parsed["l"]]),
voltage=torch.tensor([bmad_parsed.get("voltage", 0.0)]),
phase=torch.tensor(
-np.degrees(bmad_parsed.get("phi0", 0.0) * 2 * np.pi)
[-np.degrees(bmad_parsed.get("phi0", 0.0) * 2 * np.pi)]
),
frequency=torch.tensor(bmad_parsed["rf_frequency"]),
frequency=torch.tensor([bmad_parsed["rf_frequency"]]),
name=name,
)
elif bmad_parsed["element_type"] == "rcollimator":
Expand All @@ -591,8 +591,8 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
bmad_parsed,
)
return cheetah.Aperture(
x_max=torch.tensor(bmad_parsed.get("x_limit", np.inf)),
y_max=torch.tensor(bmad_parsed.get("y_limit", np.inf)),
x_max=torch.tensor([bmad_parsed.get("x_limit", np.inf)]),
y_max=torch.tensor([bmad_parsed.get("y_limit", np.inf)]),
shape="rectangular",
name=name,
)
Expand All @@ -602,8 +602,8 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
bmad_parsed,
)
return cheetah.Aperture(
x_max=torch.tensor(bmad_parsed.get("x_limit", np.inf)),
y_max=torch.tensor(bmad_parsed.get("y_limit", np.inf)),
x_max=torch.tensor([bmad_parsed.get("x_limit", np.inf)]),
y_max=torch.tensor([bmad_parsed.get("y_limit", np.inf)]),
shape="elliptical",
name=name,
)
Expand All @@ -622,12 +622,12 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
],
bmad_parsed,
)
return cheetah.Undulator(length=torch.tensor(bmad_parsed["l"]), name=name)
return cheetah.Undulator(length=torch.tensor([bmad_parsed["l"]]), name=name)
elif bmad_parsed["element_type"] == "patch":
# TODO: Does this need to be implemented in Cheetah in a more proper way?
validate_understood_properties(["element_type", "tilt"], bmad_parsed)
return cheetah.Drift(
length=torch.tensor(bmad_parsed.get("l", 0.0)), name=name
length=torch.tensor([bmad_parsed.get("l", 0.0)]), name=name
)
else:
print(
Expand All @@ -636,7 +636,7 @@ def convert_element(name: str, context: dict) -> "cheetah.Element":
)
# TODO: Remove the length if by adding markers to Cheeath
return cheetah.Drift(
name=name, length=torch.tensor(bmad_parsed.get("l", 0.0))
name=name, length=torch.tensor([bmad_parsed.get("l", 0.0)])
)
else:
raise ValueError(f"Unknown Bmad element type for {name = }")
Expand Down
Loading

0 comments on commit 94064cb

Please sign in to comment.