-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmodel.py
142 lines (104 loc) · 3.97 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List
from pydantic import BaseModel
import rosnav_rl.cfg.sb3_cfg as sb3_cfg
from rosnav_rl.utils.type_aliases import (
EncodedObservationDict,
_SupportedRosnavRLModels,
)
if TYPE_CHECKING:
from rosnav_rl.rl_agent import RL_Agent
from rosnav_rl.spaces import BaseObservationSpace
class RL_Model(ABC):
"""
Abstract base class for reinforcement learning models.
Attributes:
_model: The underlying model used for reinforcement learning.
_algorithm_cfg: Configuration for the reinforcement learning algorithm.
_rl_agent: The reinforcement learning agent.
Methods:
__init__(rl_agent, algorithm_cfg, *args, **kwargs):
Initializes the RL_Model with the given agent and algorithm configuration.
setup_model(*args, **kwargs):
Abstract method to set up the model. Must be implemented by subclasses.
train(*args, **kwargs):
Abstract method to train the model. Must be implemented by subclasses.
save(*args, **kwargs):
Abstract method to save the model. Must be implemented by subclasses.
load(*args, **kwargs):
Abstract method to load the model. Must be implemented by subclasses.
get_action(observation, *args, **kwargs):
Gets the action for a given observation.
transfer_weights(*args, **kwargs):
Abstract method to transfer weights. Must be implemented by subclasses.
is_model_initialized:
Checks if the model is initialized.
model:
Gets or sets the underlying model.
algorithm_cfg:
Gets the algorithm configuration.
observation_space_list:
Abstract property to get the list of observation spaces. Must be implemented by subclasses.
observation_space_kwargs:
Abstract property to get the keyword arguments for observation spaces. Must be implemented by subclasses.
stack_size:
Gets the stack size. Default is 1.
parameter_number:
Abstract property to get the number of parameters. Must be implemented by subclasses.
config:
Gets the configuration dictionary. Default is an empty dictionary.
"""
_model: ...
_algorithm_cfg: BaseModel
_rl_agent: "RL_Agent"
def __init__(
self, rl_agent: "RL_Agent", algorithm_cfg: BaseModel, *args, **kwargs
) -> None:
self._rl_agent = rl_agent
self._algorithm_cfg = algorithm_cfg
@abstractmethod
def setup_model(self, *args, **kwargs):
pass
@abstractmethod
def train(self, *args, **kwargs):
raise NotImplementedError()
@abstractmethod
def save(self, *args, **kwargs):
pass
@abstractmethod
def load(self, *args, **kwargs):
pass
@abstractmethod
def get_action(self, observation: "EncodedObservationDict", *args, **kwargs):
pass
def transfer_weights(self, *args, **kwargs):
raise NotImplementedError()
@property
def is_model_initialized(self):
return self._model is not None
@property
def model(self) -> _SupportedRosnavRLModels:
if self._model is None:
raise ValueError("Model not initialized. Call 'initialize' first.")
return self._model
@model.setter
def model(self, model):
self._model = model
@property
def algorithm_cfg(self) -> "sb3_cfg.SBAlgorithmCfg":
return self._algorithm_cfg
@property
def observation_space_list(self) -> List["BaseObservationSpace"]:
raise NotImplementedError()
@property
def observation_space_kwargs(self) -> Dict[str, Any]:
raise NotImplementedError()
@property
def stack_size(self) -> int:
return 1
@property
def parameter_number(self) -> int:
raise NotImplementedError()
@property
def config(self) -> Dict[str, Any]:
return {}