{
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"id": "89f2b537",
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"from typing import Dict, List, Optional\n",
"from mathruler.grader import extract_boxed_content, grade_answer\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8590ec56",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from pathlib import Path\n",
"from typing import List, Dict, Union\n",
"from typing import Dict, List, Any\n",
"import re\n",
"from typing import List\n",
"\n",
"def read_json(path: Union[str, Path]) -> List[Dict]:\n",
" \"\"\"\n",
" Read a JSON file and return its contents as a list of dicts.\n",
"\n",
" Parameters\n",
" ----------\n",
" path : str or Path\n",
" Path to a JSON file whose root is a JSON array.\n",
"\n",
" Returns\n",
" -------\n",
" List[Dict]\n",
" Each element of the top-level JSON array, parsed into a Python dict.\n",
"\n",
" Raises\n",
" ------\n",
" ValueError\n",
" If the JSON root is not a list.\n",
" json.JSONDecodeError\n",
" If the file is not valid JSON.\n",
" \"\"\"\n",
" path = Path(path).expanduser()\n",
"\n",
" with path.open(\"r\", encoding=\"utf-8\") as f:\n",
" data = json.load(f)\n",
"\n",
" if not isinstance(data, list):\n",
" raise ValueError(f\"{path} does not contain a JSON array at the top level.\")\n",
"\n",
" # (Optional) sanity-check that every item is a dict\n",
" if not all(isinstance(item, dict) for item in data):\n",
" raise ValueError(\"Not every element in the JSON array is an object.\")\n",
"\n",
" return data\n",
"\n",
"\n",
"def extract_description(predict: str) -> Optional[str]:\n",
" \"\"\"\n",
" Extracts the content of the … block from `predict`.\n",
" Returns the inner text (with leading/trailing whitespace stripped),\n",
" or None if no tag is found.\n",
" \"\"\"\n",
" match = re.search(r\"([\\s\\S]*?)\", predict, re.DOTALL)\n",
" if not match:\n",
" return None\n",
" return match.group(1).strip()\n",
"\n",
"\n",
"def accuracy_reward(predict: str, ground_truth: str) -> float:\n",
" answer = extract_boxed_content(predict)\n",
" # answer = extract_answer(predict)\n",
" return 1.0 if grade_answer(answer, ground_truth) else 0.0"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9fb984e7",
"metadata": {},
"outputs": [],
"source": [
"def load_json_dir(root: str | Path, *, verbose: bool = True) -> Dict[str, List[Any]]:\n",
" \"\"\"\n",
" Traverse *root* recursively and return {file_stem: parsed_json_data}.\n",
"\n",
" • Files that are empty or contain invalid JSON are skipped with a warning.\n",
" Set verbose=False to silence the warnings.\n",
" \"\"\"\n",
" root = Path(root).expanduser().resolve()\n",
" out: Dict[str, List[Any]] = {}\n",
"\n",
" for path in root.rglob(\"*.json\"):\n",
" try:\n",
" with path.open(\"r\", encoding=\"utf-8\") as f:\n",
" data = json.load(f)\n",
" out[path.stem] = data\n",
" except json.JSONDecodeError as err:\n",
" if verbose:\n",
" print(f\"[skip] {path} – invalid JSON ({err})\")\n",
" except Exception as err:\n",
" if verbose:\n",
" print(f\"[skip] {path} – {err}\")\n",
"\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c8e29fcb",
"metadata": {},
"outputs": [],
"source": [
"# folder_dir = './gemini-flash'\n",
"folder_dir = './gemini-pro'\n",
"# folder_dir = './gemini-pro-pro'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fad0547b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['realWorldQA', 'clevr_count_70k', 'mmmu-pro', 'mathvision', 'mmstar', 'mmmu-pro-vision', 'mm-vet', 'mmmu_pro_10options', 'mathvista', 'visnumbench'])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datas = load_json_dir(folder_dir)\n",
"\n",
"datas.keys()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e74dd8dd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"realWorldQA: 0.6862745098039216\n",
"clevr_count_70k: 0.7108571428571429\n",
"mmmu-pro: 0.6105527638190955\n",
"mathvision: 0.36875\n",
"mmstar: 0.6633333333333333\n",
"mmmu-pro-vision: 0.5256410256410257\n",
"mm-vet: 0.3302752293577982\n",
"mmmu_pro_10options: 0.49243379571248425\n",
"mathvista: 0.554\n",
"visnumbench: 0.28835978835978837\n"
]
}
],
"source": [
"indices = {}\n",
"\n",
"for file, answers in datas.items():\n",
" indices[file]=[]\n",
" acc = 0\n",
" for index, ele in enumerate(answers):\n",
" solution = ele['solution']\n",
" prediction = ele['predictions'][0]\n",
" accuracy = accuracy_reward(prediction, solution)\n",
" acc += accuracy\n",
" \n",
" if accuracy == 1:\n",
" indices[file].append(index)\n",
" \n",
" print(f'{file}: {acc/len(answers)}')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "99761358",
"metadata": {},
"outputs": [
{
"ename": "KeyError",
"evalue": "'MLLM_rlvr_train'",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mKeyError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28mlen\u001b[39m(\u001b[43mdatas\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mMLLM_rlvr_train\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m]\u001b[49m)\n",
"\u001b[31mKeyError\u001b[39m: 'MLLM_rlvr_train'"
]
}
],
"source": [
"len(datas['MLLM_rlvr_train'])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cb380a0c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['realWorldQA', 'MLLM_hotpot_train', 'mmmu-pro', 'mmstar', 'mm-vet', 'mathvista'])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"indices.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9367bc67",
"metadata": {},
"outputs": [],
"source": [
"realWorldQA: 0.6972477064220184\n",
"mmmu-pro: 0.5646606914212549\n",
"mmstar: 0.6061433447098976\n",
"mm-vet: 0.6018518518518519\n",
"mathvista: 0.5822401614530777"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "08286602",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d033bd06",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f7a73e5",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "84f260ed",
"metadata": {},
"source": [
"# Construct indices to merge datasets"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6c771d63",
"metadata": {},
"outputs": [],
"source": [
"description_folder_dir = './gpt_o1_outputs'\n",
"description_outputs = load_json_dir(description_folder_dir)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d8e03bf5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"891\n",
"2694\n"
]
}
],
"source": [
"file = 'mathvision'\n",
"print(len(description_outputs[file]))\n",
"print(len(datas[file]))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "6a814cba",
"metadata": {},
"outputs": [],
"source": [
"# idx = 1200\n",
"# print(description_outputs[file][idx])\n",
"# print('-'*10)\n",
"# print(datas[file][idx])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7c08cff1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'index': 0,\n",
" 'problem': 'Are there any states with a number of reporters between 376-385? Options:\\nA. No\\nB. Yes',\n",
" 'solution': 'B',\n",
" 'predictions': [' The task is to determine if there are any states with a number of reporters between 376-385. The map shows states in two different colors corresponding to two ranges of reporters: 373-375 and 376-385. I need to identify states in the color representing the range 376-385.\\n\\nThe legend indicates two colors:\\n- A lighter color for the range 373-375.\\n- A darker color for the range 376-385.\\n\\nI will look at the map to identify states shaded with the darker color. The states shaded in the darker color are:\\n- California\\n- North Dakota\\n- South Dakota\\n- Iowa\\n- Missouri\\n- Illinois\\n- Kentucky\\n- West Virginia\\n- New Jersey\\n- Massachusetts\\n\\nThese states are not in the lighter color range, hence they must have a number of reporters between 376-385. Therefore, there are indeed states with a number of reporters in the range 376-385.\\n\\nThe correct option is B. Yes, because there are several states shaded in the darker color that indicates the range 376-385. \\\\boxed{B}']}"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datas['MLLM_hotpot_train'][0]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c3e8619a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MLLM_hotpot_train: 0.2949054259284827; dataset len: 14486\n",
"mathverse: 0.18071065989847715; dataset len: 3940\n"
]
}
],
"source": [
"indices = {}\n",
"\n",
"for file, answers in datas.items():\n",
" # try:\n",
" indices[file]=[]\n",
" # description_data = description_outputs[file]\n",
" acc = 0\n",
" for i, ele in enumerate(answers):\n",
" solution = ele['solution']\n",
" prediction = ele['predictions'][0]\n",
" datas_index = ele['index']\n",
" \n",
" # print(description)\n",
" # break\n",
" accuracy = accuracy_reward(prediction, solution)\n",
" # acc += accuracy\n",
" \n",
" if accuracy == 1:\n",
" # if description is not None:\n",
" indices[file].append(datas_index)\n",
" acc += accuracy\n",
" \n",
" print(f'{file}: {acc/len(answers)}; dataset len: {len(answers)}')\n",
" # except Exception as e:\n",
" # print(f\"Exception caught: {e} for file: {file}\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ca869a96",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exception caught: name 'description_outputs' is not defined for file: MLLM_hotpot_train\n",
"Exception caught: name 'description_outputs' is not defined for file: mathverse\n"
]
}
],
"source": [
"indices = {}\n",
"texts = {}\n",
"for file, answers in datas.items():\n",
" try:\n",
" indices[file]=[]\n",
" texts[file] = []\n",
" description_data = description_outputs[file]\n",
" # ---------- 1) make a hash‑map: index -> description item ----------\n",
" desc_by_idx = {item[\"index\"]: item for item in description_data}\n",
" \n",
" acc = 0\n",
" for i, ele in enumerate(answers):\n",
" solution = ele['solution']\n",
" prediction = ele['predictions'][0]\n",
" data_idx = ele[\"index\"] # the index in the answers item\n",
" \n",
" try:\n",
" desc_item = desc_by_idx.get(data_idx)\n",
" extracted_description = extract_description(desc_item['predictions'][0])\n",
" except:\n",
" extracted_description = None\n",
"\n",
" # print(description)\n",
" # break\n",
" accuracy = accuracy_reward(prediction, solution)\n",
" # acc += accuracy \n",
" \n",
" # print('data: ', description_data)\n",
" # print('-'*10)\n",
" # print('data1: ', ele)\n",
" # break\n",
" \n",
" \n",
" if accuracy == 1:\n",
" if extracted_description is not None:\n",
" indices[file].append(data_idx)\n",
" curr_text = '\\n' + extracted_description + '/n' + prediction\n",
" texts[file].append(curr_text) \n",
" \n",
" acc += accuracy\n",
" \n",
" print(f'{file}: {acc/len(answers)}; dataset len: {len(answers)}')\n",
" except Exception as e:\n",
" print(f\"Exception caught: {e} for file: {file}\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "2d3594e0",
"metadata": {},
"outputs": [],
"source": [
"indices_by_dataset = indices"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "4b0a1872",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"K: realWorldQA; V len: 514\n",
"K: MLLM_hotpot_train; V len: 0\n",
"K: mmmu-pro; V len: 389\n",
"K: mathvision; V len: 328\n",
"K: mmstar; V len: 512\n",
"K: mm-vet; V len: 65\n",
"K: mathvista; V len: 457\n"
]
},
{
"data": {
"text/plain": [
"2265"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"total = 0\n",
"for k, v in indices_by_dataset.items():\n",
" print(f'K: {k}; V len: {len(v)}')\n",
" total += len(v)\n",
" \n",
"total"
]
},
{
"cell_type": "markdown",
"id": "4dba6e3c",
"metadata": {},
"source": [
"### Add it for MLLM hotpot train"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "5d453890",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[skip] /apdcephfs_cq11/share_1603164/user/zongxia/workspace/C-gemini-answers/gemini-flash/clevr_count_70k.json – invalid JSON (Expecting value: line 1 column 1 (char 0))\n",
"14486\n",
"MLLM_hotpot_train: 0.2949054259284827; dataset len: 14486\n",
"3940\n",
"mathverse: 0.18071065989847715; dataset len: 3940\n"
]
},
{
"data": {
"text/plain": [
"4272"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"indices = {}\n",
"\n",
"hotpot_description_folder_dir = './gemini-flash'\n",
"hotpot_description_outs = load_json_dir(hotpot_description_folder_dir)\n",
"\n",
"for file, answers in hotpot_description_outs.items():\n",
" try:\n",
" print(len(answers))\n",
" indices[file]=[]\n",
" texts[file] = []\n",
" acc = 0\n",
" for i, ele in enumerate(answers):\n",
" solution = ele['solution']\n",
" prediction = ele['predictions'][0]\n",
" datas_index = ele['index']\n",
" \n",
" # print(description)\n",
" # break\n",
" accuracy = accuracy_reward(prediction, solution)\n",
" # acc += accuracy\n",
" \n",
" if accuracy == 1:\n",
" indices[file].append(datas_index)\n",
" texts[file].append(prediction)\n",
" acc += accuracy\n",
" \n",
" print(f'{file}: {acc/len(answers)}; dataset len: {len(answers)}')\n",
" except Exception as e:\n",
" print(f\"Exception caught: {e} for file: {file}\")\n",
"\n",
"len(indices['MLLM_hotpot_train'])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "8f4fe74e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"len(idxs) = 14486 min = 0 max = 14485\n",
"missing count : 0\n",
"first 20 gaps : []\n"
]
}
],
"source": [
"idxs = [ele['index'] for ele in hotpot_description_outs['MLLM_hotpot_train']]\n",
"\n",
"\n",
"print(\"len(idxs) =\", len(idxs), \" min =\", min(idxs), \" max =\", max(idxs))\n",
"# → len(idxs) == 6105, min == 0 (maybe), max == 6463\n",
"\n",
"# 2) find every number that *should* be there but isn’t\n",
"expected = set(range(min(idxs), max(idxs) + 1)) # full consecutive range\n",
"missing = sorted(expected - set(idxs))\n",
"\n",
"print(\"missing count :\", len(missing))\n",
"print(\"first 20 gaps :\", missing[:20])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "411dcfc7",
"metadata": {},
"outputs": [],
"source": [
"indices_by_dataset = indices"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "ce4cea20",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_keys(['MLLM_hotpot_train', 'mathverse'])\n",
"dict_keys(['MLLM_hotpot_train', 'mathverse'])\n"
]
}
],
"source": [
"print(indices_by_dataset.keys())\n",
"print(texts.keys())"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "2a3ea275",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4272"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(indices_by_dataset['MLLM_hotpot_train'])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "08197397",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[14471, 14473, 14474, 14476, 14477, 14478, 14480, 14481, 14484, 14485]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"indices_by_dataset['MLLM_hotpot_train'][-10:]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "bd2b91ff",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"filename: zli12321/MLLM_hotpot_train\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Flattening the indices: 100%|██████████| 4272/4272 [00:03<00:00, 1282.44 examples/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"filename: zli12321/mathverse\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Generating test split: 3940 examples [00:00, 13229.68 examples/s]\n",
"Flattening the indices: 100%|██████████| 712/712 [00:00<00:00, 48814.82 examples/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset({\n",
" features: ['problem', 'answer', 'images', 'outputs'],\n",
" num_rows: 4984\n",
"})\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from datasets import load_dataset, concatenate_datasets\n",
"\n",
"BASE_REPO = \"zli12321/\" # prefix for every dataset id\n",
"kept_splits = []\n",
"\n",
"for short_name, keep in indices_by_dataset.items():\n",
" try:\n",
" if not keep: # nothing to keep → skip\n",
" continue\n",
"\n",
" # -----------------------------------------------------------------\n",
" # 1) ensure `keep` and its matching texts are sorted *together*\n",
" # -----------------------------------------------------------------\n",
" idxs = keep\n",
" outs = texts[short_name]\n",
"\n",
" # idxs and outs were built in parallel, so they are aligned.\n",
" # If you want the rows in ascending order, sort both lists together:\n",
" order = sorted(range(len(idxs)), key=idxs.__getitem__)\n",
" idxs = [idxs[i] for i in order] # sorted indices\n",
" outs = [outs[i] for i in order] # matching outputs\n",
"\n",
" # -----------------------------------------------------------------\n",
" # 2) load, slice, and keep only the three original columns\n",
" # -----------------------------------------------------------------\n",
" full_name = f\"{BASE_REPO}{short_name}\"\n",
" \n",
" print(f'filename: {full_name}')\n",
" split = \"train\" if \"MLLM_hotpot_train\" in short_name else \"test\"\n",
"\n",
" ds = load_dataset(full_name, split=split, trust_remote_code=True)\n",
" ds = ds.select(idxs) # keep only those rows\n",
" \n",
" # print(f'filename: {full_name}; len: {len(ds)}')\n",
"\n",
" cols_to_keep = {\"problem\", \"images\", \"answer\"}\n",
" ds = ds.remove_columns([c for c in ds.column_names if c not in cols_to_keep])\n",
"\n",
" # -----------------------------------------------------------------\n",
" # 3) add the new column\n",
" # -----------------------------------------------------------------\n",
" ds = ds.add_column(\"outputs\", outs) # len(outs) == len(ds)\n",
"\n",
" kept_splits.append(ds)\n",
" except Exception as e:\n",
" print(f\"dataset len: {len(ds)}\")\n",
" print(f'{short_name} Failed: {e}')\n",
"\n",
"# ---------------------------------------------------------------------\n",
"# 4) concatenate everything into one big dataset\n",
"# ---------------------------------------------------------------------\n",
"combined = concatenate_datasets(kept_splits)\n",
"\n",
"print(combined) # verify\n",
"# combined.save_to_disk(\"combined.arrow\") # or .to_parquet(...)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "cb8bfe20",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Creating parquet from Arrow format: 100%|██████████| 39/39 [00:17<00:00, 2.18ba/s]\n"
]
},
{
"data": {
"text/plain": [
"909006342"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"combined.to_parquet(\"./hf_upload_train/train.parquet\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "5b7aed77",
"metadata": {},
"outputs": [],
"source": [
"def save_any_image(img_obj, out_base: Path) -> Path:\n",
" \"\"\"\n",
" Save *img_obj* (str | dict | PIL.Image) to disk.\n",
" Returns the *Path* actually written (possibly .png if alpha).\n",
" \"\"\"\n",
" import io, shutil\n",
" from PIL import Image\n",
"\n",
" # 1) resolve a PIL.Image ---------------------------------------------------\n",
" if isinstance(img_obj, str): # already a path\n",
" pil = Image.open(img_obj)\n",
"\n",
" elif isinstance(img_obj, dict): # HF Image feature\n",
" if img_obj.get(\"path\"):\n",
" pil = Image.open(img_obj[\"path\"])\n",
" else:\n",
" pil = Image.open(io.BytesIO(img_obj[\"bytes\"]))\n",
"\n",
" else: # PIL.Image.Image\n",
" pil = img_obj\n",
"\n",
" # 2) choose format & filename ---------------------------------------------\n",
" suffix = \".jpg\"\n",
" img_mode = pil.mode\n",
"\n",
" if img_mode in (\"RGBA\", \"LA\", \"P\"):\n",
" # keep alpha by switching to PNG (or call .convert(\"RGB\") to stay JPEG)\n",
" suffix = \".png\"\n",
"\n",
" out_path = out_base.with_suffix(suffix)\n",
"\n",
" # 3) convert if you insist on JPG without alpha\n",
" if suffix == \".jpg\" and img_mode != \"RGB\":\n",
" pil = pil.convert(\"RGB\")\n",
"\n",
" # 4) write -----------------------------------------------------------------\n",
" pil.save(out_path)\n",
" return out_path\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "358edaa6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"writing images: 100%|██████████| 4984/4984 [14:38<00:00, 5.67it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Done: 4984 items saved.\n"
]
}
],
"source": [
"import os, io, json, shutil\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"from tqdm import tqdm # optional progress bar\n",
"\n",
"# ------------------------------------------------------------------ #\n",
"# directory setup\n",
"# ------------------------------------------------------------------ #\n",
"OUT_DIR = Path(\"sft_description\")\n",
"OUT_DIR.mkdir(exist_ok=True) # creates folder if missing\n",
"\n",
"json_records = []\n",
"\n",
"# ------------------------------------------------------------------ #\n",
"# main loop\n",
"# ------------------------------------------------------------------ #\n",
"for idx, row in enumerate(tqdm(combined, desc=\"writing images\")):\n",
" img_path = save_any_image(row[\"images\"], OUT_DIR / str(idx))\n",
" json_records.append({\n",
" \"messages\": [\n",
" {\"content\": row[\"problem\"], \"role\": \"user\"},\n",
" {\"content\": row[\"outputs\"], \"role\": \"assistant\"}\n",
" ],\n",
" \"images\": [str(img_path)]\n",
" })\n",
"\n",
"# ------------------------------------------------------------------ #\n",
"# write the JSONL / JSON\n",
"# ------------------------------------------------------------------ #\n",
"with open(\"sft_description.json\", \"w\", encoding=\"utf-8\") as f:\n",
" json.dump(json_records, f, ensure_ascii=False, indent=2)\n",
"\n",
"print(f\"✅ Done: {len(json_records)} items saved.\")"
]
},
{
"cell_type": "markdown",
"id": "d4e56b70",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"id": "adc502bc",
"metadata": {},
"source": [
"### Now process the data for Hotpot Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e84f2aa2",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 1,
"id": "54356d4e",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from openai import OpenAI\n",
"from concurrent.futures import ThreadPoolExecutor, as_completed\n",
"from time import sleep\n",
"from typing import List, Dict, Any, Optional\n",
"from openai import OpenAI\n",
"from __future__ import annotations\n",
"import json\n",
"from pathlib import Path\n",
"from typing import Any, Dict, Iterable, List, Union"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5caaaa06",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'index': 0,\n",
" 'problem': 'Are there any states with a number of reporters between 376-385? Options:\\nA. No\\nB. Yes',\n",
" 'solution': 'B',\n",
" 'predictions': ['The image is a map of the United States, with each state colored according to the number of reporters in that state. The title of the map is \"The Number of reporters in the USA\". There is a legend in the bottom right corner. States colored in a light beige color have between 373-375 reporters. States colored in a dark purple color have between 376-385 reporters. Several states are colored dark purple, including Washington, Montana, North Dakota, South Dakota, Iowa, Missouri, Louisiana, Utah, Nevada, California, Virginia, Maryland, and New Hampshire. Alaska and Hawaii are also shown. \\nThe question asks if there are any states with a number of reporters between 376-385. The legend indicates that states with 376-385 reporters are colored dark purple. The map shows several states colored dark purple. Therefore, the answer is yes. \\n\\\\boxed{Yes}']}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data[0]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}