Skip to content
This repository has been archived by the owner on Jul 27, 2022. It is now read-only.

Integrate machine translation to Darma chat #6

Open
wants to merge 1 commit into
base: darma
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion parlai/crowdsourcing/tasks/darma_chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ This task is adapted from `https://github.com/isi-nlp/ParlAI/tree/main/parlai/cr
3. Go to ParlAI main directory (i.e. `cd ~/ParlAI`) and install ParlAI in development mode `pip3 install -e . `
4. Go back to the Mephisto directory and install all the required packages: `pip install -r requirements.txt`
5. Manually install the pip incompatibilities for Mephisto by running the following command
```
```bash
pip3 install zipp==3.1.0
pip3 install importlib-metadata==1.6.0
pip3 install atomicwrites==1.3.0
Expand Down Expand Up @@ -76,6 +76,39 @@ Here, we map frontend customizations and the corresponding scripts that need to
- A static variable keeps track of the index.
- Currently, only one assignment is created for a single conversation seed.

## Enabling MT

> see `translator.py` for the code

Add this config block as `mephisto.blueprint.translator`

```yaml
translator:
activation: 'pre' # pre, post, pre+post, null
preprocess: rtg_api
preprocess_args:
# TODO: change the URL to DARMA hosted service
api_url: http://rtg.isi.edu/many-eng/v1/translate
postprocess: huggingface
postprocess_args:
model: Helsinki-NLP/opus-mt-en-fr
```
The key `activation` takes the following values
* `pre` - Only translate human input (via `preprocess` config)
* `post` - Only translate bot output (via `postprocess` config)
* `pre+post` - Enable both `pre` and `post`
* `null` - Disable MT. Which has same effect as deleting the whole `translator` config block

`preprocess` and `postprocess` takes the MT backend __name__.
Whereas `{pre,post}process_args` take a dictionary of arguments to MT backend.

__The following MT backends are supported__
* `rtg_api` which calls RTG over a REST API. See http://rtg.isi.edu/many-eng/
* `huggingface` calls `transformers` library. Requires `model` argument which can be obtained from https://huggingface.co/models?pipeline_tag=translation




## Debug logs/tips

- Q: I'm making changes to the front end and it seems like they are not reflected in my task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ mephisto:
block_qualification: darma_onboarding_block
final_rating_question: "How coherent was E? | How responsive was E to what you wrote? | Did E understand your point of view (as B)? | Did E convince you to change your behavior?"
max_onboard_time: 6000
translator:
activation: 'pre' # pre, post, pre+post, null
preprocess: rtg_api
preprocess_args:
# TODO: change the URL to DARMA hosted service
api_url: http://rtg.isi.edu/many-eng/v1/translate
postprocess: huggingface
postprocess_args:
model: Helsinki-NLP/opus-mt-en-fr

task:
allowed_concurrent: 1
assignment_duration_in_seconds: 6000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defaults:
- /mephisto/provider: mturk_sandbox
mephisto:
provider:
requester_name: jon_iam_sb_sandbox # Or whatever ID you provided with `mephisto register mturk_sandbox`
requester_name: tg_sandbox # Or whatever ID you provided with `mephisto register mturk_sandbox`
blueprint:
chat_data_folder: ${task_dir}/model_chat/
consent_data_folder: ${task_dir}/consent/
Expand All @@ -25,6 +25,16 @@ mephisto:
block_qualification: darma_onboarding_block
final_rating_question: "How coherent was E? | How responsive was E to what you wrote? | Did E understand your point of view (as B)? | Did E convince you to change your behavior?"
max_onboard_time: 6000
translator:
activation: 'pre' # pre, post, pre+post, null
preprocess: rtg_api
preprocess_args:
# TODO: change the URL to DARMA hosted service
api_url: http://rtg.isi.edu/many-eng/v1/translate
postprocess: huggingface
postprocess_args:
model: Helsinki-NLP/opus-mt-en-fr

task:
allowed_concurrent: 1
assignment_duration_in_seconds: 6000
Expand Down
5 changes: 5 additions & 0 deletions parlai/crowdsourcing/tasks/darma_chat/model_chat_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
'check_acceptability': args.blueprint.check_acceptability,
'chat_data_folder': args.blueprint.chat_data_folder,
'consent_data_folder': args.blueprint.consent_data_folder,
'translator': args.blueprint.translator if 'translator' in args.blueprint else None,
}
)

Expand Down Expand Up @@ -342,6 +343,10 @@ class ModelChatBlueprintArgs(BaseModelChatBlueprintArgs):
metadata={"help": "Path to file containing parlai world"},
)

translator: Dict[str, Any] = field(
default_factory=dict,
metadata = {"help": "settings to enable machine translation integration"}
)

@register_mephisto_abstraction()
class ModelChatBlueprint(BaseModelChatBlueprint):
Expand Down
89 changes: 89 additions & 0 deletions parlai/crowdsourcing/tasks/darma_chat/translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from abc import ABC, abstractmethod
import requests
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import logging as log


class BaseTranslator(ABC):

@abstractmethod
def translate(self, text: str, src_lang: str, tgt_lang: str, **args) -> str:
pass

def __call__(self, text, *args, **kwds) -> str:
return self.translate(text, *args, **kwds)


class RtgApiTranslator(BaseTranslator):

def __init__(self, api_url='http://rtg.isi.edu/many-eng/v1/translate') -> None:
self.api_url = api_url

def translate(self, text: str, src_lang ='mul', tgt_lang='eng') -> str:
source = {'source': [text]}
log.debug(f'Sending source to RTG for translation: {source}')
try:
response = requests.post(self.api_url, json=source)
if response.ok:
response = response.json()
log.info(f'RTG translation: {response["translation"]}')
return response["translation"][0]
else:
log.warning(f'Translation failed {response.status_code} -> {response.reason}')
log.warning(f'Response Body from RTG:\n{response.json()}')
return text
except Exception as e:
log.error(f'Error connecting to RTG API: {e}. Returning source.')
return text


class HuggingFaceTranslator(BaseTranslator):

def __init__(self, model='Helsinki-NLP/opus-mt-en-fr') -> None:
self.model_id = model
self.model = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer = AutoModelForSeq2SeqLM.from_pretrained(self.model_id)

def translate(self, text: str, src_lang ='mul', tgt_lang='eng') -> str:
log.debug(f"Translating {text}...")
batch = self.tokenizer([text], return_tensors="pt")
gen = self.model.generate(**batch)
fr_response = self.tokenizer.batch_decode(gen, skip_special_tokens=True)
return fr_response[0]



# TODO add other MTs
registry = {
'rtg_api': RtgApiTranslator,
'huggingface': HuggingFaceTranslator
}


def get_translator(name, args):
assert name in registry, f'{name} is invalid; supported: {registry.keys}'
log.info(f"creating MT: name={name} args={args}")
args = args or dict()
return registry[name](**args)


class DialogTranslator:

def __init__(self, pre_translator, post_translator=None) -> None:
assert pre_translator or post_translator,\
'Both pre- and post- processing MTs are None. Expected atleast one.'
log.info(f"Preprocess MT: {pre_translator}")
log.info(f"Postprocess MT: {post_translator}")
self.pre_translator = pre_translator
self.post_translator = post_translator

def maybe_preprocess(self, text):
if not self.pre_translator:
return text
return self.pre_translator(text)

def maybe_postprocess(self, text):
if not self.post_translator:
return text
return self.post_translator(text)
Loading