Skip to content

Commit

Permalink
Add util.copy_with. (#2562)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasper authored Jan 16, 2023
1 parent 6878509 commit 0e22711
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 1 deletion.
30 changes: 29 additions & 1 deletion src/gluonts/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,36 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import copy
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar

T = TypeVar("T")


def copy_with(obj: T, **kwargs) -> T:
"""Return copy of `obj` and update attributes of the copy using `kwargs`.
::
@dataclass
class MyClass:
value: int
a = MyClass(1)
b = copy_with(a, value=2)
assert a.value == 1
assert b.value == 2
"""

new_obj = copy.copy(obj)

for name, value in kwargs.items():
setattr(new_obj, name, value)

return new_obj


if TYPE_CHECKING:
Expand Down
12 changes: 12 additions & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
26 changes: 26 additions & 0 deletions test/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from gluonts.util import copy_with


def test_copy_with():
class X:
def __init__(self, value):
self.value = value

a = X(42)
b = copy_with(a, value=99)

assert a.value == 42
assert b.value == 99

0 comments on commit 0e22711

Please sign in to comment.