Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image generation functionality #407

Merged
merged 28 commits into from
Nov 16, 2023
Merged

Conversation

andrewfrench
Copy link
Member

@andrewfrench andrewfrench commented Nov 2, 2023

This PR adds image generation functionality, including driver support for the following providers/models:

  • OpenAI DALLE2
  • Leonardo
  • Stable Diffusion via Amazon Bedrock

along with unit tests for the above drivers. A new tool, ImageGenerationTool, accepts an ImageGenerationEngine configured to use the desired driver:

from griptape.drivers import AmazonBedrockStableDiffusionImageGenerationDriver
from griptape.engines import ImageGenerationEngine
from griptape.tools import ImageGenerator
from griptape.structures import Agent

...

driver = AmazonBedrockStableDiffusionImageGenerationDriver(
  session=boto3.Session(),
  style_preset="cinematic",
  sampler="K_EULER",
  # ..., etc.
)

image_generator = ImageGenerator(
  image_generation_engine=ImageGenerationEngine(
    image_generation_driver=driver,
  )
)

Agent(tools=[image_generator])
...

Resolves #332


@define(frozen=True)
class ImageArtifact(BlobArtifact):
mime_type: str = field(default="image/png", kw_only=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this should be the default. Can we make it more generic? Seems like we might have options.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that a default doesn't quite make sense here. Updated to require the mime_type value.

from griptape.artifacts import ImageArtifact


class BaseImageGenerationDriver:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add (ABC).



@define
class Dalle2ImageGenerationDriver(BaseImageGenerationDriver):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it a generic DalleImageGenerationDriver and manage version at the class property level.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we prepend with OpenAi to match other driver naming schemas?

import requests
from attr import field, Factory, define
from griptape.artifacts import ImageArtifact
from griptape.drivers.image_generation.base_image_generation_driver import (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop base_image_generation_driver.



@define
class PromptImageGenerationEngine(BaseImageGenerationEngine):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make it a generic ImageGenerationEngine engine and add multiple methods for text-to-image and image-to-image modes? Or should we have different engines that encapsulate text-to-image (i.e., PromptImageGenerationEngine) and image-to-image?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've gone back and forth on it. This implementation represents the assumption that the interface and mechanics for text-to-image and image-to-image will be substantially different or that providers (and their drivers) might only provide a subset of that functionality. We could overcome each of those obstacles, just chose a path here.

Comment on lines 18 to 20
image_generation_driver: BaseImageGenerationDriver = field(
default=Factory(lambda: Dalle2ImageGenerationDriver()), kw_only=True
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove this to simplify the interface? The user can always pass a custom driver into image_generation_engine.

Comment on lines 1 to 3
Description: """{{ description }}"""

Generate an image based on the description provided.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's sync up with @shhlife, @Amaru-Zeas, or @averyroche about what the best base prompt is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. I think we should leave the flowery specifics to the description field, but I can imagine reinforcing instructions like 'stick to the description' or optionally 'be creative', etc.

Comment on lines 20 to 22
image_generation_driver: BaseImageGenerationDriver = field(
default=Factory(lambda: Dalle2ImageGenerationDriver()), kw_only=True
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as in the task: could probably remove this.

@andrewfrench andrewfrench changed the base branch from main to dev November 8, 2023 17:40
@andrewfrench andrewfrench force-pushed the french/image-generation branch 2 times, most recently from 04fcba4 to ff54a45 Compare November 9, 2023 05:37
@andrewfrench andrewfrench marked this pull request as ready for review November 9, 2023 05:37
@andrewfrench andrewfrench force-pushed the french/image-generation branch from 6538b39 to b17d023 Compare November 9, 2023 16:15
Copy link
Member

@collindutter collindutter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, this is going to be awesome!

Comment on lines 42 to 44
image_size: Union[
Literal["256x256"], Literal["512x512"], Literal["1024x1024"], Literal["1024x1792"], Literal["1792x1024"]
] = field(default=Literal["512x512"], kw_only=True)
Copy link
Member Author

@andrewfrench andrewfrench Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how I feel about this -- the patterns we surface to users seem to be heavily inspired by the underlying SDKs being interfaced with. In this case (and in OpenAI Chat Prompt Driver) we expose the same Literal values the OpenAI SDK expects, but we don't see this elsewhere in the framework. Should we define a more concrete Griptape style and transform inputs to what's expected by the dependencies? Are the utilities provided by typing the way we want to go?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's somewhat unavoidable to expose SDK patterns in Drivers since their primary purpose is to sit right on top of the SDK/API. Even in the case of image_size, a seemingly universal field, it seems non-straightforward to implement it at the BaseImageGenerationDriver level. I think providing helpful type hints is a good enough solution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's certainly less maintenance to do it this way. If we define and enforce a standard Griptape interface style, the intermediate layer would be a constant source of frustration as we'd have to re-define (or somehow handle) all possible options to some degree. On the other hand, we require the user do some research and have knowledge of the underlying dependency (OpenAI SDK, Leonardo API) before use. Not at all unreasonable, but a bit annoying.

@@ -10,5 +10,5 @@ class BaseImageGenerationDriver(ABC):
model: str = field(kw_only=True)

@abstractmethod
def generate_image(self, prompts: list[str], negative_prompts: list[str], **kwargs) -> ImageArtifact:
def generate_image(self, prompts: list[str], negative_prompts: list[str] = list, **kwargs) -> ImageArtifact:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -10,5 +10,5 @@ class BaseImageGenerationDriver(ABC):
model: str = field(kw_only=True)

@abstractmethod
def generate_image(self, prompts: list[str], negative_prompts: list[str] = list, **kwargs) -> ImageArtifact:
def generate_image(self, prompts: list[str], negative_prompts: list[str] = None, **kwargs) -> ImageArtifact:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list[str] should be Optional[list[str]]

@@ -31,7 +31,10 @@ class LeonardoImageGenerationDriver(BaseImageGenerationDriver):
image_width: int = field(default=512, kw_only=True)
image_height: int = field(default=512, kw_only=True)

def generate_image(self, prompts: list[str], negative_prompts: list[str], **kwargs) -> ImageArtifact:
def generate_image(self, prompts: list[str], negative_prompts: list[str] = None, **kwargs) -> ImageArtifact:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional[list[str]]

negative_prompts: list[str] = list,
rulesets: list[Ruleset] = list,
negative_rulesets: list[Ruleset] = list,
negative_prompts: list[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional

**kwargs
):
if not negative_prompts:
negative_prompts = []

for ruleset in rulesets:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to initialize rulesets and negative_rulesets to empty lists of None

negative_prompts: Optional[list[str]] = None,
rulesets: Optional[list[Ruleset]] = None,
negative_rulesets: Optional[list[Ruleset]] = None,
**kwargs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need kwargs? As far as I can tell it's not being used anywhere.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was needed when additional parameters came through the method call and not the driver setup. Removed.

https://docs.leonardo.ai/reference/creategeneration
"""

api_key: str = field(default=Factory(lambda: os.environ.get("LEONARDO_API_KEY")), kw_only=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this change didn't get implemented

def _get_image_url(self, generation_id: str):
for attempt in range(self.max_attempts):
response = self.requests_session.get(
url=f"{self.api_base}/generations/{generation_id}", headers={"authorization": f"Bearer {self.api_key}"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capitalize authorization

) = field(default="1024x1024", kw_only=True)
response_format: Literal["b64_json"] = field(default="b64_json", kw_only=True)

def generate_image(self, prompts: list[str], **kwargs) -> ImageArtifact:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing negative_prompts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenAI's API doesn't support negative prompts. The prompt is rewritten, so the specifics of any 'do not' instructions have an opportunity to be warped by rewriting or summarization. From my experimentation negative prompts in a unified prompt seem to be counterproductive, like negative prompts: text, clouds results in more text and clouds in the image than not including the negative prompts at all.

Copy link
Member

@collindutter collindutter Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the method signature should still include it though, right? Maybe we throw an error if it's provided.

kwargs is masking the fact that the parameters do not line up properly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize now that you might just mean in the signature (i.e. instead of **kwargs). Updated to make that change. If I'm mistake, let me know.

prompt=prompt,
)

@staticmethod
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a static method? I don't think we really use @staticmethod anywhere else.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, it is a static method but if we don't use that internally I'll remove.

Copy link
Member

@vasinov vasinov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

@collindutter collindutter force-pushed the french/image-generation branch from 3b119cf to 0907d69 Compare November 16, 2023 20:15
@collindutter collindutter merged commit 828b3f5 into dev Nov 16, 2023
@andrewfrench andrewfrench deleted the french/image-generation branch November 16, 2023 21:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add image generation functionality
3 participants