forked from microsoft/monitors4codegen
-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_joint_monitors.py
137 lines (118 loc) · 7.04 KB
/
test_joint_monitors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
This file contains tests for Monitor-Guided Decoding running 2 monitors simultaneously
"""
import torch
import transformers
import pytest
from pathlib import PurePath
from monitors4codegen.multilspy.language_server import SyncLanguageServer
from monitors4codegen.multilspy.multilspy_config import Language
from tests.test_utils import create_test_context, is_cuda_available
from transformers import AutoTokenizer, AutoModelForCausalLM
from monitors4codegen.multilspy.multilspy_utils import TextUtils
from monitors4codegen.monitor_guided_decoding.monitors.switch_enum_monitor import SwitchEnumMonitor
from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor
from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer
from monitors4codegen.monitor_guided_decoding.hf_gen import MGDLogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from monitors4codegen.multilspy.multilspy_types import Position
from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper
pytest_plugins = ("pytest_asyncio",)
@pytest.mark.asyncio
@pytest.mark.skip(reason="TODO: This runs too slow. Reimplement joint monitoring")
async def test_multilspy_csharp_ryujinx_joint_switch_enum_dereferences() -> None:
"""
Test the working of Joint monitoring with SwitchEnumMonitor and DereferencesMonitor with C# repository - Ryujinx
"""
code_language = Language.CSHARP
params = {
"code_language": code_language,
"repo_url": "https://github.com/Ryujinx/Ryujinx/",
"repo_commit": "e768a54f17b390c3ac10904c7909e3bef020edbd"
}
device = torch.device('cuda' if is_cuda_available() else 'cpu')
model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained(
"bigcode/santacoder", trust_remote_code=True
).to(device)
tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder")
with create_test_context(params) as context:
lsp1 = SyncLanguageServer.create(context.config, context.logger, context.source_directory)
lsp2 = SyncLanguageServer.create(context.config, context.logger, context.source_directory)
with lsp1.start_server(), lsp2.start_server():
completions_filepath = "src/ARMeilleure/CodeGen/X86/CodeGenerator.cs"
with lsp1.open_file(completions_filepath), lsp2.open_file(completions_filepath):
deleted_text1 = lsp1.delete_text_between_positions(
completions_filepath,
Position(line=224, character=37),
Position(line=243, character=28)
)
deleted_text2 = lsp2.delete_text_between_positions(
completions_filepath,
Position(line=224, character=37),
Position(line=243, character=28)
)
assert deleted_text1 == deleted_text2
assert deleted_text1 == """Intrinsic.X86Comisdlt:
context.Assembler.Comisd(src1, src2);
context.Assembler.Setcc(dest, X86Condition.Below);
break;
case Intrinsic.X86Comisseq:
context.Assembler.Comiss(src1, src2);
context.Assembler.Setcc(dest, X86Condition.Equal);
break;
case Intrinsic.X86Comissge:
context.Assembler.Comiss(src1, src2);
context.Assembler.Setcc(dest, X86Condition.AboveOrEqual);
break;
case Intrinsic.X86Comisslt:
context.Assembler.Comiss(src1, src2);
context.Assembler.Setcc(dest, X86Condition.Below);
break;
"""
filebuffer_enum = MonitorFileBuffer(
lsp1.language_server,
completions_filepath,
(224, 37),
(224, 37),
code_language,
)
monitor_switch_enum = SwitchEnumMonitor(HFTokenizerWrapper(tokenizer), filebuffer_enum)
mgd_logits_processor_switch_enum = MGDLogitsProcessor([monitor_switch_enum], lsp1.language_server.server.loop)
filebuffer_dereferences = MonitorFileBuffer(
lsp2.language_server,
completions_filepath,
(224, 37),
(224, 37),
code_language,
)
monitor_dereferences = DereferencesMonitor(HFTokenizerWrapper(tokenizer), filebuffer_dereferences)
mgd_logits_processor_dereferences = MGDLogitsProcessor([monitor_dereferences], lsp2.language_server.server.loop)
with open(str(PurePath(context.source_directory, completions_filepath)), "r") as f:
filecontent = f.read()
pos_idx = TextUtils.get_index_from_line_col(filecontent, 224, 37)
assert filecontent[:pos_idx].endswith('case ')
prompt = filecontent[:pos_idx]
assert prompt[-1] == " "
prompt_tokenized = tokenizer.encode(prompt, return_tensors="pt").cuda()[:, -(2048 - 512) :]
generated_code_without_mgd = model.generate(
prompt_tokenized, do_sample=False, max_new_tokens=50, early_stopping=True
)
generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -50:])
assert (
generated_code_without_mgd
== " Intrinsic.X86Comisdgt:\n context.Assembler.Comisd(src1, src2);\n context.Assembler.Setcc(dest, X86Condition.GreaterThan);\n break;\n\n case In"
)
# Generate code using santacoder model with the MGD logits processor and greedy decoding
logits_processor = LogitsProcessorList([mgd_logits_processor_switch_enum, mgd_logits_processor_dereferences])
generated_code = model.generate(
prompt_tokenized,
do_sample=False,
max_new_tokens=50,
logits_processor=logits_processor,
early_stopping=True,
)
generated_code = tokenizer.decode(generated_code[0, -50:])
assert (
generated_code
== "Intrinsic.X86Comisdlt:\n context.Assembler.Comisd(src1, src2);\n context.Assembler.Setcc(dest, X86Condition.Below);\n break;\n\n case Intrinsic"
)