Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement] Better Implementation for CircularEval #770

Merged
merged 3 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,28 @@ demo.ipynb
*json
.vscode
*.swp
GPT4o_MINI/
GPT4o_MINI/

2weiyun*
script.py
Gemini*
Claude3-5V*
GLM4V*
GPT4o*
GPT4V*
mmmu_debug
bailingMM
BailingMM*
SenseChat*
Step*
DoubaoVL
arch
BlueLM*
mmb_*
Reka*
Taiyi
TeleMM
apple.jpg
assets/LOGO.png
api_list.txt
vlmeval/gemini_tmp.py
15 changes: 10 additions & 5 deletions vlmeval/dataset/utils/multiple_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,32 +406,37 @@ def mcq_circular_eval(model, data, meta, nproc, result_file, dataset_name=None):

for idx in list(meta['index']) + list(data['index']):
assert istype(idx, int)
if 'g_index' not in data:
data['g_index'] = [int(x % 1e6) for x in data['index']]

# Only keep those lines in the meta data
data = data[data['index'].isin(answer_map)]
data['GT'] = [answer_map[idx] for idx in data['index']]
data_main = data[data['index'] < int(1e6)]


data['tmp_flag'] = [x == y for x, y in zip(data['index'], data['g_index'])]
data_main = data[data['tmp_flag']]
data_main.pop('tmp_flag')

data_groups = []
for i in range(len(data_main)):
# Dealing with the normal part
idx = data_main.iloc[i]['index']
if idx not in result:
sub_data = data[data['index'] % int(1e6) == idx]
sub_data = data[data['g_index'] == idx]
data_groups.append(sub_data)

if len(data_groups):
prefetched = [prefetch_circular_group(g, verbose=False) for g in data_groups]
remain = []
for dg, pf in zip(data_groups, prefetched):
if pf is not None:
result[dg.iloc[0]['index'] % 1e6] = pf
result[dg.iloc[0]['g_index']] = pf
else:
remain.append(dg)
dump(result, result_file)

tups = [dict(model=model, sub_data=x, dataset_name=dataset_name) for x in remain]
keys = [x.iloc[0]['index'] % 1e6 for x in remain]
keys = [x.iloc[0]['g_index'] for x in remain]

if len(tups) == 0:
pass
Expand Down
33 changes: 0 additions & 33 deletions vlmeval/smp/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,36 +144,3 @@ def gpt_key_set():
def apiok(wrapper):
s = wrapper.generate('Hello!')
return wrapper.fail_msg not in s


def circular_pred(df, extract_func=None):
if extract_func is None:
extract_func = lambda x: x # noqa: E731
df = df.sort_values('index')
from vlmeval.utils import can_infer_option

shift = int(1e6)

choices = [extract_func(x) for x in df['prediction']]
pred_map = {i: c for i, c in zip(df['index'], choices)}
flag_map = {i: True for i in pred_map if i < 1e6}
valid_map = {i: True for i in pred_map if i < 1e6}
for i in df['index']:
if i >= shift and pred_map[i] and pred_map[i - shift]:
if pred_map[i] not in list(
string.ascii_uppercase
) or pred_map[ # noqa: W504
i - shift
] not in list(
string.ascii_uppercase
):

valid_map[i % shift] = False
continue
if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1:
continue
else:
flag_map[i % shift] = False
flag_map = {k: v for k, v in flag_map.items() if valid_map[k]}
flags = list(flag_map.values())
return np.mean(flags)
165 changes: 91 additions & 74 deletions vlmeval/tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from collections import deque
from vlmeval.dataset import SUPPORTED_DATASETS
from vlmeval.config import *
from vlmeval.smp import *
Expand Down Expand Up @@ -157,88 +158,104 @@ def MISSING(lvl):


def CIRCULAR(inp):
def proc_str(s):
chs = set(s)
chs = [x for x in chs if x not in string.ascii_letters and x != ' ']
for ch in chs:
s = s.replace(ch, ' ')
return s

def abnormal_entry(line):
choices = {k: line[k] for k in string.ascii_uppercase if k in line and not pd.isna(line[k])}
for k in choices:
s = proc_str(choices[k]).split()
hit_words = [x for x in s if x in choices]
hit_words = set(hit_words)
if len(hit_words) > 1:
return True
return False

assert inp.endswith('.tsv')
data = load(inp)
OFFSET = 1e6
while max(data['index']) >= OFFSET:
OFFSET *= 10

assert 'E' not in data, 'Currently build_circular only works for up to 4-choice questions'
data_2c = data[pd.isna(data['C'])]
data_3c = data[~pd.isna(data['C']) & pd.isna(data['D'])]
data_4c = data[~pd.isna(data['D'])]
map_2c = [('AB', 'BA')]
map_3c = [('ABC', 'BCA'), ('ABC', 'CAB')]
map_4c = [('ABCD', 'BCDA'), ('ABCD', 'CDAB'), ('ABCD', 'DABC')]

