Skip to content

Commit

Permalink
Overhaul computation graph system
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Dec 12, 2023
1 parent ca80afe commit f3e328e
Show file tree
Hide file tree
Showing 13 changed files with 834 additions and 751 deletions.
30 changes: 21 additions & 9 deletions mergekit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def evaluate_setting(
if (
(cond.filter is None)
or (cond.filter == "*")
or cond.filter in tensor_name
or (tensor_name and cond.filter in tensor_name)
):
res = evaluate_setting(tensor_name, cond.value, t)
return res
Expand Down Expand Up @@ -113,10 +113,10 @@ def validate(self):

class ConfigReader(BaseModel):
config: MergeConfiguration
tensor_name: str
t: float
slice_out: Optional[OutputSliceDefinition]
slices_in: Optional[List[InputSliceDefinition]]
tensor_name: Optional[str] = None
slice_out: Optional[OutputSliceDefinition] = None
slices_in: Optional[List[InputSliceDefinition]] = None

@property
def base_model(self) -> Optional[ModelReference]:
Expand All @@ -129,6 +129,20 @@ def base_model(self) -> Optional[ModelReference]:
return ModelReference.parse(res)
return None

def for_out_slice(self, slice: OutputSliceDefinition) -> "ConfigReader":
return ConfigReader(self.model_dump(exclude=["slice_out"]), slice_out=slice)

def for_in_slices(self, slices: List[InputSliceDefinition]) -> "ConfigReader":
return ConfigReader(self.model_dump(exclude=["slices_in"]), slices_in=slices)

def for_tensor(self, tensor_name: str) -> "ConfigReader":
return ConfigReader(
self.model_dump(exclude=["tensor_name"]), tensor_name=tensor_name
)

def with_t(self, t: float) -> "ConfigReader":
return ConfigReader(self.model_dump(exclude=["t"]), t=t)

def parameter(
self,
name: str,
Expand Down Expand Up @@ -177,10 +191,8 @@ def parameter(
return value

if required:
suffix = (
f" for {str(model)}.{self.tensor_name}"
if model
else f" for {self.tensor_name}"
)
path_paths = [str(s) for s in [model, self.tensor_name] if s]
p = ".".join(path_paths)
suffix = f" for {p}" if p else ""
raise RuntimeError(f"Missing required parameter {name}{suffix}")
return default
Loading

0 comments on commit f3e328e

Please sign in to comment.