-
Notifications
You must be signed in to change notification settings - Fork 439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weights Metrics #340
base: main
Are you sure you want to change the base?
Weights Metrics #340
Conversation
… to plot (simple) interactive graph
…yers. heatmaps for other metrics.
run_metrics.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should probably be in mergekit/scripts
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also would be good to use click
to turn the hardcoded values into arguments.
from typing import List, Dict, Optional, Any, Tuple | ||
from mergekit.graph import Task | ||
import networkx as nx | ||
import plotly.graph_objects as go |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should capture these new dependencies in pyproject.toml
. Probably under a feature, so headless installs don't need to bring them in.
mergekit/plan.py
Outdated
) | ||
finalize = FinalizeModel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally fine to not do the finalize
task when we're doing metrics, but this is needed for merges - I think as is this makes merges not write out correctly.
**_kwargs, | ||
) -> Task: | ||
|
||
if 'self_attn' in output_weight.name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Down the line we probably want this split to be done based on new fields in ArchitectureInfo but this is good for now!
|
||
res = {} | ||
|
||
scale_diff = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be doing something here to guard against dividing by zero?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep - norms are non-negative so adding small epsilon will be fine
mergekit/architecture.py
Outdated
@@ -53,6 +57,9 @@ class WeightInfo(BaseModel, frozen=True): | |||
aliases: Optional[Tuple[str, ...]] = None | |||
force_dtype: Optional[str] = None | |||
|
|||
GQA_groups: Optional[int] = None # None if not GQA, 1 if MQA, >1 if GQA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be gqa_groups
num_heads=32 # hard-coded for now | ||
) | ||
self.block_count += 1 | ||
return AttnTask(weights=weights, weight_infos=infos, weight_info=weight_info) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this end up creating N AttnTasks for each block? I don't think it's actually a problem as the tasks will be deduplicated downstream - should be fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should only be one AttnTask for each block - the if statement on line 351 is only satisfied once all the tensors (K,Q,V,O) have been collected. Then self.attn_weight_dict is reset to {} and the (one) AttnTask is created. I might also add individual tensor metrics for comparing just the Qs, Vs etc, which would be simpler.
self._method = merge_methods.get(config.merge_method) | ||
if getattr(config, "merge_method", None): | ||
self._method = merge_methods.get(config.merge_method) | ||
elif getattr(config, "metric_method", None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to add a validator to MergeConfig that checks that exactly one of these fields is set.
mergekit/measure.py
Outdated
) | ||
|
||
res = [] | ||
for _task, value in exec.run(quiet=options.quiet): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking this over, I kinda think we might not need a separate file here - maybe it should just early out in merge.py
if there's a metric_method
set instead of merge_method?
mergekit/graph.py
Outdated
@@ -37,6 +37,7 @@ class Task(ABC, BaseModel, Generic[ValueT], frozen=True): | |||
Abstract base class representing a task in a computational graph. | |||
|
|||
This class should be extended to define specific tasks. Each task can have arguments (dependencies) and a defined execution strategy. | |||
Note that PyDantic BaseModel requires that all attributes are defined in the class initialisation, and cannot be changed after. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nitpick here: I think the official capitalization is Pydantic, not PyDantic.
…rge OR Metri, not both.
… to separate case
…eralised substitute function in architecture
Implemented:
Not Implemented: