-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add tensorflow hub extractor #433
Conversation
@adelavega @tyarkoni the current implementation could be one way to go about this. There's a generic extractor (TFHubExtractor) where you can in principle pass whatever TFHub model, it will pack the output into a number of feature columns with generic names or into a columns with custom names if you pass them (if only one name is passed, it packs everything into a column). It's a class that allows you to use any model, but at your own risk as it remains rather abstract in terms of input type, input preprocessing and output postprocessing. Then, there's are two modality-specific extractors, one for This should work for most image classification and embedding models and text embedding models. In terms of output, these extractors are also pretty flexible. If you do not pass any labels/feature names, the extractor will just split the output in a number of features/columns with generic names ( We could also decide to be more specific and always constrain the dimensionality of the input to Let me know what you think - I'll go ahead and write tests if you like the current approach. |
one more note: right now the generic extractor crashes ( Is this expected behavior? In my understanding, specifying If, though, this is expected behavior and it's me misunderstanding something - shall I create modality-specific classes for audio and video where |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very minor comments I leave to your discretion, otherwise, looks great!
url_or_path (str): url or path to TFHub model. You can | ||
browse models at https://tfhub.dev/. | ||
task (str): model task/domain identifier | ||
features (optional): list of labels (for classification) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might need more explanation; some examples might be helpful, especially if the semantics depend on the model being loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added more explanation in new commits
pliers/extractors/models.py
Outdated
|
||
def __init__(self, url_or_path, features=None, task=None, | ||
transform_out=None, **kwargs): | ||
verify_dependencies(['tensorflow', 'tensorflow_hub', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are all of these dependencies needed here? At least from the code, it looks like it might only be tensorflow_hub
(for KerasLayer
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only hub is required, indeed.
Removed tensorflow dependency, and moved attempt to import tensorflow_text to the text extractor.
It is only needed for some models and there's no way to know if it is needed until the model is called, so I've added a warning at initialization when import fails.
…d warning for tensorflow_text
@tyarkoni @adelavega I've added test and implemented Tal's suggestions so is ready for new review. Tests pass locally but they are not triggered automatically on Travis, and I'm not sure if I can start them manually and how. As for the BERT extractor, I expect these tests may cause some memory issues on Travis (each extractor is tested on more than one model) - but let's see what happens. By the way, there is probably also some updating to do to the BERT extractors, I can take care of that in a separate PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just one minor suggestion.
self.transform_inp = transform_inp | ||
super().__init__() | ||
|
||
def get_feature_names(self, out): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is currently publicly exposed, so we might want to move the check for self.features
inside here (instead of doing it in _extract
) and use that if available. Otherwise a user might naively call get_feature_names
expecting to get the stored feature names, and instead they'll get the naive enumeration of feature_*
.
Alternatively, if it's not meant to be public, maybe rename to _get_feature_names
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the feedback and suggestion, Tal! Moved check for self.features
inside get_feature_names
.
@rbroc I think when running using forked mode, the pytest print out becomes very ugly and it looks like it hangs, but it doesn't. It looks like there is one failure in 3.7 and 3.8 which is: I pushed a commit that:
I'm curious to see how much slower it will be with |
Looks like @rbroc it's a bit out of the context of this PR so perhaps could merge this and deal with this later, but only thing left is to change the |
thank you, @adelavega! 🙏 I've opened a separate issue for |
Sounds good, let's merge! |
First draft of a (hierarchy of?) Tensorflow Hub extractor(s).
In current implementation, a generic TFHub extractor is implemented which is agnostic to stimulus type + two subclasses for embedding models and classification models - at the moment, these are pretty much the only two task types for which input and output seem fairly standardized (see https://www.tensorflow.org/hub/common_saved_model_apis/images and https://www.tensorflow.org/hub/common_saved_model_apis/text).
Closes #428.