Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Kosmos-2.5 #31711

Open
wants to merge 341 commits into
base: main
Choose a base branch
from
Open

Support Kosmos-2.5 #31711

wants to merge 341 commits into from

Conversation

tic-top
Copy link

@tic-top tic-top commented Jun 29, 2024

What does this PR do?

#30877 Implementation of Kosmos-2.5 in transformers.
https://huggingface.co/kirp/kosmos2_5/blob/main/README.md

Usage

from PIL import Image
import requests
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq, AutoConfig
import re

repo = "kirp/kosmos2_5"
device = "cuda:0"
config = AutoConfig.from_pretrained(repo)

NAME = {
    "f" : "flash_attention_2",
    "s" : "sdpa",
    "e" : "eager",
}

# all sdpa fp16
dtype = torch.float16
config._attn_implementation = NAME["s"]
config.vision_config._attn_implementation = NAME["s"]
config.text_config._attn_implementation = NAME["s"]

# # all sdpa fp16
# dtype = torch.float16
# config._attn_implementation = NAME["s"]
# config.text_config._attn_implementation = NAME["s"]
# config.vision_config._attn_implementation = NAME["s"]

# # all eager bf16
# dtype = torch.bfloat16
# config._attn_implementation = NAME["e"]
# config.text_config._attn_implementation = NAME["e"]
# config.vision_config._attn_implementation = NAME["e"]


model = AutoModelForVision2Seq.from_pretrained(repo, device_map = device, torch_dtype=dtype, config=config)
processor = AutoProcessor.from_pretrained(repo)

url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<ocr>" # <md>

inputs = processor(text=prompt, images=image, return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width

inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)

generated_ids = model.generate(
    **inputs,
    max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

def postprocess(y, scale_height, scale_width):
    y = y.replace(prompt, "")
    if "<md>" in prompt:
        return y
    pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
    bboxs_raw = re.findall(pattern, y)
    lines = re.split(pattern, y)[1:]
    bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
    bboxs = [[int(j) for j in i] for i in bboxs]
    info = ""
    for i in range(len(lines)):
        box = bboxs[i]
        x0, y0, x1, y1 = box
        if not (x0 >= x1 or y0 >= y1):
            x0 = int(x0 * scale_width)
            y0 = int(y0 * scale_height)
            x1 = int(x1 * scale_width)
            y1 = int(y1 * scale_height)
            info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}"
    return info

output_text = postprocess(generated_text[0], scale_height, scale_width)
print(output_text)

@amyeroberts
Copy link
Collaborator

cc @ydshieh

@ydshieh ydshieh self-assigned this Jul 1, 2024
@ydshieh
Copy link
Collaborator

ydshieh commented Jul 9, 2024

Thanks a lot for this hard work @tic-top. This is going to benefit the community 🤗! Check this tomorrow!

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 10, 2024

I will push an empty commit to trigger a CI running on GPU 🙏

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. Apologized for being late.

My left one quick question and I will focus on reviewing this PR tomorrow.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 16, 2024

No worry about the failing CI above. It's fine, and I checked the tests running on a A10 that is ✅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(for me): I will revert this

benchmark/benchmark.py Outdated Show resolved Hide resolved
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/kosmos_2_overview.jpg"
alt="drawing" width="600"/>

<small> Overview of tasks that KOSMOS-2 can handle. Taken from the <a href="https://arxiv.org/abs/2306.14824">original paper</a>. </small>
Copy link
Collaborator

@ydshieh ydshieh Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't keep the above line image, we should remove this too. Otherwise KOSMOS-2 -> KOSMOS-2.5

Comment on lines 35 to 81
from PIL import Image
import requests
import torch
from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
import re
repo = "microsoft/kosmos-2.5"
device = "cuda:0"
dtype = torch.bfloat16
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<ocr>" # <md>
inputs = processor(text=prompt, images=image, return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width
inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
def postprocess(y, scale_height, scale_width):
y = y.replace(prompt, "")
if "<md>" in prompt:
return y
pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
bboxs_raw = re.findall(pattern, y)
lines = re.split(pattern, y)[1:]
bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
bboxs = [[int(j) for j in i] for i in bboxs]
info = ""
for i in range(len(lines)):
box = bboxs[i]
x0, y0, x1, y1 = box
if not (x0 >= x1 or y0 >= y1):
x0 = int(x0 * scale_width)
y0 = int(y0 * scale_height)
x1 = int(x1 * scale_width)
y1 = int(y1 * scale_height)
info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}"
return info
output_text = postprocess(generated_text[0], scale_height, scale_width)
print(output_text)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: (not necessary)

