Skip to content

Commit

Permalink
Added initial proposal for lazy loading model initialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Jan 11, 2025
1 parent ca8331a commit e22d478
Showing 1 changed file with 93 additions and 64 deletions.
157 changes: 93 additions & 64 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
from dataclasses import dataclass, field
from datetime import timedelta
from enum import Enum, auto
from typing import Dict

import numpy as np
from tqdm import tqdm

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.metrics.utils.metric_utils import MetricCategory
from lighteval.models.abstract_model import ModelInfo
from lighteval.models.model_loader import TransformersModel, load_model
from lighteval.models.model_output import (
GenerativeMultiturnResponse,
Expand All @@ -43,6 +45,9 @@
LoglikelihoodSingleTokenResponse,
ModelResponse,
)
from lighteval.models.transformers.transformers_model import TransformersModelConfig
from lighteval.models.utils import _simplify_name
from lighteval.models.vllm.vllm_model import VLLMModelConfig
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
from lighteval.tasks.registry import Registry, taskinfo_selector
from lighteval.tasks.requests import RequestType, SampleUid
Expand Down Expand Up @@ -142,42 +147,49 @@ def __init__(
"--max_samples WAS SET. THESE NUMBERS ARE ONLY PARTIAL AND SHOULD NOT BE USED FOR COMPARISON UNLESS YOU KNOW WHAT YOU ARE DOING."
)

self.tasks = tasks
self.model = model
self.model_config = model_config
self.evaluation_tracker = evaluation_tracker
self.accelerator, self.parallel_context = self._init_parallelism_manager()
self.model = self._init_model(model_config, model)

self.evaluation_tracker.general_config_logger.log_model_info(self.model.model_info)
self._init_tasks_and_requests(tasks=tasks)
self._init_parallelism_manager()
self._init_random_seeds()

self.evaluation_tracker.general_config_logger.log_model_info(self._get_model_info())
self._init_tasks()

# Final results
self.final_dict: dict = None

def _get_model_info(self):
if isinstance(self.model_config, (VLLMModelConfig, TransformersModelConfig)):
# At this point we only need the model name to know the details path
return ModelInfo(model_name=_simplify_name(self.model_config.pretrained))
else:
return self._init_model().model_info