def okn(o, n=4):
ostr = o.replace(',', ' ')
osplits = ostr.split()
if sum([c in osplits for c in string.ascii_uppercase[:n - 1]]) == n - 1:
return False
olower = o.lower()
olower = olower.replace(',', ' ')
olower_splits = olower.split()
if 'all' in olower_splits or 'none' in olower_splits:
return False
return True

yay4, nay4 = [], []
lt4 = len(data_4c)
for i in range(lt4):
if okn(data_4c.iloc[i]['D'], 4):
yay4.append(i)
n_opt = 2
for i, ch in enumerate(string.ascii_uppercase):
if ch in data:
n_opt = ord(ch) - ord('A') + 1
else:
nay4.append(i)
data_4c_y = data_4c.iloc[yay4]
data_4c_n = data_4c.iloc[nay4]
data_3c = pd.concat([data_4c_n, data_3c])

yay3, nay3 = [], []
lt3 = len(data_3c)
for i in range(lt3):
if okn(data_3c.iloc[i]['C'], 3):
yay3.append(i)
for j in range(i + 1, 26):
assert string.ascii_uppercase[j] not in data
groups = defaultdict(list)
for i in range(len(data)):
item = data.iloc[i]
this_n_opt = 0
for j, ch in enumerate(string.ascii_uppercase[:n_opt]):
if not pd.isna(item[ch]):
this_n_opt = j + 1
else:
for k in range(j + 1, n_opt):
assert pd.isna(item[string.ascii_uppercase[k]]), (k, item)
assert this_n_opt >= 2 or this_n_opt == 0
flag = abnormal_entry(item)
if flag or this_n_opt == 0:
groups['abnormal'].append(item)
elif ord(item['answer']) - ord('A') + 1 > this_n_opt:
groups['abnormal'].append(item)
else:
nay3.append(i)
data_3c_y = data_3c.iloc[yay3]
data_3c_n = data_3c.iloc[nay3]
data_2c = pd.concat([data_3c_n, data_2c])

def remap(data_in, tup, off):
off = int(off)
data = data_in.copy()
char_map = {k: v for k, v in zip(*tup)}
idx = data.pop('index')
answer = data.pop('answer')
answer_new = [char_map[x] if x in char_map else x for x in answer]
data['answer'] = answer_new
options = {}
for c in char_map:
options[char_map[c]] = data.pop(c)
for c in options:
data[c] = options[c]
data.pop('image')
data['image'] = idx
idx = [x + off for x in idx]
data['index'] = idx
return data

data_all = pd.concat([
data_2c,
data_3c_y,
data_4c_y,
remap(data_2c, map_2c[0], OFFSET),
remap(data_3c_y, map_3c[0], OFFSET),
remap(data_4c_y, map_4c[0], OFFSET),
remap(data_3c_y, map_3c[1], OFFSET * 2),
remap(data_4c_y, map_4c[1], OFFSET * 2),
remap(data_4c_y, map_4c[2], OFFSET * 3),
])

tgt_file = inp.replace('.tsv', '_CIRC.tsv')
groups[this_n_opt].append(item)
for k in groups:
groups[k] = pd.concat(groups[k], axis=1).T
print(f'{k if k == "abnormal" else str(k) + "-choice"} records: {len(groups[k])}')

data_all = []

for k in groups:
if k == 'abnormal':
warnings.warn(
f"{len(groups['abnormal'])} abnormal entries detected. The problems can be: "
"1. Choice labels found in some choice contents; 2. No choices found for this question; "
"3. The answer is not a valid choice. Will not apply circular to those samples."
)
abdata = groups['abnormal']
abdata['g_index'] = abdata['index']
data_all.append(abdata)
else:
cir_data = []
assert isinstance(k, int) and k >= 2
labels = string.ascii_uppercase[:k]
rotates = [labels]
dq = deque(labels)
for i in range(k - 1):
dq.rotate(1)
rotates.append(list(dq))
for i, rot in enumerate(rotates):
if i == 0:
data = groups[k].copy()
data['g_index'] = data['index']
cir_data.append(data)
else:
try:
data = groups[k].copy()
data['index'] = [x + OFFSET * i for x in data['index']]
data['g_index'] = [x % OFFSET for x in data['index']]
c_map = {k: v for k, v in zip(rotates[0], rot)}
data['answer'] = [c_map[x] for x in data['answer']]
for s, t in c_map.items():
data[t] = groups[k][s]
cir_data.append(data)
except:
print(set(data['answer']))
raise NotImplementedError
data_all.append(pd.concat(cir_data))
data_all = pd.concat(data_all)
data_all['index'] = [int(x) for x in data_all['index']]
data_all['g_index'] = [int(x) for x in data_all['g_index']]

tgt_file = inp.replace('.tsv', '_circular.tsv')
dump(data_all, tgt_file)
print(f'The circularized data is saved to {tgt_file}')
print(f'Processed data are saved to {tgt_file}: {len(load(inp))} raw records, {len(data_all)} circularized records.')
assert osp.exists(tgt_file)
print(f'The MD5 for the circularized data is {md5(tgt_file)}')

Expand Down