Skip to content

Commit

Permalink
Refactor optimizer registration (#1053)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Feb 20, 2025
1 parent 434f76d commit 9246037
Show file tree
Hide file tree
Showing 12 changed files with 98 additions and 42 deletions.
2 changes: 2 additions & 0 deletions src/fairseq2/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from fairseq2.optim._adamw import AdamW as AdamW
from fairseq2.optim._adamw import AdamWConfig as AdamWConfig
from fairseq2.optim._adamw import AdamWHandler as AdamWHandler
from fairseq2.optim._adamw import register_adamw as register_adamw
from fairseq2.optim._dynamic_loss_scaler import DynamicLossScaler as DynamicLossScaler
from fairseq2.optim._dynamic_loss_scaler import LossScaleResult as LossScaleResult
from fairseq2.optim._error import UnknownOptimizerError as UnknownOptimizerError
from fairseq2.optim._handler import OptimizerHandler as OptimizerHandler
from fairseq2.optim._optimizer import AbstractOptimizer as AbstractOptimizer
from fairseq2.optim._optimizer import ParameterCollection as ParameterCollection
from fairseq2.optim._setup import register_optimizers as register_optimizers
7 changes: 7 additions & 0 deletions src/fairseq2/optim/_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.optim.adamw import adamw # type: ignore[attr-defined]
from typing_extensions import override

from fairseq2.context import RuntimeContext
from fairseq2.error import NotSupportedError
from fairseq2.optim._handler import OptimizerHandler
from fairseq2.optim._optimizer import AbstractOptimizer, ParameterCollection
Expand Down Expand Up @@ -347,3 +348,9 @@ def create(self, params: ParameterCollection, config: object) -> Optimizer:
@override
def config_kls(self) -> type[object]:
return AdamWConfig


def register_adamw(context: RuntimeContext) -> None:
registry = context.get_registry(OptimizerHandler)

registry.register(ADAMW_OPTIMIZER, AdamWHandler())
14 changes: 14 additions & 0 deletions src/fairseq2/optim/_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.context import RuntimeContext
from fairseq2.optim._adamw import register_adamw


def register_optimizers(context: RuntimeContext) -> None:
register_adamw(context)
14 changes: 14 additions & 0 deletions src/fairseq2/optim/lr_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from fairseq2.optim.lr_scheduler._cosine_annealing import (
CosineAnnealingLRHandler as CosineAnnealingLRHandler,
)
from fairseq2.optim.lr_scheduler._cosine_annealing import (
register_cosine_annealing_lr as register_cosine_annealing_lr,
)
from fairseq2.optim.lr_scheduler._error import (
UnknownLRSchedulerError as UnknownLRSchedulerError,
)
Expand All @@ -39,10 +42,12 @@
from fairseq2.optim.lr_scheduler._myle import MyleLR as MyleLR
from fairseq2.optim.lr_scheduler._myle import MyleLRConfig as MyleLRConfig
from fairseq2.optim.lr_scheduler._myle import MyleLRHandler as MyleLRHandler
from fairseq2.optim.lr_scheduler._myle import register_myle_lr as register_myle_lr
from fairseq2.optim.lr_scheduler._noam import NOAM_LR as NOAM_LR
from fairseq2.optim.lr_scheduler._noam import NoamLR as NoamLR
from fairseq2.optim.lr_scheduler._noam import NoamLRConfig as NoamLRConfig
from fairseq2.optim.lr_scheduler._noam import NoamLRHandler as NoamLRHandler
from fairseq2.optim.lr_scheduler._noam import register_noam_lr as register_noam_lr
from fairseq2.optim.lr_scheduler._polynomial_decay import (
POLYNOMIAL_DECAY_LR as POLYNOMIAL_DECAY_LR,
)
Expand All @@ -55,9 +60,18 @@
from fairseq2.optim.lr_scheduler._polynomial_decay import (
PolynomialDecayLRHandler as PolynomialDecayLRHandler,
)
from fairseq2.optim.lr_scheduler._polynomial_decay import (
register_polynomial_decay_lr as register_polynomial_decay_lr,
)
from fairseq2.optim.lr_scheduler._setup import (
register_lr_schedulers as register_lr_schedulers,
)
from fairseq2.optim.lr_scheduler._tri_stage import TRI_STAGE_LR as TRI_STAGE_LR
from fairseq2.optim.lr_scheduler._tri_stage import TriStageLR as TriStageLR
from fairseq2.optim.lr_scheduler._tri_stage import TriStageLRConfig as TriStageLRConfig
from fairseq2.optim.lr_scheduler._tri_stage import (
TriStageLRHandler as TriStageLRHandler,
)
from fairseq2.optim.lr_scheduler._tri_stage import (
register_tri_stage_lr as register_tri_stage_lr,
)
7 changes: 7 additions & 0 deletions src/fairseq2/optim/lr_scheduler/_cosine_annealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.optim import Optimizer
from typing_extensions import override

from fairseq2.context import RuntimeContext
from fairseq2.logging import log
from fairseq2.optim.lr_scheduler._handler import LRSchedulerHandler
from fairseq2.optim.lr_scheduler._lr_scheduler import (
Expand Down Expand Up @@ -275,3 +276,9 @@ def requires_num_steps(self) -> bool:
@override
def config_kls(self) -> type[object]:
return CosineAnnealingLRConfig


def register_cosine_annealing_lr(context: RuntimeContext) -> None:
registry = context.get_registry(LRSchedulerHandler)

registry.register(COSINE_ANNEALING_LR, CosineAnnealingLRHandler())
7 changes: 7 additions & 0 deletions src/fairseq2/optim/lr_scheduler/_myle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.optim import Optimizer
from typing_extensions import override

from fairseq2.context import RuntimeContext
from fairseq2.optim.lr_scheduler._handler import LRSchedulerHandler
from fairseq2.optim.lr_scheduler._lr_scheduler import (
AbstractLRScheduler,
Expand Down Expand Up @@ -135,3 +136,9 @@ def requires_num_steps(self) -> bool:
@override
def config_kls(self) -> type[object]:
return MyleLRConfig


def register_myle_lr(context: RuntimeContext) -> None:
registry = context.get_registry(LRSchedulerHandler)

registry.register(MYLE_LR, MyleLRHandler())
7 changes: 7 additions & 0 deletions src/fairseq2/optim/lr_scheduler/_noam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.optim import Optimizer
from typing_extensions import override

from fairseq2.context import RuntimeContext
from fairseq2.optim.lr_scheduler._handler import LRSchedulerHandler
from fairseq2.optim.lr_scheduler._lr_scheduler import AbstractLRScheduler, LRScheduler
from fairseq2.utils.structured import structure
Expand Down Expand Up @@ -108,3 +109,9 @@ def requires_num_steps(self) -> bool:
@override
def config_kls(self) -> type[object]:
return NoamLRConfig


def register_noam_lr(context: RuntimeContext) -> None:
registry = context.get_registry(LRSchedulerHandler)

registry.register(NOAM_LR, NoamLRHandler())
7 changes: 7 additions & 0 deletions src/fairseq2/optim/lr_scheduler/_polynomial_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.optim import Optimizer
from typing_extensions import override

from fairseq2.context import RuntimeContext
from fairseq2.optim.lr_scheduler._error import UnspecifiedNumberOfStepsError
from fairseq2.optim.lr_scheduler._handler import LRSchedulerHandler
from fairseq2.optim.lr_scheduler._lr_scheduler import (
Expand Down Expand Up @@ -169,3 +170,9 @@ def requires_num_steps(self) -> bool:
@override
def config_kls(self) -> type[object]:
return PolynomialDecayLRConfig


def register_polynomial_decay_lr(context: RuntimeContext) -> None:
registry = context.get_registry(LRSchedulerHandler)

registry.register(POLYNOMIAL_DECAY_LR, PolynomialDecayLRHandler())
22 changes: 22 additions & 0 deletions src/fairseq2/optim/lr_scheduler/_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.context import RuntimeContext
from fairseq2.optim.lr_scheduler._cosine_annealing import register_cosine_annealing_lr
from fairseq2.optim.lr_scheduler._myle import register_myle_lr
from fairseq2.optim.lr_scheduler._noam import register_noam_lr
from fairseq2.optim.lr_scheduler._polynomial_decay import register_polynomial_decay_lr
from fairseq2.optim.lr_scheduler._tri_stage import register_tri_stage_lr


def register_lr_schedulers(context: RuntimeContext) -> None:
register_cosine_annealing_lr(context)
register_myle_lr(context)
register_noam_lr(context)
register_polynomial_decay_lr(context)
register_tri_stage_lr(context)
7 changes: 7 additions & 0 deletions src/fairseq2/optim/lr_scheduler/_tri_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.optim import Optimizer
from typing_extensions import override

from fairseq2.context import RuntimeContext
from fairseq2.optim.lr_scheduler._error import UnspecifiedNumberOfStepsError
from fairseq2.optim.lr_scheduler._handler import LRSchedulerHandler
from fairseq2.optim.lr_scheduler._lr_scheduler import (
Expand Down Expand Up @@ -193,3 +194,9 @@ def requires_num_steps(self) -> bool:
@override
def config_kls(self) -> type[object]:
return TriStageLRConfig


def register_tri_stage_lr(context: RuntimeContext) -> None:
registry = context.get_registry(LRSchedulerHandler)

registry.register(TRI_STAGE_LR, TriStageLRHandler())
7 changes: 4 additions & 3 deletions src/fairseq2/setup/_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from fairseq2.extensions import run_extensions
from fairseq2.metrics import register_metric_descriptors
from fairseq2.metrics.recorders import register_metric_recorders
from fairseq2.optim import register_optimizers
from fairseq2.optim.lr_scheduler import register_lr_schedulers
from fairseq2.profilers import register_profilers
from fairseq2.recipes.setup import register_recipes
from fairseq2.setup._generation import (
Expand All @@ -31,7 +33,6 @@
_register_seq_generators,
)
from fairseq2.setup._models import _register_models
from fairseq2.setup._optim import _register_lr_schedulers, _register_optimizers
from fairseq2.setup._text_tokenizers import _register_text_tokenizers
from fairseq2.utils.file import LocalFileSystem

Expand Down Expand Up @@ -99,11 +100,11 @@ def setup_library() -> RuntimeContext:
register_chatbots(context)
register_clusters(context)
register_datasets(context)
_register_lr_schedulers(context)
register_lr_schedulers(context)
register_metric_descriptors(context)
register_metric_recorders(context)
_register_models(context)
_register_optimizers(context)
register_optimizers(context)
register_profilers(context)
register_recipes(context)
_register_samplers(context)
Expand Down
39 changes: 0 additions & 39 deletions src/fairseq2/setup/_optim.py

This file was deleted.

0 comments on commit 9246037

Please sign in to comment.