Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InstructionalPipeline support and fix inference config bug #8252

Merged
merged 3 commits into from
Apr 18, 2023

Conversation

BenWilson2
Copy link
Member

Related Issues/PRs

#xxx

What changes are proposed in this pull request?

Adds support for Dolly and resolves a bug in inference configuration overrides

How is this patch tested?

  • Existing unit/integration tests
  • New unit/integration tests
  • Manual tests (describe details, including test results, below)

Manually validated multiple inference configuration settings for pyfunc predict overrides to the generate method of pipelines, validated that using Dolly as pyfunc functions correctly and removes question repetition, newline characters, and the random A: artifacts that appear in some types of questions' responses.

Does this PR change the documentation?

  • No. You can skip the rest of this section.
  • Yes. Make sure the changed pages / sections render correctly in the documentation preview.

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

(Details in 1-2 sentences. You can just refer to another PR with a description if this PR is part of a larger change.)

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

@github-actions github-actions bot added area/models MLmodel format, model serialization/deserialization, flavors rn/none List under Small Changes in Changelogs. labels Apr 18, 2023
@mlflow-automation
Copy link
Collaborator

mlflow-automation commented Apr 18, 2023

Documentation preview for 58379b2 will be available here when this CircleCI job completes successfully.

More info

self._conversation = None

@staticmethod
def _validate_inference_config_keys(pipeline, inference_config):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does nothing. The inference overrides are arguments to the wrapped generate() method (implicit call to the pipeline object) and are not associated with the pipeline object itself.

f"inference configuration keys. Invalid keys: {invalid_keys}",
error_code=INVALID_PARAMETER_VALUE,
)
self._supported_custom_generator_types = {"InstructionTextGenerationPipeline"}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is Dolly.

Copy link
Member

@harupy harupy Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment and say this is Dolly? InstructionTextGenerationPipeline doesn't sound like Dolly at all. A link to the code also helps.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dolly is a general instruction GPT model. InstructionTextGenerationPipeline seems appropriate here.

Copy link
Member

@harupy harupy Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InstructionTextGenerationPipeline is our own custom class name for Dolly and transformers doesn't have it, right?

https://github.com/databrickslabs/dolly/blob/0eadcb7b0648d496d67243a7d572b413560be661/training/generate.py#L65

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, it's a custom subclass of TextGenerationPipeline. I modified the logic to do an exact match of the as-saved pipeline instance type so that we don't have inadvertent manipulation of any other TextGenerationPipeline types

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@harupy I added an NB note with a link to the Hub page (that contains the model card info) with a caveat to say that all variants are covered in this pipeline type parsing logic

@@ -1286,7 +1286,7 @@ def _load_pyfunc(path):
"""
local_path = pathlib.Path(path)
flavor_configuration = _get_flavor_configuration(local_path, FLAVOR_NAME)
inference_config = _get_inference_config(local_path)
inference_config = _get_inference_config(local_path.joinpath(_COMPONENTS_BINARY_KEY))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the bug that prevented the inference config from being applied (part 1)

@@ -1461,12 +1444,24 @@ def _predict(self, data):
conversation_output = self.pipeline(self._conversation)
return conversation_output.generated_responses[-1]
elif isinstance(data, dict):
raw_output = self.pipeline(**data)
if self.inference_config:
raw_output = self.pipeline(**data, **self.inference_config)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the other bug (part 2) where the inference configuration was not being applied to the wrapped generate() method of the pipeline. Now it is (confirmed with a note in the test suites)

@@ -1449,6 +1429,9 @@ def _predict(self, data):
if not self._conversation:
self._conversation = transformers.Conversation()
self._conversation.add_user_input(data)
elif type(self.pipeline).__name__ in self._supported_custom_generator_types:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this, we can't validate Dolly's input data for signature inference.


# Handle the pipeline outputs
if isinstance(self.pipeline, transformers.FillMaskPipeline):
if type(self.pipeline).__name__ in self._supported_custom_generator_types or isinstance(
self.pipeline, transformers.TextGenerationPipeline
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When loading, since the custom new pipeline type is not installed, the super class will be loaded for Dolly. Dolly's InstructionalTextGenerationPipeline inherits from TextGenerationPipeline, so it still works just fine when loaded through the pyfunc loader. We just need to make sure that we're parsing this output when loaded as the super class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change assumes any TextGenerationPipeline will be stripped with the selected special chars. Is it desirable to do so? Since Dolly is the only OSS instructGPT, can we just check the model name matches Dolly and only filter for that?

# want to left-trim these types of pipelines output values.
if data_out.startswith(data_in + "\n\n"):
output_string = data_out[len(data_in) :].strip()
else:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normal (standard) TextGenerationPipelines don't have the concept of a "Question"; rather, they generate text continued from a prompt. We don't want to strip the initial prompt from those types of text generators (as it makes the output difficult to understand)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above, if we restrict stripping to only Dolly model series, there are less chance a mistake is made for the general text generation pipeline output.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100%. Updated to explicitly use the flavor configuration from the MLModel file to ensure that we're only engaging this logic if the original pipeline that was stored when saved was of the Dolly type.

@@ -1172,6 +1172,11 @@ def test_text2text_generation_pipeline_with_inference_configs(
data, mlflow.transformers.generate_signature_output(text2text_generation_pipeline, data)
)

# NB: The result of the 2nd test, without inference config overrides is:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This note explains how we're testing the inference configuration properly. By setting these overrides, we're controlling the output through the application of forced higher grammatical accuracy. (particularly setting the temperature to a drastically different value than the default enforces the greater complexity of output)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this comment above line 1163 and replace the 2nd test with this test case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

excellent point. moved



@pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
@pytest.mark.skipcacheclean
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test definitely cannot run on Github Actions (the size of the raw model weights is 3x the size of available RAM on a GHA runner)

Signed-off-by: Ben Wilson <[email protected]>
# return statements, followed by the start of the response to the prompt. We only
# want to left-trim these types of pipelines output values.
if data_out.startswith(data_in + "\n\n"):
output_string = data_out[len(data_in) :].strip()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep the input question if a user requests Dolly to do so?

https://github.com/databrickslabs/dolly/blob/0eadcb7b0648d496d67243a7d572b413560be661/training/generate.py#L207

            # If the full text is requested, then append the decoded text to the original instruction.
            # This technically isn't the full text, as we format the instruction in the prompt the model has been
            # trained on, but to the client it will appear to be the full text.
            if return_full_text:
                decoded = f"{instruction_text}\n{decoded}"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an inference_config pop functionality to check for disabling the prompt inclusion.

Parse the output from instruction pipelines to conform with other text generator
pipeline types and remove line feed characters and other confusing outputs
"""
replacements = {"\n\n": " ", "A:": ""}
Copy link
Member

