-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlistwise_llm.py
More file actions
90 lines (72 loc) · 3.3 KB
/
Copy pathlistwise_llm.py
File metadata and controls
90 lines (72 loc) · 3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import json
import os
import wandb
from tqdm import tqdm
from src.baseline import listwise
from src.evaluate import evaluate
from src.utils import load_dataset, seed_everything
def main(dataset_name, model_name, top_k_passages, args):
if not args.wandb_disable:
configs = dict(vars(args))
run = wandb.init(
project="listwise",
config=configs,
)
else:
run = None
base_path = os.path.dirname(os.path.abspath(__file__))
dataset, _, relevance_map, queries, passages, query_ids, passage_ids, query_embeddings, passage_embeddings = (
load_dataset(base_path, dataset_name, model_name))
if args.llm_name == "unsloth/Qwen3-14B-unsloth-bnb-4bit":
from src.baseline.rankgpt.llm import RankQwen
llm = RankQwen()
elif args.llm_name == "openai/gpt4o":
from src.baseline.rankgpt.llm import RankGPT4o
llm = RankGPT4o()
else:
raise ValueError(f"Unsupported model name: {args.llm_name}")
print("\n")
results = {}
for query, q_id, query_embedding in tqdm(zip(queries, query_ids, query_embeddings), desc=" > RankGPT Reranking", total=len(query_ids)):
pred = listwise(
query=query,
query_embedding=query_embedding,
passages=passages,
passage_ids=passage_ids,
passage_embeddings=passage_embeddings,
llm=llm,
cutoff=top_k_passages,
window_size=args.window_size,
step=args.step,
verbose=args.verbose,
)
results[q_id] = {"pred": pred}
cutoff = [int(k) for k in args.cutoff if int(k) <= top_k_passages]
metric, results = evaluate(results, relevance_map, cutoff, threshold=dataset.relevance_threshold)
if run is not None:
updated_dict = {}
for k, v in metric.items():
new_key = str(k).replace("@", "/")
updated_dict[new_key] = v
wandb.log(updated_dict)
os.makedirs(f"results/{dataset_name}/", exist_ok=True)
with open(f"results/{dataset_name}/{run.name}.json", "w", encoding="utf-8") as f:
json.dump(results, f, indent=1, ensure_ascii=False)
def arg_parser():
parser = argparse.ArgumentParser(description='Reranking with RankGPT')
parser.add_argument('--dataset_name', type=str, required=True, help='dataset name')
parser.add_argument('--llm_name', type=str, default='unsloth/Qwen3-14B-unsloth-bnb-4bit', help='LLM model name')
parser.add_argument('--llm_budget', type=int, default=50, help='top k passages for reranking')
parser.add_argument('--window_size', type=int, default=20, help='window size for RankGPT')
parser.add_argument('--step', type=int, default=10, help='step size for sliding window')
parser.add_argument('--emb_model', type=str, default='all-MiniLM-L6-v2', help='embedding model')
parser.add_argument("--cutoff", type=int, nargs="+", default=[1, 5, 10, 30, 50, 100])
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--wandb_disable", action="store_true", help="disable wandb")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = arg_parser()
seed_everything()
main(dataset_name=args.dataset_name, model_name=args.emb_model, top_k_passages=args.llm_budget, args=args)