-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsmart_nbconvert.py
104 lines (91 loc) · 3.53 KB
/
smart_nbconvert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from openai import OpenAI
from pathlib import Path
import argparse
import json
import nbformat
import os
import re
from prompt import SYSTEM_PROMPT
def get_oai_client() -> OpenAI:
key = os.getenv("OPENAI_API_KEY", None)
if key is None:
raise ValueError("OPENAI_API_KEY must be set")
oai_client = OpenAI(api_key=key)
return oai_client
def get_notebook_state(notebook: str) -> dict:
with open(notebook, "r") as f:
notebook = nbformat.read(f, as_version=4)
print(f"Loaded notebook: {len(notebook.cells)} total cells")
processed_cells = []
images = []
for c, cell in enumerate(notebook.cells):
if cell['cell_type'] != 'code':
processed_cells.append(cell)
else:
temp_cell = {k:v for k, v in cell.items() if k not in ['outputs']}
temp_cell['outputs'] = []
for output in cell['outputs']:
temp_output = {p:q for p, q in output.items() if p not in ['data']}
if 'data' in output:
data_keys = list(output['data'].keys())
if 'image/png' in data_keys:
images.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{output['data']['image/png']}",
"detail": "low"
}
})
temp_output['data'] = {'image_idx': len(images)}
else:
temp_output['data'] = output['data']
temp_cell['outputs'].append(temp_output)
processed_cells.append(temp_cell)
return {'cells': processed_cells, 'images': images}
def get_chat_messages(notebook_state: dict, instructions: str) -> list:
content = []
if instructions != '':
content.append({"type": "text", "text": f"Goal for the report/project: {instructions}"})
content.append({"type": "text", "text": f"<JupyterNotebookState>{json.dumps(notebook_state['cells'])}</JupyterNotebookState>"})
content = content + notebook_state['images']
messages=[
{
"role": "system",
"content": SYSTEM_PROMPT
},
{
"role": "user",
"content": content
}
]
return messages
def replace_image_links(report_content: str, images: list) -> str:
def replace_link(match):
idx = int(match.group(1)) - 1
return f'\n ![image_{idx+1}]({images[idx]["image_url"]["url"]}) \n'
return re.sub(r'<image_idx>(\d+)</image_idx>', replace_link, report_content)
if "__main__" == __name__:
parser = argparse.ArgumentParser()
parser.add_argument("--notebook", "-n", type=str, required=True)
parser.add_argument("--model", "-m", type=str, default="gpt-4o-mini")
parser.add_argument("--instructions", "-i", type=str, default='')
parser.add_argument("--output", "-o", type=str, default=None)
args = parser.parse_args()
oai_client = get_oai_client()
notebook_state = get_notebook_state(args.notebook)
messages = get_chat_messages(notebook_state, args.instructions)
response = oai_client.chat.completions.create(
model=args.model,
messages=messages
)
summary_content = re.search(r'<summary>(.*?)</summary>', response.choices[0].message.content, re.DOTALL).group(1).strip()
report_content = re.search(r'<report>(.*?)</report>', response.choices[0].message.content, re.DOTALL).group(1).strip()
report_content = replace_image_links(report_content, notebook_state['images'])
print(f"\n\nSummary: {summary_content}\n\n")
if args.output is not None:
filepath = Path(args.output)
else:
filepath = Path(args.notebook).parent / (Path(args.notebook).stem + "_report.md")
with open(filepath, "w") as f:
f.write(report_content)
print(f"Report saved to {filepath}")