def _init_parallelism_manager(self):
accelerator, parallel_context = None, None
self.accelerator, self.parallel_context = None, None
if self.launcher_type == ParallelismManager.ACCELERATE:
if not is_accelerate_available():
raise ValueError("You are trying to launch an accelerate model, but accelerate is not installed")
accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
test_all_gather(accelerator=accelerator)
self.accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
test_all_gather(accelerator=self.accelerator)
elif self.launcher_type == ParallelismManager.NANOTRON:
if not is_nanotron_available():
raise ValueError("You are trying to launch a nanotron model, but nanotron is not installed")
dist.initialize_torch_distributed()
parallel_context = ParallelContext(
self.parallel_context = ParallelContext(
tensor_parallel_size=self.model_config.lighteval_config.parallelism.tp,
pipeline_parallel_size=self.model_config.lighteval_config.parallelism.pp,
data_parallel_size=self.model_config.lighteval_config.parallelism.dp,
)
test_all_gather(parallel_context=parallel_context)
test_all_gather(parallel_context=self.parallel_context)

return accelerator, parallel_context

def _init_model(self, model_config, model):
def _init_model(self):
logger.info("--- LOADING MODEL ---")
if model_config is not None:
if self.model_config is not None:
if self.parallel_context:
return NanotronLightevalModel(
self.model = NanotronLightevalModel(
checkpoint_path=os.path.dirname(self.pipeline_parameters.nanotron_checkpoint_path)
if self.pipeline_parameters.nanotron_checkpoint_path
else "",
Expand All @@ -188,46 +200,42 @@ def _init_model(self, model_config, model):
env_config=self.pipeline_parameters.env_config,
)
else:
return load_model(config=model_config, env_config=self.pipeline_parameters.env_config)
if isinstance(model, TransformersModel):
return model
else:
return TransformersModel.from_model(
model=model,
self.model = load_model(config=self.model_config, env_config=self.pipeline_parameters.env_config)
if not isinstance(self.model, TransformersModel):
self.model = TransformersModel.from_model(
model=self.model,
use_chat_template=self.pipeline_parameters.use_chat_template,
env_config=self.pipeline_parameters.env_config,
accelerator=self.accelerator,
)
return self.model

def _init_tasks_and_requests(self, tasks: str):
def _init_tasks(self):
with local_ranks_zero_first() if self.launcher_type == ParallelismManager.NANOTRON else nullcontext():
logger.info("--- LOADING TASKS ---")
registry = Registry(
cache_dir=self.pipeline_parameters.env_config.cache_dir,
custom_tasks=self.pipeline_parameters.custom_tasks_directory,
)
task_names_list, fewshots_dict = taskinfo_selector(tasks, registry)
task_dict = registry.get_task_dict(task_names_list)
LightevalTask.load_datasets(list(task_dict.values()), self.pipeline_parameters.dataset_loading_processes)

self.evaluation_tracker.task_config_logger.log(task_dict)

requests, docs = create_requests_from_tasks(
task_dict=task_dict,
fewshot_dict=fewshots_dict,
num_fewshot_seeds=self.pipeline_parameters.num_fewshot_seeds,
lm=self.model,
max_samples=self.pipeline_parameters.max_samples,
evaluation_tracker=self.evaluation_tracker,
use_chat_template=self.pipeline_parameters.use_chat_template,
system_prompt=self.pipeline_parameters.system_prompt,
self.task_names_list, self.fewshots_dict = taskinfo_selector(self.tasks, registry)
self.task_dict = registry.get_task_dict(self.task_names_list)
LightevalTask.load_datasets(
list(self.task_dict.values()), self.pipeline_parameters.dataset_loading_processes
)

self.task_names_list = task_names_list
self.task_dict = task_dict
self.fewshot_dict = fewshots_dict
self.requests = requests
self.docs = docs
self.evaluation_tracker.task_config_logger.log(self.task_dict)

def _init_requests(self):
self.requests, self.docs = create_requests_from_tasks(
task_dict=self.task_dict,
fewshot_dict=self.fewshots_dict,
num_fewshot_seeds=self.pipeline_parameters.num_fewshot_seeds,
lm=self.model,
max_samples=self.pipeline_parameters.max_samples,
evaluation_tracker=self.evaluation_tracker,
use_chat_template=self.pipeline_parameters.use_chat_template,
system_prompt=self.pipeline_parameters.system_prompt,
)

def _init_random_seeds(self):
logger.info("--- INIT SEEDS ---")
Expand Down Expand Up @@ -280,16 +288,37 @@ def evaluate(self):
except OSError:
pass

@staticmethod
def _metric_category_to_request_type() -> Dict[MetricCategory, RequestType]:
"""Maps MetricCategory to their corresponding RequestType."""
return {
MetricCategory.TARGET_PERPLEXITY: RequestType.LOGLIKELIHOOD,
MetricCategory.PERPLEXITY: RequestType.LOGLIKELIHOOD_ROLLING,
MetricCategory.GENERATIVE_SAMPLING: RequestType.GREEDY_UNTIL,
MetricCategory.GENERATIVE: RequestType.GREEDY_UNTIL,
MetricCategory.GENERATIVE_LOGPROB: RequestType.GREEDY_UNTIL,
MetricCategory.MULTICHOICE: RequestType.LOGLIKELIHOOD,
MetricCategory.MULTICHOICE_PMI: RequestType.LOGLIKELIHOOD,
MetricCategory.MULTICHOICE_ONE_TOKEN: RequestType.LOGLIKELIHOOD_SINGLE_TOKEN,
MetricCategory.LLM_AS_JUDGE_MULTI_TURN: RequestType.GREEDY_UNTIL_MULTI_TURN,
MetricCategory.LLM_AS_JUDGE: RequestType.GREEDY_UNTIL,
}

@staticmethod
def _request_type_to_response() -> Dict[RequestType, type[ModelResponse]]:
return {
RequestType.LOGLIKELIHOOD: LoglikelihoodResponse,
RequestType.LOGLIKELIHOOD_SINGLE_TOKEN: LoglikelihoodSingleTokenResponse,
RequestType.LOGLIKELIHOOD_ROLLING: LoglikelihoodResponse,
RequestType.GREEDY_UNTIL_MULTI_TURN: GenerativeMultiturnResponse,
RequestType.GREEDY_UNTIL: GenerativeResponse,
}

def _load_responses_from_details(self):
logger.info("--- LOADING RESPONSES FROM DETAILS ---")
sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list)

request_types = list(self.requests.keys())
if len(request_types) > 1:
raise ValueError(
"Loading responses from details when there are multiple request types is currently not supported"
)
model_response_type = self._get_model_response_type(request_types[0])
model_response_type = self._get_model_response_type()

details_datasets = self.evaluation_tracker.load_details_datasets(
self.pipeline_parameters.load_responses_from_details_date_id, self.task_names_list
Expand Down Expand Up @@ -333,25 +362,25 @@ def _load_responses_from_details(self):
sample_id_to_responses[(SampleUid(task_name, f"{idx}_{0}"), metric_category)] = [response]
return sample_id_to_responses

def _get_model_response_type(self, request_type):
if request_type == RequestType.LOGLIKELIHOOD:
model_response_type = LoglikelihoodResponse
elif request_type == RequestType.LOGLIKELIHOOD_SINGLE_TOKEN:
model_response_type = LoglikelihoodSingleTokenResponse
elif request_type == RequestType.LOGLIKELIHOOD_ROLLING:
model_response_type = LoglikelihoodResponse
elif request_type == RequestType.GREEDY_UNTIL_MULTI_TURN:
model_response_type = GenerativeMultiturnResponse
elif request_type == RequestType.GREEDY_UNTIL:
model_response_type = GenerativeResponse
else:
raise ValueError(
f"Loading responses from details for request type {request_type} is currently not supported"
)

def _get_model_response_type(self):
model_response_type = None
for task in self.task_dict.values():
for metric_category, has_metric_category in task.has_metric_category.items():
if has_metric_category:
request_type = self._metric_category_to_request_type()[metric_category]
new_model_response_type = self._request_type_to_response()[request_type]
if model_response_type and new_model_response_type != model_response_type:
raise ValueError(
f"Loading responses from details with multiple model response types ({model_response_type} and {new_model_response_type}) is currently not supported"
)
model_response_type = new_model_response_type
return model_response_type

def _run_model(self):
# Initi model stuff lazily to avoid loading the model if not needed
self._init_model()
self._init_requests() # Needs the model to be initialized

# Running all requests depending on the model call type (log likelihood, generative, ...)
# to be able to batch them
logger.info("--- RUNNING MODEL ---")
Expand Down

0 comments on commit e22d478

Please sign in to comment.