diff --git a/tools/who_what_benchmark/tests/test_cli_text.py b/tools/who_what_benchmark/tests/test_cli_text.py index 750ff2eec1..53cbae472d 100644 --- a/tools/who_what_benchmark/tests/test_cli_text.py +++ b/tools/who_what_benchmark/tests/test_cli_text.py @@ -159,7 +159,7 @@ def test_text_verbose(): assert "## Diff:" in result.stderr -def test_text_language_autodetect(): +def test_text_language(): with tempfile.TemporaryDirectory() as temp_dir: temp_file_name = os.path.join(temp_dir, "gt.csv") result = run_wwb( @@ -172,6 +172,8 @@ def test_text_language_autodetect(): "2", "--device", "CPU", + "--language", + "cn", ] ) assert result.returncode == 0 diff --git a/tools/who_what_benchmark/whowhatbench/text_evaluator.py b/tools/who_what_benchmark/whowhatbench/text_evaluator.py index 433521a186..73fb1d1928 100644 --- a/tools/who_what_benchmark/whowhatbench/text_evaluator.py +++ b/tools/who_what_benchmark/whowhatbench/text_evaluator.py @@ -73,21 +73,6 @@ } -def autodetect_language(model): - model2language = { - "chatglm": "cn", - "qwen2": "cn", - "qwen": "cn", - "baichuan": "cn", - "minicpmv": "cn", - "internlm": "cn", - } - - if not hasattr(model, "config"): - return "en" - return model2language.get(model.config.model_type, "en") - - @register_evaluator( "text" ) @@ -103,7 +88,7 @@ def __init__( max_new_tokens=128, crop_question=True, num_samples=None, - language=None, + language="en", gen_answer_fn=None, generation_config=None, generation_config_base=None, @@ -130,9 +115,6 @@ def __init__( # Take language from the base model if provided self.language = language - if self.language is None: - if base_model is not None: - self.language = autodetect_language(base_model) if base_model: self.gt_data = self._generate_data( @@ -233,11 +215,6 @@ def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, data = {"prompts": list(self.test_data)} data = pd.DataFrame.from_dict(data) else: - if self.language is None: - print( - "No language detecting in the base model or ground truth data. Taking language from target model." - ) - self.language = autodetect_language(model) data = pd.DataFrame.from_dict(default_data[self.language]) prompt_data = data["prompts"] diff --git a/tools/who_what_benchmark/whowhatbench/wwb.py b/tools/who_what_benchmark/whowhatbench/wwb.py index e58c1c2aaf..69397b23a7 100644 --- a/tools/who_what_benchmark/whowhatbench/wwb.py +++ b/tools/who_what_benchmark/whowhatbench/wwb.py @@ -129,8 +129,8 @@ def parse_args(): "--language", type=str, choices=["en", "cn"], - default=None, - help="Used to select default prompts based on the primary model language, e.g. 'en', 'ch'.", + default="en", + help="Used to select default prompts based on the primary model language, e.g. 'en', 'cn'.", ) parser.add_argument( "--hf",