Skip to content

Commit

Permalink
[cleanup] removing global compiler configuration
Browse files Browse the repository at this point in the history
This is a part of ongoing effort to remove any global state left over
from pybuda.

There should be no need of having global configuration of compiler. If
user wants to override some of the exposed settings, it should be done
locally, by creating a new (default) config and then passing it into the
compile function. All tests which need to modify the config should do
the same.

No feature should rely on reading from a "volatile global
configuration". Instead, each feature must either be designed to receive
the user-specified compiler configuration directly or defer
config-dependent decisions for later (until the compilation stage where
the appropriate configuration is available).

In majority of the cases I would assume that direct access to the user
passed config is easily available.

Closes #147
  • Loading branch information
pilkicTT committed Feb 7, 2025
1 parent 8baf3d2 commit 637149d
Show file tree
Hide file tree
Showing 37 changed files with 64 additions and 1,673 deletions.
10 changes: 0 additions & 10 deletions forge/forge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,6 @@ def set_home_paths():
from .config import (
CompilerConfig,
CompileDepth,
set_configuration_options,
set_epoch_break,
set_chip_break,
override_op_size,
PerfTraceLevel,
insert_buffering_nop,
insert_nop,
_internal_insert_fj_buffering_nop,
override_dram_queue_placement,
configure_mixed_precision,
)
from .verify import DepricatedVerifyConfig
from .forgeglobal import forge_reset, set_device_pipeline, is_silicon, get_tenstorrent_device
Expand Down
18 changes: 3 additions & 15 deletions forge/forge/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from forge.config import (
CompilerConfig,
CompileDepth,
_get_global_compiler_config,
)
from forge._C import (
link_past_cache_ios,
Expand Down Expand Up @@ -183,6 +182,7 @@ def compile_main(
optimizer: Optional[Union[torch.optim.Optimizer, forge.optimizers.Optimizer]] = None,
training: bool = False,
attach_to: Optional[CompiledModel] = None,
compiler_cfg: CompilerConfig = CompilerConfig(),
) -> CompiledModel:
"""
Main entry point for compiling modules from different frameworks for Tenstorrent devices.
Expand Down Expand Up @@ -218,9 +218,6 @@ def compile_main(

assert isinstance(module, AnyModule), "Only PyTorch, TensorFlow, and Forge modules are supported."

compiler_cfg = _get_global_compiler_config()
compiler_cfg.apply_env_config_overrides()

if module_name is None:
module_name = module.__class__.__name__

Expand Down Expand Up @@ -440,7 +437,7 @@ def forge_compile_torch(

inputs = list(inputs)

compiler_cfg = _get_global_compiler_config()
compiler_cfg = CompilerConfig()
compiler_cfg.apply_env_config_overrides()

compile_context: CompileContext = CompileContext(
Expand All @@ -461,7 +458,7 @@ def forge_compile(
graph_name: str,
*inputs: Union[Tensor, List[Any], Dict[str, Any]],
targets: List[Tensor] = [],
compiler_cfg: Optional[CompilerConfig] = None,
compiler_cfg: CompilerConfig = CompilerConfig(),
verify_cfg: Optional[DepricatedVerifyConfig] = None,
losses: Optional[List[Tensor]] = None,
microbatch_size: int = 1,
Expand Down Expand Up @@ -505,9 +502,6 @@ def forge_compile(
if verify_cfg is None:
verify_cfg = DepricatedVerifyConfig.disabled() # no verification config provided, disable by default

if compiler_cfg is None:
compiler_cfg = _get_global_compiler_config()

compiler_cfg.apply_env_config_overrides()

compile_context: CompileContext = CompileContext(
Expand Down Expand Up @@ -642,9 +636,6 @@ def init_compile(context: CompileContext) -> CompileDepth:
if force_full:
compiler_cfg.compile_depth = CompileDepth.FULL

context.backend_output_directory = compiler_cfg.backend_output_dir
ci.initialize_output_build_directory(context.backend_output_directory)

# compiler_cfg is fully formed
if "FORGE_LOAD_CONFIG" in os.environ:
compiler_cfg = load_compiler_cfg(compiler_cfg)
Expand Down Expand Up @@ -1098,9 +1089,6 @@ def generate_graph(
graph = Graph(graph_name)
graph.set_microbatch(1)

if compiler_cfg is None:
compiler_cfg = _get_global_compiler_config()

# Trace through the modules
all_subgraph_outputs = []
outputs = inputs
Expand Down
Loading

0 comments on commit 637149d

Please sign in to comment.