From 074318f79bfebccf15cfc69c638bac798ba8e308 Mon Sep 17 00:00:00 2001 From: ByteYJ Date: Tue, 6 Aug 2024 13:48:45 -0400 Subject: [PATCH 01/10] Automate testing for llm ocr functions. --- pyproject.toml | 2 + tests/test_llm_ocr_functions.py | 98 +++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 tests/test_llm_ocr_functions.py diff --git a/pyproject.toml b/pyproject.toml index 100d44c..d0144ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,8 @@ test = [ "pytest", "pytest-datadir", "requests_mock", + "pytest_mock", + "openai" ] # Dependencies only needed to run the streamlit app go here diff --git a/tests/test_llm_ocr_functions.py b/tests/test_llm_ocr_functions.py new file mode 100644 index 0000000..d39076a --- /dev/null +++ b/tests/test_llm_ocr_functions.py @@ -0,0 +1,98 @@ +import unittest +from unittest.mock import patch +import logging +from typing import Optional + +import openai +from openai import APIConnectionError, AuthenticationError + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Process data using AI +class AIHandler: + """Process data using AI.""" + + MODEL = "gpt4o" + + def __init__(self, openai_key: str) -> None: + """ + Initialize the class with the OpenAI API key. + + Args: + openai_key: The API key for accessing the OpenAI service. + """ + self.openai_key = openai_key + openai.api_key = self.openai_key + + def query_api(self, query: str) -> Optional[str]: + """ + Query the AI API. + + Args: + query: Query to send to the API + + Returns: + Response message from the API + """ + logging.info("Querying API...") + + # No need to query the API if there is no query content + if not query: + return None + + message = [{"role": "user", "content": query}] + + result = None + try: + completion = openai.ChatCompletion.create( + model=self.MODEL, messages=message + ) + result = completion.choices[0].message['content'] + except AuthenticationError as ex: + logger.error("Authentication error: %s", ex) + except APIConnectionError as ex: + logger.error("APIConnection error: %s", ex) + + return result + + +# Mocking classes for testing +class MockedChoice: + def __init__(self, content: str) -> None: + self.message = {"content": content} + + +class MockedCompletion: + def __init__(self, content: str) -> None: + self.choices = [MockedChoice(content)] + + +class TestAIHandler(unittest.TestCase): + + @patch('openai.ChatCompletion.create') + def test_query_api(self, mock_create): + # Mocking the response from OpenAI API + mock_create.return_value = MockedCompletion("This is a mocked response from AI.") + + # Initialize the handler with a fake API key + handler = AIHandler(openai_key="fake_api_key") + + # Call the method with a test query + response = handler.query_api("Test query") + + # Check that the mocked response is returned + self.assertEqual(response, "This is a mocked response from AI.") + + # Check that the API was called with the correct parameters + mock_create.assert_called_once_with( + model="gpt4o", messages=[{"role": "user", "content": "Test query"}] + ) + + + + + + + From 5d25554ac963f8b74ddcc0baea99b1f1c00cd80f Mon Sep 17 00:00:00 2001 From: ByteYJ Date: Tue, 6 Aug 2024 15:21:19 -0400 Subject: [PATCH 02/10] update test script --- tests/test_llm_ocr_functions.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/tests/test_llm_ocr_functions.py b/tests/test_llm_ocr_functions.py index d39076a..9d11a76 100644 --- a/tests/test_llm_ocr_functions.py +++ b/tests/test_llm_ocr_functions.py @@ -2,19 +2,20 @@ from unittest.mock import patch import logging from typing import Optional +from requests.models import Response import openai -from openai import APIConnectionError, AuthenticationError +from openai import APIConnectionError, AuthenticationError, APIStatusError # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Process data using AI + class AIHandler: """Process data using AI.""" - MODEL = "gpt4o" + MODEL = "gpt-4" def __init__(self, openai_key: str) -> None: """ @@ -53,11 +54,12 @@ def query_api(self, query: str) -> Optional[str]: except AuthenticationError as ex: logger.error("Authentication error: %s", ex) except APIConnectionError as ex: - logger.error("APIConnection error: %s", ex) + logger.error("API connection error: %s", ex) + except APIStatusError as ex: + logger.error("API status error: %s", ex) return result - # Mocking classes for testing class MockedChoice: def __init__(self, content: str) -> None: @@ -87,12 +89,29 @@ def test_query_api(self, mock_create): # Check that the API was called with the correct parameters mock_create.assert_called_once_with( - model="gpt4o", messages=[{"role": "user", "content": "Test query"}] + model="gpt-4", messages=[{"role": "user", "content": "Test query"}] ) + @patch('openai.ChatCompletion.create') + def test_query_api_authentication_error(self, mock_create): + # Create a mock response object + mock_response = Response() + mock_response.status_code = 401 # Unauthorized status code + + # Mock the error + mock_create.side_effect = AuthenticationError( + message="Invalid API key", + response=mock_response, + body={} + ) + handler = AIHandler(openai_key="fake_api_key") + response = handler.query_api("Test query") - + self.assertIsNone(response) + mock_create.assert_called_once_with( + model="gpt-4", messages=[{"role": "user", "content": "Test query"}] + ) From 4284bb17b1e1e202d1af0f09255cd01ac17c93ca Mon Sep 17 00:00:00 2001 From: ByteYJ Date: Tue, 6 Aug 2024 15:25:46 -0400 Subject: [PATCH 03/10] add docstring --- tests/test_llm_ocr_functions.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_llm_ocr_functions.py b/tests/test_llm_ocr_functions.py index 9d11a76..6c549ad 100644 --- a/tests/test_llm_ocr_functions.py +++ b/tests/test_llm_ocr_functions.py @@ -13,14 +13,11 @@ class AIHandler: - """Process data using AI.""" - - MODEL = "gpt-4" + MODEL = "gpt-4o" def __init__(self, openai_key: str) -> None: """ Initialize the class with the OpenAI API key. - Args: openai_key: The API key for accessing the OpenAI service. """ @@ -33,7 +30,6 @@ def query_api(self, query: str) -> Optional[str]: Args: query: Query to send to the API - Returns: Response message from the API """ @@ -110,7 +106,7 @@ def test_query_api_authentication_error(self, mock_create): self.assertIsNone(response) mock_create.assert_called_once_with( - model="gpt-4", messages=[{"role": "user", "content": "Test query"}] + model="gpt-4o", messages=[{"role": "user", "content": "Test query"}] ) From af6a7bab891c6b9c4264aac92f311d2e0d456051 Mon Sep 17 00:00:00 2001 From: ByteYJ Date: Tue, 6 Aug 2024 15:43:02 -0400 Subject: [PATCH 04/10] bug fix --- tests/test_llm_ocr_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_llm_ocr_functions.py b/tests/test_llm_ocr_functions.py index 6c549ad..af73810 100644 --- a/tests/test_llm_ocr_functions.py +++ b/tests/test_llm_ocr_functions.py @@ -85,7 +85,7 @@ def test_query_api(self, mock_create): # Check that the API was called with the correct parameters mock_create.assert_called_once_with( - model="gpt-4", messages=[{"role": "user", "content": "Test query"}] + model="gpt-4o", messages=[{"role": "user", "content": "Test query"}] ) @patch('openai.ChatCompletion.create') From efe8c55bfb62a710fedfbd8fb438a298f45158ce Mon Sep 17 00:00:00 2001 From: ByteYJ Date: Wed, 7 Aug 2024 11:00:20 -0400 Subject: [PATCH 05/10] add tests for ocr functions --- tests/test_llm_ocr_functions.py | 164 ++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) diff --git a/tests/test_llm_ocr_functions.py b/tests/test_llm_ocr_functions.py index af73810..07207b7 100644 --- a/tests/test_llm_ocr_functions.py +++ b/tests/test_llm_ocr_functions.py @@ -3,10 +3,174 @@ import logging from typing import Optional from requests.models import Response +import pandas as pd +from PIL import Image, ImageOps,ExifTags +from io import BytesIO +import base64 import openai from openai import APIConnectionError, AuthenticationError, APIStatusError +from msfocr.llm import ocr_functions + +'Part1-testing llm_ocr_function' + +def test_parse_table_data(): + result = { + 'tables': [ + { + 'table_name': 'Paediatric vaccination target group', + 'headers': ['', '0-11m', '12-59m', '5-14y'], + 'data': [['Paed (0-59m) vacc target population', '', '', '']] + }, + { + 'table_name': 'Routine paediatric vaccinations', + 'headers': ['', '0-11m', '12-59m', '5-14y'], + 'data': [ + ['BCG', '45+29', '-', '-'], + ['HepB (birth dose, within 24h)', '-', '-', '-'], + ['HepB (birth dose, 24h or later)', '-', '-', '-'], + ['Polio (OPV) 0 (birth dose)', '30+18', '-', '-'], + ['Polio (OPV) 1 (from 6 wks)', '55+29', '-', '-'], + ['Polio (OPV) 2', '77+19', '8', '-'], + ['Polio (OPV) 3', '116+8', '15+3', '-'], + ['Polio (IPV)', '342+42', '-', '-'], + ['DTP+Hib+HepB (pentavalent) 1', '88+37', '3', '-'], + ['DTP+Hib+HepB (pentavalent) 2', '125+16', '14+1', '-'], + ['DTP+Hib+HepB (pentavalent) 3', '107+5', '23+6', '-'] + ] + } + ], + 'non_table_data': { + 'Health Structure': 'W14', + 'Supervisor': 'BKL', + 'Start Date (YYYY-MM-DD)': '', + 'End Date (YYYY-MM-DD)': '', + 'Vaccination': 'paediatric' + } + } + + expected_table_names = [ + 'Paediatric vaccination target group', + 'Routine paediatric vaccinations' + ] + + expected_dataframes = [ + pd.DataFrame([['', '0-11m', '12-59m', '5-14y'], ['Paed (0-59m) vacc target population', '', '', '']]), + pd.DataFrame([ + ['', '0-11m', '12-59m', '5-14y'], + ['BCG', '45+29', '-', '-'], + ['HepB (birth dose, within 24h)', '-', '-', '-'], + ['HepB (birth dose, 24h or later)', '-', '-', '-'], + ['Polio (OPV) 0 (birth dose)', '30+18', '-', '-'], + ['Polio (OPV) 1 (from 6 wks)', '55+29', '-', '-'], + ['Polio (OPV) 2', '77+19', '8', '-'], + ['Polio (OPV) 3', '116+8', '15+3', '-'], + ['Polio (IPV)', '342+42', '-', '-'], + ['DTP+Hib+HepB (pentavalent) 1', '88+37', '3', '-'], + ['DTP+Hib+HepB (pentavalent) 2', '125+16', '14+1', '-'], + ['DTP+Hib+HepB (pentavalent) 3', '107+5', '23+6', '-'] + ]) + ] + + table_names, dataframes = ocr_functions.parse_table_data(result) + + assert table_names == expected_table_names + for df, expected_df in zip(dataframes, expected_dataframes): + pd.testing.assert_frame_equal(df, expected_df) + + +def test_rescale_image(): + # Create a simple image for testing + img = Image.new('RGB', (3000, 1500), color='red') + + # Test resizing largest dimension + resized_img = ocr_functions.rescale_image(img, 2048, True) + assert max(resized_img.size) == 2048 + assert resized_img.size == (2048, 1024) # Expected resized dimensions + + # Test resizing smallest dimension + resized_img = ocr_functions.rescale_image(img, 768, False) + assert min(resized_img.size) == 768 + # 768 / 1024 * 2048 = 1536 + assert resized_img.size == (1536, 768) + + +def test_encode_image(): + # Create a simple image for testing + img = Image.new('RGB', (3000, 1500), color = 'red') + buffered = BytesIO() + img.save(buffered, format="PNG") + buffered.seek(0) + + # Encode the image using the encode_image function + encoded_string = ocr_functions.encode_image(buffered) + + # Verify that the encoded string is a valid base64 string + decoded_image = base64.b64decode(encoded_string) + assert decoded_image[:8] == b'\x89PNG\r\n\x1a\n' + + # Optionally, check if the image can be successfully loaded back + img_back = Image.open(BytesIO(decoded_image)) + assert max(img_back.size) == 2048 or min(img_back.size) == 768 + + +def create_test_image_with_orientation(orientation): + # Create a simple image + img = Image.new('RGB', (100, 50), color='red') + buffered = BytesIO() + img.save(buffered, format="JPEG") + buffered.seek(0) + + # Load the image and manually set the orientation EXIF tag + img_with_orientation = Image.open(buffered) + exif = img_with_orientation.getexif() + exif[274] = orientation # 274 is the EXIF tag code for Orientation + exif_bytes = exif.tobytes() + + # Save the image with the new EXIF data + buffered = BytesIO() + img_with_orientation.save(buffered, format="JPEG", exif=exif_bytes) + buffered.seek(0) + return buffered + + +def assert_color_within_tolerance(color1, color2, tolerance=1): + """ + Assert that two colors are within a given tolerance. + + :param color1: The first color as an (R, G, B) tuple. + :param color2: The second color as an (R, G, B) tuple. + :param tolerance: The tolerance for each color component. + """ + for c1, c2 in zip(color1, color2): + assert abs(c1 - c2) <= tolerance + + +def test_correct_image_orientation(): + # Test for orientation 3 (180 degrees) + img_data = create_test_image_with_orientation(3) + corrected_image = ocr_functions.correct_image_orientation(img_data) + assert corrected_image.size == (100, 50) + # Check if top-left pixel is red after rotating 180 degrees + assert_color_within_tolerance(corrected_image.getpixel((0, 0)), (255, 0, 0)) + + # Test for orientation 6 (270 degrees) + img_data = create_test_image_with_orientation(6) + corrected_image = ocr_functions.correct_image_orientation(img_data) + assert corrected_image.size == (50, 100) + # Check if bottom-left pixel is red after rotating 270 degrees + assert_color_within_tolerance(corrected_image.getpixel((0, corrected_image.size[1] - 1)), (255, 0, 0)) + + # Test for orientation 8 (90 degrees) + img_data = create_test_image_with_orientation(8) + corrected_image = ocr_functions.correct_image_orientation(img_data) + assert corrected_image.size == (50, 100) + # Check if top-right pixel is red after rotating 90 degrees + assert_color_within_tolerance(corrected_image.getpixel((corrected_image.size[0] - 1, 0)), (255, 0, 0)) + + +'Part2-testing openai api call' # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) From 60bdc12172327a6efff93b9a292be81d76ecea53 Mon Sep 17 00:00:00 2001 From: ByteYJ Date: Wed, 7 Aug 2024 11:34:05 -0400 Subject: [PATCH 06/10] removed unused packages --- tests/test_llm_ocr_functions.py | 122 ++++++++++++-------------------- 1 file changed, 47 insertions(+), 75 deletions(-) diff --git a/tests/test_llm_ocr_functions.py b/tests/test_llm_ocr_functions.py index 07207b7..b7c4e7b 100644 --- a/tests/test_llm_ocr_functions.py +++ b/tests/test_llm_ocr_functions.py @@ -1,6 +1,5 @@ -import unittest +import pytest from unittest.mock import patch -import logging from typing import Optional from requests.models import Response import pandas as pd @@ -14,7 +13,6 @@ from msfocr.llm import ocr_functions 'Part1-testing llm_ocr_function' - def test_parse_table_data(): result = { 'tables': [ @@ -74,7 +72,6 @@ def test_parse_table_data(): ] table_names, dataframes = ocr_functions.parse_table_data(result) - assert table_names == expected_table_names for df, expected_df in zip(dataframes, expected_dataframes): pd.testing.assert_frame_equal(df, expected_df) @@ -98,7 +95,7 @@ def test_rescale_image(): def test_encode_image(): # Create a simple image for testing - img = Image.new('RGB', (3000, 1500), color = 'red') + img = Image.new('RGB', (3000, 1500), color='red') buffered = BytesIO() img.save(buffered, format="PNG") buffered.seek(0) @@ -136,13 +133,6 @@ def create_test_image_with_orientation(orientation): def assert_color_within_tolerance(color1, color2, tolerance=1): - """ - Assert that two colors are within a given tolerance. - - :param color1: The first color as an (R, G, B) tuple. - :param color2: The second color as an (R, G, B) tuple. - :param tolerance: The tolerance for each color component. - """ for c1, c2 in zip(color1, color2): assert abs(c1 - c2) <= tolerance @@ -171,20 +161,10 @@ def test_correct_image_orientation(): 'Part2-testing openai api call' -# Setup logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - class AIHandler: MODEL = "gpt-4o" def __init__(self, openai_key: str) -> None: - """ - Initialize the class with the OpenAI API key. - Args: - openai_key: The API key for accessing the OpenAI service. - """ self.openai_key = openai_key openai.api_key = self.openai_key @@ -197,8 +177,6 @@ def query_api(self, query: str) -> Optional[str]: Returns: Response message from the API """ - logging.info("Querying API...") - # No need to query the API if there is no query content if not query: return None @@ -211,67 +189,61 @@ def query_api(self, query: str) -> Optional[str]: model=self.MODEL, messages=message ) result = completion.choices[0].message['content'] - except AuthenticationError as ex: - logger.error("Authentication error: %s", ex) - except APIConnectionError as ex: - logger.error("API connection error: %s", ex) - except APIStatusError as ex: - logger.error("API status error: %s", ex) - + except AuthenticationError: + pass + except APIConnectionError: + pass + except APIStatusError: + pass return result -# Mocking classes for testing class MockedChoice: def __init__(self, content: str) -> None: self.message = {"content": content} - class MockedCompletion: def __init__(self, content: str) -> None: self.choices = [MockedChoice(content)] - -class TestAIHandler(unittest.TestCase): - - @patch('openai.ChatCompletion.create') - def test_query_api(self, mock_create): - # Mocking the response from OpenAI API - mock_create.return_value = MockedCompletion("This is a mocked response from AI.") - - # Initialize the handler with a fake API key - handler = AIHandler(openai_key="fake_api_key") - - # Call the method with a test query - response = handler.query_api("Test query") - - # Check that the mocked response is returned - self.assertEqual(response, "This is a mocked response from AI.") - - # Check that the API was called with the correct parameters - mock_create.assert_called_once_with( - model="gpt-4o", messages=[{"role": "user", "content": "Test query"}] - ) - - @patch('openai.ChatCompletion.create') - def test_query_api_authentication_error(self, mock_create): - # Create a mock response object - mock_response = Response() - mock_response.status_code = 401 # Unauthorized status code - - # Mock the error - mock_create.side_effect = AuthenticationError( - message="Invalid API key", - response=mock_response, - body={} - ) - - handler = AIHandler(openai_key="fake_api_key") - response = handler.query_api("Test query") - - self.assertIsNone(response) - mock_create.assert_called_once_with( - model="gpt-4o", messages=[{"role": "user", "content": "Test query"}] - ) +@pytest.fixture +def ai_handler(): + return AIHandler(openai_key="fake_api_key") + +@patch('openai.ChatCompletion.create') +def test_query_api(mock_create, ai_handler): + # Mocking the response from OpenAI API + mock_create.return_value = MockedCompletion("This is a mocked response from AI.") + + # Call the method with a test query + response = ai_handler.query_api("Test query") + + # Check that the mocked response is returned + assert response == "This is a mocked response from AI." + + # Check that the API was called with the correct parameters + mock_create.assert_called_once_with( + model="gpt-4o", messages=[{"role": "user", "content": "Test query"}] + ) + +@patch('openai.ChatCompletion.create') +def test_query_api_authentication_error(mock_create, ai_handler): + # Create a mock response object + mock_response = Response() + mock_response.status_code = 401 # Unauthorized status code + + # Mock the error + mock_create.side_effect = AuthenticationError( + message="Invalid API key", + response=mock_response, + body={} + ) + + response = ai_handler.query_api("Test query") + + assert response is None + mock_create.assert_called_once_with( + model="gpt-4o", messages=[{"role": "user", "content": "Test query"}] + ) From c0830dbe8bfd5c663ae3b0237b867b3961e4ac6e Mon Sep 17 00:00:00 2001 From: ByteYJ Date: Wed, 7 Aug 2024 11:42:08 -0400 Subject: [PATCH 07/10] removed unused import --- tests/test_llm_ocr_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_llm_ocr_functions.py b/tests/test_llm_ocr_functions.py index b7c4e7b..27a0fd7 100644 --- a/tests/test_llm_ocr_functions.py +++ b/tests/test_llm_ocr_functions.py @@ -3,7 +3,7 @@ from typing import Optional from requests.models import Response import pandas as pd -from PIL import Image, ImageOps,ExifTags +from PIL import Image from io import BytesIO import base64 From 5f1155b4845ae76668f008997da83362ee842fc5 Mon Sep 17 00:00:00 2001 From: ByteYJ Date: Wed, 7 Aug 2024 14:24:25 -0400 Subject: [PATCH 08/10] Fix image open issue and remove unused library. --- pyproject.toml | 1 - src/msfocr/llm/ocr_functions.py | 32 +++++++++++++++++--------------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d0144ad..382f5b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,6 @@ test = [ "pytest", "pytest-datadir", "requests_mock", - "pytest_mock", "openai" ] diff --git a/src/msfocr/llm/ocr_functions.py b/src/msfocr/llm/ocr_functions.py index 246d4a3..ca1f5ad 100644 --- a/src/msfocr/llm/ocr_functions.py +++ b/src/msfocr/llm/ocr_functions.py @@ -169,21 +169,23 @@ def correct_image_orientation(image_path): :param image_path: The path to the image file. :return: PIL.Image.Image: The image with corrected orientation. """ - with Image.open(image_path) as image: - orientation = None + with Image.open(image_path) as image: try: - for orientation in ExifTags.TAGS.keys(): - if ExifTags.TAGS[orientation] == 'Orientation': - break - exif = dict(image.getexif().items()) - if exif.get(orientation) == 3: - image = image.rotate(180, expand=True) - elif exif.get(orientation) == 6: - image = image.rotate(270, expand=True) - elif exif.get(orientation) == 8: - image = image.rotate(90, expand=True) - except (AttributeError, KeyError, IndexError): - pass - return image + Image.open(image_path) + exif = image._getexif() + if exif is not None: + orientation_key = next( + key for key, value in ExifTags.TAGS.items() if value == 'Orientation' + ) + orientation = exif.get(orientation_key) + if orientation == 3: + image = image.rotate(180, expand=True) + elif orientation == 6: + image = image.rotate(270, expand=True) + elif orientation == 8: + image = image.rotate(90, expand=True) + except (AttributeError, KeyError, IndexError, StopIteration) as e: + print(f"Error correcting image orientation: {e}") + return image.copy() From b17f98b0594085a5f1eadcd258f06282101a2444 Mon Sep 17 00:00:00 2001 From: ByteYJ Date: Wed, 7 Aug 2024 15:17:09 -0400 Subject: [PATCH 09/10] fix the correct_image_orientation function --- src/msfocr/llm/ocr_functions.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/msfocr/llm/ocr_functions.py b/src/msfocr/llm/ocr_functions.py index ca1f5ad..1c6ed59 100644 --- a/src/msfocr/llm/ocr_functions.py +++ b/src/msfocr/llm/ocr_functions.py @@ -170,22 +170,20 @@ def correct_image_orientation(image_path): :return: PIL.Image.Image: The image with corrected orientation. """ with Image.open(image_path) as image: + orientation = None try: - Image.open(image_path) - exif = image._getexif() - if exif is not None: - orientation_key = next( - key for key, value in ExifTags.TAGS.items() if value == 'Orientation' - ) - orientation = exif.get(orientation_key) - if orientation == 3: - image = image.rotate(180, expand=True) - elif orientation == 6: - image = image.rotate(270, expand=True) - elif orientation == 8: - image = image.rotate(90, expand=True) - except (AttributeError, KeyError, IndexError, StopIteration) as e: - print(f"Error correcting image orientation: {e}") + for orientation in ExifTags.TAGS.keys(): + if ExifTags.TAGS[orientation] == 'Orientation': + break + exif = dict(image.getexif().items()) + if exif.get(orientation) == 3: + image = image.rotate(180, expand=True) + elif exif.get(orientation) == 6: + image = image.rotate(270, expand=True) + elif exif.get(orientation) == 8: + image = image.rotate(90, expand=True) + except (AttributeError, KeyError, IndexError): + pass return image.copy() From 035a0e181c782fe67e2be3f5270b6101fc8280cf Mon Sep 17 00:00:00 2001 From: Virginia Partridge Date: Wed, 7 Aug 2024 17:39:04 -0400 Subject: [PATCH 10/10] Fixed image loading --- src/msfocr/llm/ocr_functions.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/msfocr/llm/ocr_functions.py b/src/msfocr/llm/ocr_functions.py index 1c6ed59..f4df03e 100644 --- a/src/msfocr/llm/ocr_functions.py +++ b/src/msfocr/llm/ocr_functions.py @@ -170,20 +170,22 @@ def correct_image_orientation(image_path): :return: PIL.Image.Image: The image with corrected orientation. """ with Image.open(image_path) as image: - orientation = None - try: - for orientation in ExifTags.TAGS.keys(): - if ExifTags.TAGS[orientation] == 'Orientation': - break - exif = dict(image.getexif().items()) - if exif.get(orientation) == 3: - image = image.rotate(180, expand=True) - elif exif.get(orientation) == 6: - image = image.rotate(270, expand=True) - elif exif.get(orientation) == 8: - image = image.rotate(90, expand=True) - except (AttributeError, KeyError, IndexError): - pass - return image.copy() + image.load() + + orientation = None + try: + for orientation in ExifTags.TAGS.keys(): + if ExifTags.TAGS[orientation] == 'Orientation': + break + exif = dict(image.getexif().items()) + if exif.get(orientation) == 3: + image = image.rotate(180, expand=True) + elif exif.get(orientation) == 6: + image = image.rotate(270, expand=True) + elif exif.get(orientation) == 8: + image = image.rotate(90, expand=True) + except (AttributeError, KeyError, IndexError): + pass + return image