Skip to content

Commit

Permalink
added runner script, added more robust feature handling, added featur…
Browse files Browse the repository at this point in the history
…e config
  • Loading branch information
jlunder00 committed Nov 25, 2024
1 parent fd729a4 commit 8ca0151
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 16 deletions.
61 changes: 61 additions & 0 deletions TMN_DataGen/configs/default_feature_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# TMN_DataGen/TMN_DataGen/configs/default_feature_config.yaml
feature_extraction:
word_embedding_model: bert-base-uncased
use_gpu: true
cache_embeddings: true
embedding_cache_dir: embedding_cache
batch_size: 32

feature_mappings:
pos_tags:
- ADJ
- ADP
- ADV
- AUX
- CCONJ
- DET
- INTJ
- NOUN
- NUM
- PART
- PRON
- PROPN
- PUNCT
- SCONJ
- SYM
- VERB
- X

dep_types:
- nsubj
- obj
- iobj
- det
- nmod
- amod
- advmod
- nummod
- appos
- conj
- cc
- punct
- root
- aux
- cop
- case
- mark
- compound
- acl
- fixed
- flat

morph_features:
- Number
- Person
- Tense
- VerbForm
- Case
- Gender
- Mood
- Voice
- Aspect
15 changes: 14 additions & 1 deletion TMN_DataGen/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def __init__(self):
def _load_configs(
self,
parser_config: Optional[Union[str, Dict]] = None,
preprocessing_config: Optional[Union[str, Dict]] = None,
preprocessing_config: Optional[Union[str, Dict]] = None,
feature_config: Optional[Union[str, Dict]] = None,
verbosity: str = 'normal',
override_pkg_config: Optional[Union[str, Dict]] = None
) -> Dict:
Expand All @@ -40,6 +41,9 @@ def _load_configs(

with open(config_dir / 'default_preprocessing_config.yaml') as f:
config.update(yaml.safe_load(f))

with open(config_dir / 'default_feature_config.yaml') as f:
config.update(yaml.safe_load(f))

# Add verbosity
config['verbose'] = verbosity
Expand All @@ -66,6 +70,15 @@ def _load_configs(
preprocessing_config = yaml.safe_load(f)
config['preprocessing'].update(preprocessing_config)

# Override feature config if provided
if feature_config:
if isinstance(feature_config, str):
with open(feature_config) as f:
feature_config = yaml.safe_load(f)
for key in ['feature_extraction', 'feature_mappings']:
if key in feature_config:
config[key].update(feature_config[key])

return OmegaConf.create(config), pkg_config

def generate_dataset(
Expand Down
22 changes: 11 additions & 11 deletions TMN_DataGen/utils/feature_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@ def __init__(self, config: Optional[DictConfig] = None):
self.config.get('verbose', 'normal')
)

# Load model configurations
# 1. Initialize feature mappings first
self.feature_mappings = self._initialize_feature_mappings()

# 2. Then calculate dimensions
self.morph_dim = self._calculate_morph_dim()
self.pos_dim = len(self.feature_mappings['pos_tags']) + 1
self.dep_dim = len(self.feature_mappings['dep_types']) + 1

# 3. Load model/tokenizer after mappings are ready
model_cfg = self.config.get('feature_extraction', {})
self.model_name = model_cfg.get('word_embedding_model', 'bert-base-uncased')
self.use_gpu = model_cfg.get('use_gpu', True) and torch.cuda.is_available()
Expand All @@ -45,20 +53,12 @@ def __init__(self, config: Optional[DictConfig] = None):
self.model = AutoModel.from_pretrained(self.model_name)
if self.use_gpu:
self.model = self.model.to(self.device)
self.model.eval() # Set to evaluation mode
self.model.eval()
self.embedding_dim = self.model.config.hidden_size
except Exception as e:
self.logger.error(f"Failed to load model: {e}")
raise

# Feature dimensions
self.embedding_dim = self.model.config.hidden_size
self.pos_dim = len(self._default_pos_tags()) + 1
self.dep_dim = len(self._default_dep_types()) + 1
self.morph_dim = self._calculate_morph_dim()

# Initialize feature mappings
self.feature_mappings = self._initialize_feature_mappings()

self.logger.info(f"Feature dimensions - Embedding: {self.embedding_dim}, "
f"POS: {self.pos_dim}, Dep: {self.dep_dim}, "
f"Morph: {self.morph_dim}")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setup(
name="TMN_DataGen",
version='0.5.0',
version='0.6.0',
description="Tree Matching Network Data Generator",
author="toast",
packages=find_packages(),
Expand Down
17 changes: 14 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,29 @@ def default_config():
"""Load all default configs for testing"""
config_dir = Path(__file__).parent.parent / 'TMN_DataGen' / 'configs'

# Load package config
# Load configs
with open(config_dir / 'default_package_config.yaml') as f:
pkg_config = yaml.safe_load(f)

# Load and merge parser and preprocessing configs
with open(config_dir / 'default_parser_config.yaml') as f:
config = yaml.safe_load(f)

with open(config_dir / 'default_preprocessing_config.yaml') as f:
preproc = yaml.safe_load(f)
config.update(preproc)

with open(config_dir / 'default_feature_config.yaml') as f:
feature_config = yaml.safe_load(f)

# Merge configs
config.update(preproc)
config.update(feature_config)

# # Override some settings for testing
# config['feature_extraction'].update({
# 'use_gpu': False,
# 'cache_embeddings': False
# })

return OmegaConf.create(config), pkg_config

@pytest.fixture
Expand Down

0 comments on commit 8ca0151

Please sign in to comment.