-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add InstructionalPipeline support and fix inference config bug #8252
Add InstructionalPipeline support and fix inference config bug #8252
Conversation
Signed-off-by: Ben Wilson <[email protected]>
Documentation preview for 58379b2 will be available here when this CircleCI job completes successfully. More info
|
self._conversation = None | ||
|
||
@staticmethod | ||
def _validate_inference_config_keys(pipeline, inference_config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does nothing. The inference overrides are arguments to the wrapped generate()
method (implicit call to the pipeline object) and are not associated with the pipeline object itself.
f"inference configuration keys. Invalid keys: {invalid_keys}", | ||
error_code=INVALID_PARAMETER_VALUE, | ||
) | ||
self._supported_custom_generator_types = {"InstructionTextGenerationPipeline"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is Dolly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment and say this is Dolly? InstructionTextGenerationPipeline
doesn't sound like Dolly at all. A link to the code also helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dolly is a general instruction GPT model. InstructionTextGenerationPipeline
seems appropriate here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
InstructionTextGenerationPipeline
is our own custom class name for Dolly and transformers
doesn't have it, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, it's a custom subclass of TextGenerationPipeline. I modified the logic to do an exact match of the as-saved pipeline instance type so that we don't have inadvertent manipulation of any other TextGenerationPipeline types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@harupy I added an NB note with a link to the Hub page (that contains the model card info) with a caveat to say that all variants are covered in this pipeline type parsing logic
@@ -1286,7 +1286,7 @@ def _load_pyfunc(path): | |||
""" | |||
local_path = pathlib.Path(path) | |||
flavor_configuration = _get_flavor_configuration(local_path, FLAVOR_NAME) | |||
inference_config = _get_inference_config(local_path) | |||
inference_config = _get_inference_config(local_path.joinpath(_COMPONENTS_BINARY_KEY)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the bug that prevented the inference config from being applied (part 1)
@@ -1461,12 +1444,24 @@ def _predict(self, data): | |||
conversation_output = self.pipeline(self._conversation) | |||
return conversation_output.generated_responses[-1] | |||
elif isinstance(data, dict): | |||
raw_output = self.pipeline(**data) | |||
if self.inference_config: | |||
raw_output = self.pipeline(**data, **self.inference_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the other bug (part 2) where the inference configuration was not being applied to the wrapped generate()
method of the pipeline. Now it is (confirmed with a note in the test suites)
@@ -1449,6 +1429,9 @@ def _predict(self, data): | |||
if not self._conversation: | |||
self._conversation = transformers.Conversation() | |||
self._conversation.add_user_input(data) | |||
elif type(self.pipeline).__name__ in self._supported_custom_generator_types: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this, we can't validate Dolly's input data for signature inference.
|
||
# Handle the pipeline outputs | ||
if isinstance(self.pipeline, transformers.FillMaskPipeline): | ||
if type(self.pipeline).__name__ in self._supported_custom_generator_types or isinstance( | ||
self.pipeline, transformers.TextGenerationPipeline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When loading, since the custom new pipeline type is not installed, the super class will be loaded for Dolly. Dolly's InstructionalTextGenerationPipeline inherits from TextGenerationPipeline, so it still works just fine when loaded through the pyfunc loader. We just need to make sure that we're parsing this output when loaded as the super class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change assumes any TextGenerationPipeline
will be stripped with the selected special chars. Is it desirable to do so? Since Dolly is the only OSS instructGPT, can we just check the model name matches Dolly and only filter for that?
# want to left-trim these types of pipelines output values. | ||
if data_out.startswith(data_in + "\n\n"): | ||
output_string = data_out[len(data_in) :].strip() | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normal (standard) TextGenerationPipelines don't have the concept of a "Question"; rather, they generate text continued from a prompt. We don't want to strip the initial prompt from those types of text generators (as it makes the output difficult to understand)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above, if we restrict stripping to only Dolly model series, there are less chance a mistake is made for the general text generation pipeline output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
100%. Updated to explicitly use the flavor configuration from the MLModel file to ensure that we're only engaging this logic if the original pipeline that was stored when saved was of the Dolly type.
@@ -1172,6 +1172,11 @@ def test_text2text_generation_pipeline_with_inference_configs( | |||
data, mlflow.transformers.generate_signature_output(text2text_generation_pipeline, data) | |||
) | |||
|
|||
# NB: The result of the 2nd test, without inference config overrides is: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This note explains how we're testing the inference configuration properly. By setting these overrides, we're controlling the output through the application of forced higher grammatical accuracy. (particularly setting the temperature to a drastically different value than the default enforces the greater complexity of output)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this comment above line 1163 and replace the 2nd test
with this test case
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
excellent point. moved
|
||
|
||
@pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON) | ||
@pytest.mark.skipcacheclean |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test definitely cannot run on Github Actions (the size of the raw model weights is 3x the size of available RAM on a GHA runner)
Signed-off-by: Ben Wilson <[email protected]>
mlflow/transformers.py
Outdated
# return statements, followed by the start of the response to the prompt. We only | ||
# want to left-trim these types of pipelines output values. | ||
if data_out.startswith(data_in + "\n\n"): | ||
output_string = data_out[len(data_in) :].strip() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we keep the input question if a user requests Dolly to do so?
# If the full text is requested, then append the decoded text to the original instruction.
# This technically isn't the full text, as we format the instruction in the prompt the model has been
# trained on, but to the client it will appear to be the full text.
if return_full_text:
decoded = f"{instruction_text}\n{decoded}"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added an inference_config pop functionality to check for disabling the prompt inclusion.
mlflow/transformers.py
Outdated
Parse the output from instruction pipelines to conform with other text generator | ||
pipeline types and remove line feed characters and other confusing outputs | ||
""" | ||
replacements = {"\n\n": " ", "A:": ""} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it really safe to remove "\n\n"
or A:
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leading and trailing whitespace are probably ok to remove. Not sure if Dolly can understand python, but if we ask:
Is `X[A:B]` a valid python code?
then Dolly answers:
Yes, `X[A:B]` a valid python code.
In this case, removing A:
is not ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an actual fix being tracked? I do agree this change is rather hacky.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If anyone uses this in batch mode, they're going to have to parse the output. This is the only generation pipeline that I've found that inserts newline characters in the output. It works for demos and interactive mode, but it's jarring when compared to everything else on HuggingFace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I restricted this logic to only be accessible if the underlying Pipeline type is logged in the MLmodel is "InstructionTextGenerationPipeline" so that we don't inadvertently modify any other TextGenerationPipeline outputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also changing this parse logic to only look at the start of the response string after removing the prompt instead of replacing everywhere.
mlflow/transformers.py
Outdated
if data_out.startswith(data_in + "\n\n"): | ||
output_string = data_out[len(data_in) :].strip() | ||
else: | ||
output_string = data_out.strip() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output_string = data_out.strip() | |
output_string = data_out.lstrip() |
If we just want to left-trim, we can use lstrip
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to lstrip
mlflow/transformers.py
Outdated
return output_string | ||
|
||
if isinstance(input_data, list) and isinstance(output, list): | ||
zipped = list(zip(input_data, output)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zipped = list(zip(input_data, output)) | |
zipped = zip(input_data, output) |
list
can be revmoed or you can directly iterate on this like [x for x in zip(input_data, output)]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great point. simplified!
f"inference configuration keys. Invalid keys: {invalid_keys}", | ||
error_code=INVALID_PARAMETER_VALUE, | ||
) | ||
self._supported_custom_generator_types = {"InstructionTextGenerationPipeline"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dolly is a general instruction GPT model. InstructionTextGenerationPipeline
seems appropriate here.
mlflow/transformers.py
Outdated
Parse the output from instruction pipelines to conform with other text generator | ||
pipeline types and remove line feed characters and other confusing outputs | ||
""" | ||
replacements = {"\n\n": " ", "A:": ""} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an actual fix being tracked? I do agree this change is rather hacky.
|
||
# Handle the pipeline outputs | ||
if isinstance(self.pipeline, transformers.FillMaskPipeline): | ||
if type(self.pipeline).__name__ in self._supported_custom_generator_types or isinstance( | ||
self.pipeline, transformers.TextGenerationPipeline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change assumes any TextGenerationPipeline
will be stripped with the selected special chars. Is it desirable to do so? Since Dolly is the only OSS instructGPT, can we just check the model name matches Dolly and only filter for that?
# want to left-trim these types of pipelines output values. | ||
if data_out.startswith(data_in + "\n\n"): | ||
output_string = data_out[len(data_in) :].strip() | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above, if we restrict stripping to only Dolly model series, there are less chance a mistake is made for the general text generation pipeline output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after addressing remaining comments
Signed-off-by: Ben Wilson <[email protected]>
Related Issues/PRs
#xxxWhat changes are proposed in this pull request?
Adds support for Dolly and resolves a bug in inference configuration overrides
How is this patch tested?
Manually validated multiple inference configuration settings for pyfunc predict overrides to the generate method of pipelines, validated that using Dolly as pyfunc functions correctly and removes question repetition, newline characters, and the random
A:
artifacts that appear in some types of questions' responses.Does this PR change the documentation?
Release Notes
Is this a user-facing change?
(Details in 1-2 sentences. You can just refer to another PR with a description if this PR is part of a larger change.)
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/recipes
: Recipes, Recipe APIs, Recipe configs, Recipe Templatesarea/projects
: MLproject format, project running backendsarea/scoring
: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra
: MLflow Tracking server backendarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes