-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add image generation functionality #407
Conversation
griptape/artifacts/image_artifact.py
Outdated
|
||
@define(frozen=True) | ||
class ImageArtifact(BlobArtifact): | ||
mime_type: str = field(default="image/png", kw_only=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this should be the default. Can we make it more generic? Seems like we might have options.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree that a default doesn't quite make sense here. Updated to require the mime_type value.
from griptape.artifacts import ImageArtifact | ||
|
||
|
||
class BaseImageGenerationDriver: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add (ABC)
.
|
||
|
||
@define | ||
class Dalle2ImageGenerationDriver(BaseImageGenerationDriver): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make it a generic DalleImageGenerationDriver
and manage version at the class property level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we prepend with OpenAi
to match other driver naming schemas?
import requests | ||
from attr import field, Factory, define | ||
from griptape.artifacts import ImageArtifact | ||
from griptape.drivers.image_generation.base_image_generation_driver import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop base_image_generation_driver
.
|
||
|
||
@define | ||
class PromptImageGenerationEngine(BaseImageGenerationEngine): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make it a generic ImageGenerationEngine
engine and add multiple methods for text-to-image and image-to-image modes? Or should we have different engines that encapsulate text-to-image (i.e., PromptImageGenerationEngine
) and image-to-image?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've gone back and forth on it. This implementation represents the assumption that the interface and mechanics for text-to-image and image-to-image will be substantially different or that providers (and their drivers) might only provide a subset of that functionality. We could overcome each of those obstacles, just chose a path here.
image_generation_driver: BaseImageGenerationDriver = field( | ||
default=Factory(lambda: Dalle2ImageGenerationDriver()), kw_only=True | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove this to simplify the interface? The user can always pass a custom driver into image_generation_engine
.
Description: """{{ description }}""" | ||
|
||
Generate an image based on the description provided. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's sync up with @shhlife, @Amaru-Zeas, or @averyroche about what the best base prompt is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. I think we should leave the flowery specifics to the description field, but I can imagine reinforcing instructions like 'stick to the description' or optionally 'be creative', etc.
image_generation_driver: BaseImageGenerationDriver = field( | ||
default=Factory(lambda: Dalle2ImageGenerationDriver()), kw_only=True | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as in the task: could probably remove this.
04fcba4
to
ff54a45
Compare
6538b39
to
b17d023
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, this is going to be awesome!
griptape/drivers/image_generation/amazon_bedrock_stable_diffusion_image_generation_driver.py
Outdated
Show resolved
Hide resolved
griptape/drivers/image_generation/amazon_bedrock_stable_diffusion_image_generation_driver.py
Outdated
Show resolved
Hide resolved
griptape/drivers/image_generation/amazon_bedrock_stable_diffusion_image_generation_driver.py
Outdated
Show resolved
Hide resolved
griptape/drivers/image_generation/base_image_generation_driver.py
Outdated
Show resolved
Hide resolved
griptape/drivers/image_generation/openai_dalle_image_generation_driver.py
Show resolved
Hide resolved
griptape/drivers/image_generation/openai_dalle_image_generation_driver.py
Outdated
Show resolved
Hide resolved
griptape/drivers/image_generation/openai_dalle_image_generation_driver.py
Outdated
Show resolved
Hide resolved
image_size: Union[ | ||
Literal["256x256"], Literal["512x512"], Literal["1024x1024"], Literal["1024x1792"], Literal["1792x1024"] | ||
] = field(default=Literal["512x512"], kw_only=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how I feel about this -- the patterns we surface to users seem to be heavily inspired by the underlying SDKs being interfaced with. In this case (and in OpenAI Chat Prompt Driver) we expose the same Literal values the OpenAI SDK expects, but we don't see this elsewhere in the framework. Should we define a more concrete Griptape style and transform inputs to what's expected by the dependencies? Are the utilities provided by typing
the way we want to go?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's somewhat unavoidable to expose SDK patterns in Drivers since their primary purpose is to sit right on top of the SDK/API. Even in the case of image_size
, a seemingly universal field, it seems non-straightforward to implement it at the BaseImageGenerationDriver
level. I think providing helpful type hints is a good enough solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's certainly less maintenance to do it this way. If we define and enforce a standard Griptape interface style, the intermediate layer would be a constant source of frustration as we'd have to re-define (or somehow handle) all possible options to some degree. On the other hand, we require the user do some research and have knowledge of the underlying dependency (OpenAI SDK, Leonardo API) before use. Not at all unreasonable, but a bit annoying.
griptape/drivers/image_generation/openai_dalle_image_generation_driver.py
Outdated
Show resolved
Hide resolved
@@ -10,5 +10,5 @@ class BaseImageGenerationDriver(ABC): | |||
model: str = field(kw_only=True) | |||
|
|||
@abstractmethod | |||
def generate_image(self, prompts: list[str], negative_prompts: list[str], **kwargs) -> ImageArtifact: | |||
def generate_image(self, prompts: list[str], negative_prompts: list[str] = list, **kwargs) -> ImageArtifact: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python gotcha https://docs.python-guide.org/writing/gotchas/
@@ -10,5 +10,5 @@ class BaseImageGenerationDriver(ABC): | |||
model: str = field(kw_only=True) | |||
|
|||
@abstractmethod | |||
def generate_image(self, prompts: list[str], negative_prompts: list[str] = list, **kwargs) -> ImageArtifact: | |||
def generate_image(self, prompts: list[str], negative_prompts: list[str] = None, **kwargs) -> ImageArtifact: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list[str]
should be Optional[list[str]]
@@ -31,7 +31,10 @@ class LeonardoImageGenerationDriver(BaseImageGenerationDriver): | |||
image_width: int = field(default=512, kw_only=True) | |||
image_height: int = field(default=512, kw_only=True) | |||
|
|||
def generate_image(self, prompts: list[str], negative_prompts: list[str], **kwargs) -> ImageArtifact: | |||
def generate_image(self, prompts: list[str], negative_prompts: list[str] = None, **kwargs) -> ImageArtifact: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional[list[str]]
negative_prompts: list[str] = list, | ||
rulesets: list[Ruleset] = list, | ||
negative_rulesets: list[Ruleset] = list, | ||
negative_prompts: list[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional
**kwargs | ||
): | ||
if not negative_prompts: | ||
negative_prompts = [] | ||
|
||
for ruleset in rulesets: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to initialize rulesets
and negative_rulesets
to empty lists of None
negative_prompts: Optional[list[str]] = None, | ||
rulesets: Optional[list[Ruleset]] = None, | ||
negative_rulesets: Optional[list[Ruleset]] = None, | ||
**kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need kwargs
? As far as I can tell it's not being used anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was needed when additional parameters came through the method call and not the driver setup. Removed.
https://docs.leonardo.ai/reference/creategeneration | ||
""" | ||
|
||
api_key: str = field(default=Factory(lambda: os.environ.get("LEONARDO_API_KEY")), kw_only=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this change didn't get implemented
def _get_image_url(self, generation_id: str): | ||
for attempt in range(self.max_attempts): | ||
response = self.requests_session.get( | ||
url=f"{self.api_base}/generations/{generation_id}", headers={"authorization": f"Bearer {self.api_key}"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Capitalize authorization
griptape/drivers/image_generation/leonardo_image_generation_driver.py
Outdated
Show resolved
Hide resolved
) = field(default="1024x1024", kw_only=True) | ||
response_format: Literal["b64_json"] = field(default="b64_json", kw_only=True) | ||
|
||
def generate_image(self, prompts: list[str], **kwargs) -> ImageArtifact: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing negative_prompts
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OpenAI's API doesn't support negative prompts. The prompt is rewritten, so the specifics of any 'do not' instructions have an opportunity to be warped by rewriting or summarization. From my experimentation negative prompts in a unified prompt seem to be counterproductive, like negative prompts: text, clouds
results in more text and clouds in the image than not including the negative prompts at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the method signature should still include it though, right? Maybe we throw an error if it's provided.
kwargs
is masking the fact that the parameters do not line up properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realize now that you might just mean in the signature (i.e. instead of **kwargs
). Updated to make that change. If I'm mistake, let me know.
prompt=prompt, | ||
) | ||
|
||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a static method? I don't think we really use @staticmethod
anywhere else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, it is a static method but if we don't use that internally I'll remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
3b119cf
to
0907d69
Compare
This PR adds image generation functionality, including driver support for the following providers/models:
along with unit tests for the above drivers. A new tool, ImageGenerationTool, accepts an ImageGenerationEngine configured to use the desired driver:
Resolves #332