Might be nice / interesting to refer to

https://github.com/microsoft/unilm/blob/master/kosmos-2.5/draw_bbox.py

and attach a screenshot of the output images.

Copy link
Author

@tic-top tic-top Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is easier to use.
The python file above need to convert the str to json first, then draw.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. But if I understand correctly, this only gives the string, but people are more interested to see the final images with bounding boxes or the structured MD layout.

I am not saying to use draw_bbox.py in this documentation. Just mention that there is a such file to draw things and give the link as a reference.

If you have any consideration not to mention, I am OK not to have it here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean something like How to use?

Comment on lines +287 to +292
seqlen = self.model_tester.text_model_tester.seq_length
inputs_dict["input_ids"] = torch.arange(seqlen, device=torch_device).unsqueeze(0).expand(bs, seqlen)
inputs_dict["input_ids"] = inputs_dict["input_ids"] % self.model_tester.text_model_tester.vocab_size
inputs_dict["attention_mask"] = torch.ones((bs, seqlen), device=torch_device)
inputs_dict["image_embeds_position_mask"] = torch.zeros((bs, seqlen), device=torch_device)
inputs_dict["image_embeds_position_mask"][:, : self.model_tester.latent_query_num] = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary only to adjust the batch size? I think the default value will give the same batch size.