@harupy harupy Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really safe to remove "\n\n" or A:?

Copy link
Member

@harupy harupy Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leading and trailing whitespace are probably ok to remove. Not sure if Dolly can understand python, but if we ask:

Is `X[A:B]` a valid python code?

then Dolly answers:

Yes, `X[A:B]` a valid python code.

In this case, removing A: is not ok.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an actual fix being tracked? I do agree this change is rather hacky.

Copy link
Member

@harupy harupy Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at Dolly's traning data. As you can see, some responses start with A:. Is this the reason we need to strip (not replace) A:?

image

As Dolly's trained on this dataset, it might response with A: ... for certain questions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If anyone uses this in batch mode, they're going to have to parse the output. This is the only generation pipeline that I've found that inserts newline characters in the output. It works for demos and interactive mode, but it's jarring when compared to everything else on HuggingFace

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I restricted this logic to only be accessible if the underlying Pipeline type is logged in the MLmodel is "InstructionTextGenerationPipeline" so that we don't inadvertently modify any other TextGenerationPipeline outputs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also changing this parse logic to only look at the start of the response string after removing the prompt instead of replacing everywhere.

if data_out.startswith(data_in + "\n\n"):
output_string = data_out[len(data_in) :].strip()
else:
output_string = data_out.strip()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_string = data_out.strip()
output_string = data_out.lstrip()

If we just want to left-trim, we can use lstrip.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to lstrip

return output_string

if isinstance(input_data, list) and isinstance(output, list):
zipped = list(zip(input_data, output))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
zipped = list(zip(input_data, output))
zipped = zip(input_data, output)

list can be revmoed or you can directly iterate on this like [x for x in zip(input_data, output)]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great point. simplified!

f"inference configuration keys. Invalid keys: {invalid_keys}",
error_code=INVALID_PARAMETER_VALUE,
)
self._supported_custom_generator_types = {"InstructionTextGenerationPipeline"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dolly is a general instruction GPT model. InstructionTextGenerationPipeline seems appropriate here.

Parse the output from instruction pipelines to conform with other text generator
pipeline types and remove line feed characters and other confusing outputs
"""
replacements = {"\n\n": " ", "A:": ""}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an actual fix being tracked? I do agree this change is rather hacky.


# Handle the pipeline outputs
if isinstance(self.pipeline, transformers.FillMaskPipeline):
if type(self.pipeline).__name__ in self._supported_custom_generator_types or isinstance(
self.pipeline, transformers.TextGenerationPipeline
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change assumes any TextGenerationPipeline will be stripped with the selected special chars. Is it desirable to do so? Since Dolly is the only OSS instructGPT, can we just check the model name matches Dolly and only filter for that?

# want to left-trim these types of pipelines output values.
if data_out.startswith(data_in + "\n\n"):
output_string = data_out[len(data_in) :].strip()
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above, if we restrict stripping to only Dolly model series, there are less chance a mistake is made for the general text generation pipeline output.

Copy link
Collaborator

@WeichenXu123 WeichenXu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after addressing remaining comments

Signed-off-by: Ben Wilson <[email protected]>
@BenWilson2 BenWilson2 merged commit 8c0027c into mlflow:master Apr 18, 2023
@BenWilson2 BenWilson2 deleted the add-transformers-dolly-support branch April 18, 2023 17:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/models MLmodel format, model serialization/deserialization, flavors rn/none List under Small Changes in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants