Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions oellm/contrib/regiondial_bench/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def to_contrib_flags(self) -> str | None:
name = Path(self._path).name.lower()
if "regionreasoner" in name or "region_reasoner" in name:
return "vision_reasoner"
if "qwen2.5" in name:
return "qwen2.5"
if "qwen2" in name:
return "qwen2"
if "qwen" in name:
Expand Down
4 changes: 2 additions & 2 deletions oellm/contrib/regiondial_bench/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Metrics
-------
- **GIoU**: mean of per-sample mask IoU (intersection / union).
- **CIoU**: cumulative IoU — sum of all intersections / sum of all unions.
- **CIoU**: sum of all intersections / sum of all unions.
- **BboxAP**: fraction of samples where bbox IoU > 0.5.
- **PassRate**: fraction of samples where mask IoU > *threshold*.
"""
Expand Down Expand Up @@ -78,7 +78,7 @@ def compute(self, predictions: list[str], references: list[str]) -> float:


class CIoU(BaseMetric):
"""Cumulative IoU (cIoU as reported in RegionDial-Bench).
"""cIoU as reported in RegionDial-Bench.

Formula: ``sum(all intersections) / sum(all unions)``.
"""
Expand Down
40 changes: 20 additions & 20 deletions oellm/contrib/regiondial_bench/suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,10 @@ def run(
"--num_parts",
"1",
"--batch_size",
"2",
"1",
"--task_router_model_path",
"Ricky06662/TaskRouter-1.5B",
"--binarize_bbox_iou",
]
logger.info("Starting shard %d/%d: %s", idx + 1, num_gpus, " ".join(cmd))
proc = subprocess.Popen(cmd, env=shard_env, cwd=str(Path(test_json).parent))
Expand Down Expand Up @@ -356,27 +357,26 @@ def _aggregate_shards(shard_dir: str) -> dict[str, float]:
metrics[m.name] = val
logger.debug("%s = %.4f", m.name, val)

# Infer per-round membership by counting each image_id's occurrence order in
# the output (mirrors calculate_iou_with_bbox_by_turns.py). The inference
# script emits turns in sequential order per image, so the k-th time an
# image_id appears corresponds to turn k (1-indexed).
rounds_map: dict[int, list[str]] = defaultdict(list)
image_turn_counter: dict[str, int] = {}
for sample_dict, sample_str in zip(all_samples, samples, strict=True):
rnd = sample_dict.get("round")
if rnd is not None:
rounds_map[int(rnd)].append(sample_str)

if rounds_map:
per_round_metrics = [GIoU(), BboxAP()]
for rnd in sorted(rounds_map):
rnd_samples = rounds_map[rnd]
rnd_refs = [""] * len(rnd_samples)
for m in per_round_metrics:
val = m.compute(rnd_samples, rnd_refs)
metrics[f"{m.name}_R{rnd}"] = val
logger.debug("%s_R%d = %.4f", m.name, rnd, val)
else:
logger.warning(
"No 'round' field found in samples — skipping per-round breakdown. "
"Per-round metrics (R1–R7) require the inference script to output "
"a 'round' field in each sample."
)
image_id = str(sample_dict.get("image_id", ""))
image_turn_counter[image_id] = image_turn_counter.get(image_id, 0) + 1
rnd = image_turn_counter[image_id]
rounds_map[rnd].append(sample_str)

per_round_metrics = [GIoU(), BboxAP()]
for rnd in sorted(rounds_map):
rnd_samples = rounds_map[rnd]
rnd_refs = [""] * len(rnd_samples)
for m in per_round_metrics:
val = m.compute(rnd_samples, rnd_refs)
metrics[f"{m.name}_R{rnd}"] = val
logger.debug("%s_R%d = %.4f", m.name, rnd, val)

return metrics

Expand Down
29 changes: 16 additions & 13 deletions tests/test_regiondial_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def test_detect_model_flags_region_reasoner_model(self, suite):
assert suite.detect_model_flags("lmsdss/RegionReasoner-7B") == "vision_reasoner"