(For Kosmos-2, I don't need this extra block to adjust anything)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I test it, it returns unexpected result.

Comment on lines 561 to 611
class Kosmos2_5ModelIntegrationTest(unittest.TestCase):
def run_example(self, prompt, image, model, processor):
print("Prompt:", prompt)
inputs = processor(text=prompt, images=image, return_tensors="pt")
_, _ = inputs.pop("height"), inputs.pop("width")
inputs = {k: v.to(torch_device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(model.dtype)

generation_outputs = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_ids = generation_outputs
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

return generated_ids, generated_text

def test_receipt_image_ocr(self):
url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)

dtype = torch.bfloat16
repo = "microsoft/kosmos-2.5"
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=torch_device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
prompt = "<ocr>"
generated_ids, generated_text = self.run_example(prompt, image, model, processor)

EXPECTED_TEXT = [
"<ocr><bbox><x_53><y_573><x_69><y_606></bbox>1\n<bbox><x_79><y_573><x_464><y_611></bbox>[REG] BLACK SAKURA\n<bbox><x_690><y_569><x_810><y_606></bbox>45,455\n<bbox><x_53><y_614><x_69><y_648></bbox>1\n<bbox><x_79><y_614><x_468><y_650></bbox>COOKIE DOH SAUCES\n<bbox><x_788><y_609><x_812><y_644></bbox>0\n<bbox><x_50><y_658><x_69><y_693></bbox>1\n<bbox><x_79><y_658><x_358><y_693></bbox>NATA DE COCO\n<bbox><x_790><y_652><x_814><y_687></bbox>0\n<bbox><x_31><y_742><x_820><y_781></bbox>Sub Total 45,455\n<bbox><x_27><y_781><x_822><y_827></bbox>PB1 (10%) 4,545\n<bbox><x_27><y_826><x_824><y_872></bbox>Rounding 0\n<bbox><x_24><y_872><x_827><y_921></bbox>Total 50,000\n<bbox><x_17><y_1056><x_836><y_1108></bbox>Card Payment 50,000\n"
]

self.assertListEqual(generated_text, EXPECTED_TEXT)

def test_receipt_image_md(self):
url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)

dtype = torch.bfloat16
repo = "microsoft/kosmos-2.5"
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=torch_device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
prompt = "<md>"
generated_ids, generated_text = self.run_example(prompt, image, model, processor)
print(generated_text)
EXPECTED_TEXT = [
"<md>- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000"
]
self.assertListEqual(generated_text, EXPECTED_TEXT)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we skip some SDPA/Flash attention tests above 🙏 ?

Could you have them in the integration tests? The tests are likely identical, but just using Pure/SDPA/Flash attention masks.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Since our CI is still using T4 runner, and non of these 3 tests pass (GPU OOM), I am thinking to reduce max_new_tokens=1024, to something smaller.

Do you have any comment about this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think that makes sense too, I can run on T4 and update the expected output values on my own side

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, it turns out that GPU OOM happens already at vision_model_output = self.vision_model. So reducing max_new_tokens won't work here. Forget about my above comment.

Comment on lines 31 to 38
@require_vision
class LlavaProcessorTest(unittest.TestCase):
def test_can_load_various_tokenizers(self):
# for checkpoint in ["microsoft/kosmos-2.5", "microsoft/kosmos-2.5"]:
for checkpoint in ["kirp/kosmos2_5"]:
processor = AutoProcessor.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be similar to Kosmos2ProcessorTest.

The most important is test_full_processor.

I did a very extensive tests there, but you can keep it simple.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tic-top It looks very good! A few nits comments + a few things for the tests.

I will have to continue to finalize the review tomorrow but I submit what I have so far.

One important point to me is about the comment I left for the model Integration tests.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 17, 2024

Ah, I forgot. We should also have a test file (class Kosmos2_5ImageProcessorTest) for class Kosmos2_5ImageProcessor
(similar to Pix2StructImageProcessingTest for Pix2StructImageProcessor)

It should be easy: just copy Pix2StructImageProcessingTest and only a few changes to do.

(But for this, you can wait - I still have a few things to review tomorrow)

docs/source/en/model_doc/kosmos-2.5.md Outdated Show resolved Hide resolved
Comment on lines 35 to 81
from PIL import Image
import requests
import torch
from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
import re
repo = "microsoft/kosmos-2.5"
device = "cuda:0"
dtype = torch.bfloat16
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<ocr>" # <md>
inputs = processor(text=prompt, images=image, return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width
inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
def postprocess(y, scale_height, scale_width):
y = y.replace(prompt, "")
if "<md>" in prompt:
return y
pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
bboxs_raw = re.findall(pattern, y)
lines = re.split(pattern, y)[1:]
bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
bboxs = [[int(j) for j in i] for i in bboxs]
info = ""
for i in range(len(lines)):
box = bboxs[i]
x0, y0, x1, y1 = box
if not (x0 >= x1 or y0 >= y1):
x0 = int(x0 * scale_width)
y0 = int(y0 * scale_height)
x1 = int(x1 * scale_width)
y1 = int(y1 * scale_height)
info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}"
return info
output_text = postprocess(generated_text[0], scale_height, scale_width)
print(output_text)
Copy link
Author

@tic-top tic-top Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is easier to use.
The python file above need to convert the str to json first, then draw.

@@ -1149,6 +1154,7 @@
_import_structure["models.idefics2"].extend(["Idefics2ImageProcessor"])
_import_structure["models.imagegpt"].extend(["ImageGPTFeatureExtractor", "ImageGPTImageProcessor"])
_import_structure["models.instructblipvideo"].extend(["InstructBlipVideoImageProcessor"])
_import_structure["models.kosmos2_5"].extend(["Kosmos2_5ImageProcessor", "Kosmos2_5Processor"])
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should I add it?

@@ -5821,6 +5839,7 @@
from .models.idefics2 import Idefics2ImageProcessor
from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor
from .models.instructblipvideo import InstructBlipVideoImageProcessor
from .models.kosmos2_5 import Kosmos2_5ImageProcessor, Kosmos2_5Processor
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where?

Comment on lines +287 to +292
seqlen = self.model_tester.text_model_tester.seq_length
inputs_dict["input_ids"] = torch.arange(seqlen, device=torch_device).unsqueeze(0).expand(bs, seqlen)
inputs_dict["input_ids"] = inputs_dict["input_ids"] % self.model_tester.text_model_tester.vocab_size
inputs_dict["attention_mask"] = torch.ones((bs, seqlen), device=torch_device)
inputs_dict["image_embeds_position_mask"] = torch.zeros((bs, seqlen), device=torch_device)
inputs_dict["image_embeds_position_mask"][:, : self.model_tester.latent_query_num] = 1
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I test it, it returns unexpected result.

Comment on lines 31 to 38
@require_vision
class LlavaProcessorTest(unittest.TestCase):
def test_can_load_various_tokenizers(self):
# for checkpoint in ["microsoft/kosmos-2.5", "microsoft/kosmos-2.5"]:
for checkpoint in ["kirp/kosmos2_5"]:
processor = AutoProcessor.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Comment on lines 561 to 611
class Kosmos2_5ModelIntegrationTest(unittest.TestCase):
def run_example(self, prompt, image, model, processor):
print("Prompt:", prompt)
inputs = processor(text=prompt, images=image, return_tensors="pt")
_, _ = inputs.pop("height"), inputs.pop("width")
inputs = {k: v.to(torch_device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(model.dtype)

generation_outputs = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_ids = generation_outputs
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

return generated_ids, generated_text

def test_receipt_image_ocr(self):
url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)

dtype = torch.bfloat16
repo = "microsoft/kosmos-2.5"
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=torch_device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
prompt = "<ocr>"
generated_ids, generated_text = self.run_example(prompt, image, model, processor)

EXPECTED_TEXT = [
"<ocr><bbox><x_53><y_573><x_69><y_606></bbox>1\n<bbox><x_79><y_573><x_464><y_611></bbox>[REG] BLACK SAKURA\n<bbox><x_690><y_569><x_810><y_606></bbox>45,455\n<bbox><x_53><y_614><x_69><y_648></bbox>1\n<bbox><x_79><y_614><x_468><y_650></bbox>COOKIE DOH SAUCES\n<bbox><x_788><y_609><x_812><y_644></bbox>0\n<bbox><x_50><y_658><x_69><y_693></bbox>1\n<bbox><x_79><y_658><x_358><y_693></bbox>NATA DE COCO\n<bbox><x_790><y_652><x_814><y_687></bbox>0\n<bbox><x_31><y_742><x_820><y_781></bbox>Sub Total 45,455\n<bbox><x_27><y_781><x_822><y_827></bbox>PB1 (10%) 4,545\n<bbox><x_27><y_826><x_824><y_872></bbox>Rounding 0\n<bbox><x_24><y_872><x_827><y_921></bbox>Total 50,000\n<bbox><x_17><y_1056><x_836><y_1108></bbox>Card Payment 50,000\n"
]

self.assertListEqual(generated_text, EXPECTED_TEXT)

def test_receipt_image_md(self):
url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"
url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png"
image = Image.open(requests.get(url, stream=True).raw)

dtype = torch.bfloat16
repo = "microsoft/kosmos-2.5"
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=torch_device, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained(repo)
prompt = "<md>"
generated_ids, generated_text = self.run_example(prompt, image, model, processor)
print(generated_text)
EXPECTED_TEXT = [
"<md>- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000"
]
self.assertListEqual(generated_text, EXPECTED_TEXT)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Copy link
Author

@tic-top tic-top left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should I add processor?

repo = "microsoft/kosmos-2.5"
model = Kosmos2_5ForConditionalGeneration.from_pretrained(
repo, device_map=torch_device, torch_dtype=dtype
) # , attn_implementation="eager")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when not specified, and sdpa is available, it will actually use sdpa. Hence we have to specify "eager" in order to test "eager"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a commit for this.

]
self.assertListEqual(generated_text, EXPECTED_TEXT)

def test_FA2(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need

    @require_flash_attn
    @require_torch_gpu
    @pytest.mark.flash_attn_test
    @slow

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(before def test_FA2(self):)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed the change for this.

ydshieh and others added 12 commits February 3, 2025 14:20
# Please enter a commit message to explain why this merge is necessary,
# especially if it merges an updated upstream into a topic branch.
#
# Lines starting with '#' will be ignored, and an empty message aborts
# the commit.
@ydshieh
Copy link
Collaborator

ydshieh commented Feb 5, 2025

@ArthurZucker Would be great if you could take a final look 🙏 and hope we can have this in the next release 🚀

@ydshieh ydshieh requested review from ArthurZucker and removed request for ArthurZucker February 5, 2025 09:34
Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tic-top for working on this!

@ydshieh It would be great if you could add a fast image processor for this model.
With the recent refactor, you can use transformers-cli add-fast-image-processor --model-name kosmos2_5 to add all the necessary imports, then you can check the fast image processor of llava-next for an example of how to implement it.
Since most of the function involved in the image processing already use torch here, it should be quite straightforward. You can also use torchvision functional transforms instead of pure torch functions in the fast image processor.

"""

attributes = ["image_processor", "tokenizer"]
image_processor_class = "Kosmos2_5ImageProcessor"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you'll have to change this to AutoImageProcessor once the fast image processor is added :)

Comment on lines +307 to +339
flattened_patches, width, height, rows, cols, attention_masks = [], [], [], [], [], []
for image in images:
if do_normalize:
image = self.normalize(image=image, input_data_format=input_data_format)

# convert to torch tensor and permute
f, w, h, r, c = self.extract_flattened_patches(
image=image,
max_patches=max_patches,
patch_size=patch_size,
input_data_format=input_data_format,
)
flattened_patches.append(f)
width.append(w)
height.append(h)
rows.append(r)
cols.append(c)
# create attention mask in numpy
attention_masks.append((f.sum(axis=-1) != 0).astype(np.float32))

encoded_outputs = BatchFeature(
data={
"flattened_patches": flattened_patches,
"attention_mask": attention_masks,
"width": width,
"height": height,
"rows": rows,
"cols": cols,
},
tensor_type=return_tensors,
)

return encoded_outputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only this part should be needed in the _preprocess function of the fast image processor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.