-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_graph.py
More file actions
224 lines (189 loc) · 9.3 KB
/
Copy pathrun_graph.py
File metadata and controls
224 lines (189 loc) · 9.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import os
import json
import asyncio
import argparse
import numpy as np
from api_call import ask_gpt4_async, token_cost
from prompts import create_sys_prompts, create_user_prompts
from utils import get_paper_text
from graphs import permute_knowledge_graph
def sampling_permutations(knowledge_graph_permutations_i, num_samples=10):
"""
Sample a fixed number of permutations.
Args:
knowledge_graph_permutations_i (dict): dictionary of permutations for a single experiment
"""
np.random.seed(42)
permutes = np.array(list(knowledge_graph_permutations_i.items()))
if permutes.shape[0] < num_samples:
return permutes
permutes = permutes[np.random.choice(permutes.shape[0], num_samples, replace=False)]
# Sort the order of permutations in increasing order to be
# consistent with the other permutation-related fields in json.
permutes = sorted(permutes, key=lambda x: int(x[0]))
return permutes
def has_n_experiments(knowledge_graph, num_experiments=1):
"""
Args:
knowledge_graph (dict): knowledge graph dictionary.
num_experiment (int): number of experiment to filter by.
Returns:
flag (bool): True if the number of experiment is equal to num_experiment.
"""
return len(knowledge_graph) == num_experiments
def postprocess_json(response_json):
# 1. Get response
response_content = response_json["choices"][0]["message"]["content"]
# 2. Remove code block
response_content = response_content.strip('```').lstrip('json')
# 3. Replace escape characters
response_content = response_content.replace('\\', '\\\\')
# 4. Load as json
print('response_content', response_content)
response_content = json.loads(response_content)
return response_content
async def process_paper(fname, txt_dir, outputs_dir, prompt_versions, num_samples):
fpath = os.path.join(txt_dir, fname)
paper_text = get_paper_text.load(fpath)
outputs = {}
outputs_fpath = os.path.join(outputs_dir, fname.replace('.txt', '.json'))
if os.path.exists(outputs_fpath):
print(f"Skipping {fname}, because outputs already exists")
return
total_cost = {}
sys_prompt = create_sys_prompts.prompts()
# Summarize methods
print(f"\n# Summarizing methods for {fname}..")
user_prompt = create_user_prompts.summarize_methods(paper_text, v=prompt_versions["sum_met"])
response_json = json.loads(await ask_gpt4_async(sys_prompt, user_prompt))
response_content = postprocess_json(response_json)
n_input_tokens = int(response_json["usage"]["prompt_tokens"])
n_output_tokens = int(response_json["usage"]["completion_tokens"])
outputs["methods"] = response_content["methods"]
cost = token_cost(n_input_tokens, n_output_tokens)
total_cost["sum_met"] = cost
# Step 1: Create initial knowledge graph conditioned on full paper.
print(f"\n# Creating initial knowledge graph for {fname}..")
kg_creator = create_user_prompts.KnowledgeGraphCreator(paper_text)
user_prompt = kg_creator.create_initial_kg(v=prompt_versions["kg_init"])
response_json = json.loads(await ask_gpt4_async(sys_prompt, user_prompt))
response_content = postprocess_json(response_json)
n_input_tokens = int(response_json["usage"]["prompt_tokens"])
n_output_tokens = int(response_json["usage"]["completion_tokens"])
cost = token_cost(n_input_tokens, n_output_tokens)
outputs["knowledge_graph"] = response_content["knowledge_graph"]
total_cost["res_to_kg"] = cost
# Step 2-5: Convert KG to text, identify semantic groups, and permute KGs
# For now, we only proceed with papers have 1 experiment.
if has_n_experiments(outputs["knowledge_graph"], num_experiments=1):
# Step 2: Convert original KG to text (per experiment)
print(f"\n# Converting knowledge graph to text for {fname}..")
outputs["results"] = {}
n_input_tokens = 0
n_output_tokens = 0
for experiment_i in range(1, len(outputs["knowledge_graph"]) + 1):
user_prompt = kg_creator.convert_kg_to_text_single_experiment(
outputs["knowledge_graph"][f'experiment_{experiment_i}'],
v=prompt_versions["kg_txt"]
)
response_json = json.loads(await ask_gpt4_async(sys_prompt, user_prompt))
response_content = postprocess_json(response_json)
n_input_tokens += int(response_json["usage"]["prompt_tokens"])
n_output_tokens += int(response_json["usage"]["completion_tokens"])
outputs["results"][f'experiment_{experiment_i}'] = response_content["results"]
cost = token_cost(n_input_tokens, n_output_tokens)
total_cost["kg_to_text"] = cost
# Step 3: Identify semantic groups
print(f"\n# Identifying semantic groups for {fname}..")
user_prompt = kg_creator.identify_semantic_groups(
outputs["knowledge_graph"],
v=prompt_versions["kg_sema"]
)
response_json = json.loads(await ask_gpt4_async(sys_prompt, user_prompt))
response_content = postprocess_json(response_json)
n_input_tokens = int(response_json["usage"]["prompt_tokens"])
n_output_tokens = int(response_json["usage"]["completion_tokens"])
outputs["semantic_groups"] = response_content["semantic_groups"]
cost = token_cost(n_input_tokens, n_output_tokens)
total_cost["kg_to_semantic_groups"] = cost
# Step 4: Create permuted knowledge graphs
print(f"\n# Creating permuted knowledge graphs for {fname}..")
knowledge_graph_permutations, node_swaps_tracker, triple_deviation_pct \
= permute_knowledge_graph.create_permutations(
outputs["knowledge_graph"],
outputs["semantic_groups"],
)
outputs["knowledge_graph_permutations"] = knowledge_graph_permutations
outputs["node_swaps_tracker"] = node_swaps_tracker
outputs["triple_deviation_pct"] = triple_deviation_pct
# Step 5: Convert permuted KGs to text (per experiment and per permutation)
print(f"\n# Converting permuted knowledge graphs to text for {fname}..")
outputs["results_permutations"] = {}
n_input_tokens_kg_to_text = 0
n_output_tokens_kg_to_text = 0
num_graph_permutations = {}
for experiment_i, knowledge_graph_permutations_i in knowledge_graph_permutations.items():
outputs["results_permutations"][experiment_i] = {}
num_graph_permutations[experiment_i] = len(knowledge_graph_permutations_i)
for permutation_i, knowledge_graph_i in sampling_permutations(knowledge_graph_permutations_i, num_samples):
print(f" Converting permutation {permutation_i} for {experiment_i}, {fname}..")
user_prompt = kg_creator.convert_kg_to_text_single_experiment(
knowledge_graph_i, v=prompt_versions["kg_txt"],
orig_results_as_example=outputs["results"][experiment_i]
)
response_json = json.loads(await ask_gpt4_async(sys_prompt, user_prompt))
response_content = postprocess_json(response_json)
n_input_tokens_kg_to_text += int(response_json["usage"]["prompt_tokens"])
n_output_tokens_kg_to_text += int(response_json["usage"]["completion_tokens"])
outputs["results_permutations"][experiment_i][permutation_i] = response_content["results"]
# Record number of permutations
num_graph_permutations["total"] = sum(num_graph_permutations.values())
outputs["num_graph_permutations"] = num_graph_permutations
# Calculate costs
kg_to_text_cost = token_cost(n_input_tokens_kg_to_text, n_output_tokens_kg_to_text)
total_cost["kg_permutes_to_text"] = kg_to_text_cost
total_cost["total"] = sum(total_cost.values())
outputs["token_cost"] = total_cost
with open(outputs_fpath, 'w') as f:
json.dump(outputs, f, indent=4)
else:
print(f"{fname} has more than 1 experiments, ignored.")
async def main():
txt_dir = 'data/txt_articles'
outputs_dir = 'outputs/fullpaper'
if prompt_versions is not None: # version the summed articles based on prompt versions.
for k, v in prompt_versions.items():
outputs_dir += f"_{k}{v}"
if not os.path.exists(outputs_dir):
os.makedirs(outputs_dir)
async_tasks = []
for fname in os.listdir(txt_dir):
if not fname.endswith('.txt'):
print(f"Skipping {fname}")
continue
async_tasks.append(
process_paper(
fname,
txt_dir,
outputs_dir,
prompt_versions,
num_samples
)
)
await asyncio.gather(*async_tasks)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--sum_met", type=int, default=2)
parser.add_argument("--kg_init", type=int, default=2)
parser.add_argument("--kg_sema", type=int, default=2)
parser.add_argument("--kg_txt", type=int, default=3)
args = parser.parse_args()
prompt_versions = {
"sum_met": args.sum_met,
"kg_init": args.kg_init,
"kg_sema": args.kg_sema,
"kg_txt": args.kg_txt
}
num_samples = 10
print(f"Prompt versions: {prompt_versions}")
asyncio.run(main())