def test_detect_model_flags_qwen2_model(self, suite):
assert suite.detect_model_flags("Qwen/Qwen2.5-VL-7B-Instruct") == "qwen2"
assert suite.detect_model_flags("Qwen/Qwen2.5-VL-7B-Instruct") == "qwen2.5"

def test_detect_model_flags_qwen1_model(self, suite):
assert suite.detect_model_flags("Qwen/Qwen-VL-Chat") == "qwen"
Expand Down Expand Up @@ -476,7 +476,7 @@ def test_contrib_flags_region_reasoner(self, adapter_cls):

def test_contrib_flags_qwen2(self, adapter_cls):
cls, _ = adapter_cls
assert cls("Qwen/Qwen2.5-VL-7B").to_contrib_flags() == "qwen2"
assert cls("Qwen/Qwen2.5-VL-7B").to_contrib_flags() == "qwen2.5"

def test_contrib_flags_qwen(self, adapter_cls):
cls, _ = adapter_cls
Expand All @@ -490,7 +490,7 @@ def test_detect_model_flags_delegates_to_adapter(self):
import oellm.contrib.regiondial_bench.suite as s

assert s.detect_model_flags("lmsdss/RegionReasoner-7B") == "vision_reasoner"
assert s.detect_model_flags("Qwen/Qwen2.5-VL-7B") == "qwen2"
assert s.detect_model_flags("Qwen/Qwen2.5-VL-7B") == "qwen2.5"


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -651,17 +651,20 @@ def test_empty_shard_raises(self, tmp_path):
_aggregate_shards(str(tmp_path))

def test_per_round_metrics_present(self, tmp_path):
"""Samples with 'round' field produce per-round gIoU and bbox_AP keys."""
"""Two images with two turns each produce per-round gIoU and bbox_AP keys."""
from oellm.contrib.regiondial_bench.suite import _aggregate_shards

# Turns are consecutive per image (img1 T1, img1 T2, img2 T1, img2 T2).
# The turn counter assigns: first occurrence of each image_id → R1,
# second occurrence → R2.
self._write_shard(
tmp_path,
0,
[
{"intersection": 100, "union": 100, "bbox_iou": 1.0, "round": 1},
{"intersection": 50, "union": 100, "bbox_iou": 0.6, "round": 1},
{"intersection": 0, "union": 100, "bbox_iou": 0.0, "round": 2},
{"intersection": 80, "union": 100, "bbox_iou": 0.8, "round": 2},
{"image_id": "img1", "intersection": 100, "union": 100, "bbox_iou": 1.0},
{"image_id": "img1", "intersection": 0, "union": 100, "bbox_iou": 0.0},
{"image_id": "img2", "intersection": 50, "union": 100, "bbox_iou": 0.6},
{"image_id": "img2", "intersection": 80, "union": 100, "bbox_iou": 0.8},
],
)
m = _aggregate_shards(str(tmp_path))
Expand All @@ -679,18 +682,18 @@ def test_per_round_metrics_present(self, tmp_path):
# R2 bbox_AP: one >0.5 (0.8), one =0.0 → 0.5
assert m["bbox_AP_R2"] == pytest.approx(0.5)

def test_per_round_metrics_absent_without_round_field(self, tmp_path):
"""Samples without 'round' field produce no per-round keys."""
def test_per_round_metrics_always_present(self, tmp_path):
"""Per-round keys are always produced — turns are inferred from image_id order."""
from oellm.contrib.regiondial_bench.suite import _aggregate_shards

self._write_shard(
tmp_path,
0,
[{"intersection": 100, "union": 100, "bbox_iou": 1.0}],
[{"image_id": "img1", "intersection": 100, "union": 100, "bbox_iou": 1.0}],
)
m = _aggregate_shards(str(tmp_path))
round_keys = [k for k in m if "_R" in k]
assert round_keys == []
assert "gIoU_R1" in m
assert "bbox_AP_R1" in m

def test_per_round_metrics_seven_rounds(self, tmp_path):
"""All 7 rounds produce per-round metrics when present."""
Expand Down
Loading