Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jlunder00 committed Jan 29, 2025
1 parent 4409a06 commit f9dda7b
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 141 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,5 @@ cython_debug/

embedding_cache/
embedding_cache/*

embedding_cache1/
embedding_cache1/*
93 changes: 8 additions & 85 deletions TMN_DataGen/configs/default_feature_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ feature_extraction:
word_embedding_model: bert-base-uncased
use_gpu: true
cache_embeddings: true
embedding_cache_dir: embedding_cache
do_not_store_word_embs: True
do_not_compute_word_embs: False
embedding_cache_dir: embedding_cache1
do_not_store_word_embeddings: True
do_not_compute_word_embeddings: False
batch_size: 32
num_workers: 12
shard_size: 10000
remove_old_cache: False
is_runtime: False


feature_mappings:
pos_tags:
Expand All @@ -28,65 +33,6 @@ feature_mappings:
- VERB
- X

# dep_types:
# # Base types
# - nsubj
# - obj
# - iobj
# - csubj
# - ccomp
# - xcomp
# - obl
# - vocative
# - expl
# - dislocated
# - advcl
# - advmod
# - discourse
# - aux
# - cop
# - mark
# - nmod
# - appos
# - nummod
# - acl
# - amod
# - det
# - clf
# - case
# - conj
# - cc
# - fixed
# - flat
# - compound
# - list
# - parataxis
# - orphan
# - goeswith
# - reparandum
# - punct
# - dep
# - root

# # Semi-mandatory subtypes
# - acl:relcl
# - advcl:relcl
# - aux:pass
# - csubj:outer
# - csubj:pass
# - expl:impers
# - expl:pass
# - expl:pv
# - nsubj:outer
# - nsubj:pass
# - obl:agent
# - obl:tmod
# - obl:arg
# - obl:lmod
# - nmod:poss
# - nmod:tmod
# - obl:npmod

dep_types:
# Core arguments
- nsubj
Expand Down Expand Up @@ -167,29 +113,6 @@ feature_mappings:
- obl:npmod
- nmod:npmod

# dep_types:
# - nsubj
# - obj
# - iobj
# - det
# - nmod
# - amod
# - advmod
# - nummod
# - appos
# - conj
# - cc
# - punct
# - root
# - aux
# - cop
# - case
# - mark
# - compound
# - acl
# - fixed
# - flat

morph_features:
- Number
- Person
Expand Down
6 changes: 3 additions & 3 deletions TMN_DataGen/tree/dependency_tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#dependency_tree.py
#TMN_DataGen/tree/dependency_tree.py
from typing import List, Dict, Any, Tuple, Optional
import numpy as np
import torch
Expand Down Expand Up @@ -94,7 +94,7 @@ def to_graph_data(self) -> Dict[str, Any]:
for i, node in enumerate(nodes):
try:
node_features[i] = extractor.create_node_features(node)
node_texts[i] = (node.word, node.lemma)
node_texts.append((node.word, node.lemma))
except Exception as e:
raise ValueError(f"Failed to create features for node {node}: {e}")

Expand All @@ -121,7 +121,7 @@ def to_graph_data(self) -> Dict[str, Any]:
'graph_idx': [0] * len(nodes),
'n_graphs': 1,
'node_texts': node_texts, # Save word instead of embedding to save space, compute (or retrieve from cache) at training time per batch
'node_features_need_word_embs_prepended': extractor.do_not_compute_word_embeddings or extractor.do_not_store_word_embeddings,
'node_features_need_word_embs_prepended': extractor.do_not_store_word_embeddings,
'text': self.text
}

Expand Down
169 changes: 169 additions & 0 deletions TMN_DataGen/utils/embedding_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#TMN_DataGen/utils/embedding_cache.py
import numpy as np
import torch
from pathlib import Path
from tqdm import tqdm
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from ..utils.logging_config import setup_logger
from typing import Dict, Optional
import gc

class ParallelEmbeddingCache:
"""A drop-in replacement for FeatureExtractor's embedding cache that supports parallel processing"""

def __init__(self,
cache_dir: Path,
shard_size: int = 10000,
num_workers: Optional[int] = None,
config: Optional[Dict] = None):
"""
Initialize the parallel embedding cache system.
Args:
cache_dir: Directory to store cache files
shard_size: Number of embeddings per shard
num_workers: Number of parallel workers (defaults to CPU count)
config: Configuration dictionary
"""
self.cache_dir = cache_dir
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.shard_size = shard_size
self.num_workers = num_workers or min(mp.cpu_count(), 4) # Limit default workers
self.embedding_cache: Dict[str, torch.Tensor] = {}
self._current_shard = 0
self._items_in_current_shard = 0
self.config = config or {}
self.logger = setup_logger(
self.__class__.__name__,
self.config.get('verbose', 'normal')
)

def _get_shard_path(self, shard_idx: int) -> Path:
"""Get the path for a specific shard file."""
return self.cache_dir / f"embedding_cache_shard_{shard_idx}.npz"

def __getitem__(self, word: str) -> Optional[torch.Tensor]:
"""Enable dictionary-like access to cache."""
return self.embedding_cache.get(word)

def __setitem__(self, word: str, embedding: torch.Tensor):
"""Enable dictionary-like setting of cache values."""
self.embedding_cache[word] = embedding
self._items_in_current_shard += 1

# Auto-save shards when they reach the size limit
if self._items_in_current_shard >= self.shard_size:
self._save_current_shard()
self._current_shard += 1
self._items_in_current_shard = 0

def __contains__(self, word: str) -> bool:
"""Enable 'in' operator for cache."""
return word in self.embedding_cache

def items(self):
"""Enable items() access for cache."""
return self.embedding_cache.items()

def _save_current_shard(self):
"""Save the current shard of embeddings."""
if not self.embedding_cache:
return

shard_path = self._get_shard_path(self._current_shard)
items_to_save = list(self.embedding_cache.items())
start_idx = self._current_shard * self.shard_size
end_idx = start_idx + self.shard_size
shard_items = items_to_save[start_idx:end_idx]

if shard_items:
# Convert to numpy and save
np_data = {word: emb.cpu().numpy() for word, emb in shard_items}
np.savez(shard_path, **np_data)

@staticmethod
def _load_shard(shard_path: Path) -> Dict[str, np.ndarray]:
"""Load a single shard of embeddings."""
if not shard_path.exists():
return {}

try:
with np.load(shard_path, allow_pickle=True) as npz:
# Return as numpy arrays initially to avoid PyTorch shared memory issues
return {word: emb for word, emb in npz.items()}
except Exception as e:
print(f"Error loading shard {shard_path}: {e}")
return {}

def _convert_to_torch(self, numpy_dict: Dict[str, np.ndarray]) -> None:
"""Convert numpy arrays to torch tensors and add to cache."""
for word, emb in numpy_dict.items():
self.embedding_cache[word] = torch.from_numpy(emb)

def load(self):
"""Load all cached embeddings from shards."""
shard_paths = sorted(self.cache_dir.glob("embedding_cache_shard_*.npz"))

if not shard_paths:
old_cache = self.cache_dir / "embedding_cache.npz"
if old_cache.exists():
self.logger.info("Found old-style cache file, loading it")
with np.load(old_cache, allow_pickle=True) as cache_data:
# Load as numpy first
numpy_cache = {
word: emb for word, emb in
tqdm(cache_data.items(), desc="Loading old cache")
}
# Convert to torch tensors
for word, emb in tqdm(numpy_cache.items(), desc="Converting to torch tensors"):
self.embedding_cache[word] = torch.from_numpy(emb)

self.logger.info(f"Loaded {len(self.embedding_cache)} items from old cache")
# Save in new format and optionally remove old cache
self.save()
if self.config.get('remove_old_cache', False):
self.logger.info("Removing old cache file")
old_cache.unlink()
return
else:
self.logger.info("No cache files found")
return

self.logger.info(f"Loading embeddings from {len(shard_paths)} shards")

# Process shards sequentially if only one worker
if self.num_workers == 1:
for shard_path in tqdm(shard_paths, desc="Loading embedding shards"):
numpy_data = self._load_shard(shard_path)
self._convert_to_torch(numpy_data)
gc.collect() # Help manage memory
else:
# Process shards in parallel but convert to torch tensors sequentially
with ProcessPoolExecutor(max_workers=self.num_workers) as executor:
for numpy_data in tqdm(
executor.map(self._load_shard, shard_paths),
total=len(shard_paths),
desc="Loading embedding shards"
):
self._convert_to_torch(numpy_data)
gc.collect() # Help manage memory

self._current_shard = len(shard_paths)
self.logger.info(f"Successfully loaded {len(self.embedding_cache)} embeddings")

def save(self):
"""Save all cached embeddings to shards."""
if not self.embedding_cache:
self.logger.info("No embeddings to save")
return

# Save any remaining items in the current shard
if self._items_in_current_shard > 0:
self._save_current_shard()

self.logger.info(f"Saved {len(self.embedding_cache)} embeddings across {self._current_shard + 1} shards")

def __len__(self):
"""Enable len() operator for cache."""
return len(self.embedding_cache)
Loading

0 comments on commit f9dda7b

Please sign in to comment.