From 3fa9fa3ca0f3eca79f5aa80bd4d4fbee53eebe73 Mon Sep 17 00:00:00 2001 From: Thamme Gowda Date: Tue, 5 Jul 2022 19:12:25 -0700 Subject: [PATCH] Inetgrate machine translation to darma_chat --- .../crowdsourcing/tasks/darma_chat/README.md | 35 ++- .../darma_chat/hydra_configs/conf/darma.yaml | 10 + .../hydra_configs/conf/mturk_sandbox.yaml | 12 +- .../tasks/darma_chat/model_chat_blueprint.py | 5 + .../tasks/darma_chat/translator.py | 89 ++++++++ .../crowdsourcing/tasks/darma_chat/worlds.py | 208 +++++++++--------- 6 files changed, 252 insertions(+), 107 deletions(-) create mode 100644 parlai/crowdsourcing/tasks/darma_chat/translator.py diff --git a/parlai/crowdsourcing/tasks/darma_chat/README.md b/parlai/crowdsourcing/tasks/darma_chat/README.md index 3daefc6a852..11363fbfff9 100644 --- a/parlai/crowdsourcing/tasks/darma_chat/README.md +++ b/parlai/crowdsourcing/tasks/darma_chat/README.md @@ -14,7 +14,7 @@ This task is adapted from `https://github.com/isi-nlp/ParlAI/tree/main/parlai/cr 3. Go to ParlAI main directory (i.e. `cd ~/ParlAI`) and install ParlAI in development mode `pip3 install -e . ` 4. Go back to the Mephisto directory and install all the required packages: `pip install -r requirements.txt` 5. Manually install the pip incompatibilities for Mephisto by running the following command - ``` + ```bash pip3 install zipp==3.1.0 pip3 install importlib-metadata==1.6.0 pip3 install atomicwrites==1.3.0 @@ -76,6 +76,39 @@ Here, we map frontend customizations and the corresponding scripts that need to - A static variable keeps track of the index. - Currently, only one assignment is created for a single conversation seed. +## Enabling MT + +> see `translator.py` for the code + +Add this config block as `mephisto.blueprint.translator` + +```yaml +translator: + activation: 'pre' # pre, post, pre+post, null + preprocess: rtg_api + preprocess_args: + # TODO: change the URL to DARMA hosted service + api_url: http://rtg.isi.edu/many-eng/v1/translate + postprocess: huggingface + postprocess_args: + model: Helsinki-NLP/opus-mt-en-fr +``` +The key `activation` takes the following values +* `pre` - Only translate human input (via `preprocess` config) +* `post` - Only translate bot output (via `postprocess` config) +* `pre+post` - Enable both `pre` and `post` +* `null` - Disable MT. Which has same effect as deleting the whole `translator` config block + +`preprocess` and `postprocess` takes the MT backend __name__. +Whereas `{pre,post}process_args` take a dictionary of arguments to MT backend. + +__The following MT backends are supported__ +* `rtg_api` which calls RTG over a REST API. See http://rtg.isi.edu/many-eng/ +* `huggingface` calls `transformers` library. Requires `model` argument which can be obtained from https://huggingface.co/models?pipeline_tag=translation + + + + ## Debug logs/tips - Q: I'm making changes to the front end and it seems like they are not reflected in my task. diff --git a/parlai/crowdsourcing/tasks/darma_chat/hydra_configs/conf/darma.yaml b/parlai/crowdsourcing/tasks/darma_chat/hydra_configs/conf/darma.yaml index 29eeda5c5ef..1689f7a4c91 100644 --- a/parlai/crowdsourcing/tasks/darma_chat/hydra_configs/conf/darma.yaml +++ b/parlai/crowdsourcing/tasks/darma_chat/hydra_configs/conf/darma.yaml @@ -22,6 +22,16 @@ mephisto: block_qualification: darma_onboarding_block final_rating_question: "How coherent was E? | How responsive was E to what you wrote? | Did E understand your point of view (as B)? | Did E convince you to change your behavior?" max_onboard_time: 6000 + translator: + activation: 'pre' # pre, post, pre+post, null + preprocess: rtg_api + preprocess_args: + # TODO: change the URL to DARMA hosted service + api_url: http://rtg.isi.edu/many-eng/v1/translate + postprocess: huggingface + postprocess_args: + model: Helsinki-NLP/opus-mt-en-fr + task: allowed_concurrent: 1 assignment_duration_in_seconds: 6000 diff --git a/parlai/crowdsourcing/tasks/darma_chat/hydra_configs/conf/mturk_sandbox.yaml b/parlai/crowdsourcing/tasks/darma_chat/hydra_configs/conf/mturk_sandbox.yaml index e0243db51f7..fd6c00f7d3f 100644 --- a/parlai/crowdsourcing/tasks/darma_chat/hydra_configs/conf/mturk_sandbox.yaml +++ b/parlai/crowdsourcing/tasks/darma_chat/hydra_configs/conf/mturk_sandbox.yaml @@ -5,7 +5,7 @@ defaults: - /mephisto/provider: mturk_sandbox mephisto: provider: - requester_name: jon_iam_sb_sandbox # Or whatever ID you provided with `mephisto register mturk_sandbox` + requester_name: tg_sandbox # Or whatever ID you provided with `mephisto register mturk_sandbox` blueprint: chat_data_folder: ${task_dir}/model_chat/ consent_data_folder: ${task_dir}/consent/ @@ -25,6 +25,16 @@ mephisto: block_qualification: darma_onboarding_block final_rating_question: "How coherent was E? | How responsive was E to what you wrote? | Did E understand your point of view (as B)? | Did E convince you to change your behavior?" max_onboard_time: 6000 + translator: + activation: 'pre' # pre, post, pre+post, null + preprocess: rtg_api + preprocess_args: + # TODO: change the URL to DARMA hosted service + api_url: http://rtg.isi.edu/many-eng/v1/translate + postprocess: huggingface + postprocess_args: + model: Helsinki-NLP/opus-mt-en-fr + task: allowed_concurrent: 1 assignment_duration_in_seconds: 6000 diff --git a/parlai/crowdsourcing/tasks/darma_chat/model_chat_blueprint.py b/parlai/crowdsourcing/tasks/darma_chat/model_chat_blueprint.py index 0265a39ce40..79e00aaed9f 100644 --- a/parlai/crowdsourcing/tasks/darma_chat/model_chat_blueprint.py +++ b/parlai/crowdsourcing/tasks/darma_chat/model_chat_blueprint.py @@ -256,6 +256,7 @@ def __init__( 'check_acceptability': args.blueprint.check_acceptability, 'chat_data_folder': args.blueprint.chat_data_folder, 'consent_data_folder': args.blueprint.consent_data_folder, + 'translator': args.blueprint.translator if 'translator' in args.blueprint else None, } ) @@ -342,6 +343,10 @@ class ModelChatBlueprintArgs(BaseModelChatBlueprintArgs): metadata={"help": "Path to file containing parlai world"}, ) + translator: Dict[str, Any] = field( + default_factory=dict, + metadata = {"help": "settings to enable machine translation integration"} + ) @register_mephisto_abstraction() class ModelChatBlueprint(BaseModelChatBlueprint): diff --git a/parlai/crowdsourcing/tasks/darma_chat/translator.py b/parlai/crowdsourcing/tasks/darma_chat/translator.py new file mode 100644 index 00000000000..ac62bc1a58e --- /dev/null +++ b/parlai/crowdsourcing/tasks/darma_chat/translator.py @@ -0,0 +1,89 @@ +from abc import ABC, abstractmethod +import requests +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + +import logging as log + + +class BaseTranslator(ABC): + + @abstractmethod + def translate(self, text: str, src_lang: str, tgt_lang: str, **args) -> str: + pass + + def __call__(self, text, *args, **kwds) -> str: + return self.translate(text, *args, **kwds) + + +class RtgApiTranslator(BaseTranslator): + + def __init__(self, api_url='http://rtg.isi.edu/many-eng/v1/translate') -> None: + self.api_url = api_url + + def translate(self, text: str, src_lang ='mul', tgt_lang='eng') -> str: + source = {'source': [text]} + log.debug(f'Sending source to RTG for translation: {source}') + try: + response = requests.post(self.api_url, json=source) + if response.ok: + response = response.json() + log.info(f'RTG translation: {response["translation"]}') + return response["translation"][0] + else: + log.warning(f'Translation failed {response.status_code} -> {response.reason}') + log.warning(f'Response Body from RTG:\n{response.json()}') + return text + except Exception as e: + log.error(f'Error connecting to RTG API: {e}. Returning source.') + return text + + +class HuggingFaceTranslator(BaseTranslator): + + def __init__(self, model='Helsinki-NLP/opus-mt-en-fr') -> None: + self.model_id = model + self.model = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer = AutoModelForSeq2SeqLM.from_pretrained(self.model_id) + + def translate(self, text: str, src_lang ='mul', tgt_lang='eng') -> str: + log.debug(f"Translating {text}...") + batch = self.tokenizer([text], return_tensors="pt") + gen = self.model.generate(**batch) + fr_response = self.tokenizer.batch_decode(gen, skip_special_tokens=True) + return fr_response[0] + + + +# TODO add other MTs +registry = { + 'rtg_api': RtgApiTranslator, + 'huggingface': HuggingFaceTranslator +} + + +def get_translator(name, args): + assert name in registry, f'{name} is invalid; supported: {registry.keys}' + log.info(f"creating MT: name={name} args={args}") + args = args or dict() + return registry[name](**args) + + +class DialogTranslator: + + def __init__(self, pre_translator, post_translator=None) -> None: + assert pre_translator or post_translator,\ + 'Both pre- and post- processing MTs are None. Expected atleast one.' + log.info(f"Preprocess MT: {pre_translator}") + log.info(f"Postprocess MT: {post_translator}") + self.pre_translator = pre_translator + self.post_translator = post_translator + + def maybe_preprocess(self, text): + if not self.pre_translator: + return text + return self.pre_translator(text) + + def maybe_postprocess(self, text): + if not self.post_translator: + return text + return self.post_translator(text) diff --git a/parlai/crowdsourcing/tasks/darma_chat/worlds.py b/parlai/crowdsourcing/tasks/darma_chat/worlds.py index beedd773475..6218ba11904 100644 --- a/parlai/crowdsourcing/tasks/darma_chat/worlds.py +++ b/parlai/crowdsourcing/tasks/darma_chat/worlds.py @@ -9,8 +9,11 @@ import json from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple -import random +import random import numpy as np +import logging as log +from pprint import pprint + from parlai.core.agents import create_agent_from_shared from parlai.core.message import Message @@ -22,6 +25,7 @@ ONBOARD_FAIL, ONBOARD_SUCCESS, ) +from parlai.crowdsourcing.tasks.darma_chat.translator import DialogTranslator, get_translator from parlai.crowdsourcing.tasks.darma_chat.utils import Compatibility, DarmaContextGenerator from parlai.crowdsourcing.utils.mturk import get_mturk_id_from_mephisto_wrapper @@ -29,9 +33,7 @@ if TYPE_CHECKING: from mephisto.abstractions.blueprints.parlai_chat.parlai_chat_task_runner import ( - MephistoAgentWrapper, - ) - + MephistoAgentWrapper) class ModelChatOnboardWorld(CrowdOnboardWorld): """ @@ -63,9 +65,8 @@ def parley(self): if not self.skip_onboarding: - print( - f'{self.__class__.__name__}: starting parley for worker_id: {self.worker_id}' - ) + log.info( + f'{self.__class__.__name__}: starting parley for worker_id: {self.worker_id}') # We are rendering a frontend based on the initial task data, so we just # wait for the results to come in @@ -93,28 +94,31 @@ def parley(self): def _handle_act(self, act): if 'task_data' not in act: - print(f'{self.__class__.__name__}: {self.worker_id} had no data submitted') + log.info( + f'{self.__class__.__name__}: {self.worker_id} had no data submitted') return ONBOARD_FAIL self.annotations = act['task_data'].get('annotations') - print('Onboarding annotation results: ', self.annotations) + log.info('Onboarding annotation results: ', self.annotations) if act['task_data']['success']: - print(f'Worker {self.worker_id} successfully passed the onboarding task.') + log.info( + f'Worker {self.worker_id} successfully passed the onboarding task.') # This will end the onboarding and send them directly to the HIT self.episodeDone = True # save informed consent details (signature and date) consent_data_folder = self.opt['consent_data_folder'] os.makedirs(consent_data_folder, exist_ok=True) - consent_datapath = os.path.join(consent_data_folder, "consent_log.jsonl") + consent_datapath = os.path.join( + consent_data_folder, "consent_log.jsonl") with open(consent_datapath, 'w+') as f_jsonl: act['worker_id'] = self.worker_id data_str = json.dumps(act) - f_jsonl.write(data_str+"\n") + f_jsonl.write(data_str + "\n") return ONBOARD_SUCCESS else: - print(f'Worker {self.worker_id} failed onboarding.') + log.info(f'Worker {self.worker_id} failed onboarding.') # Grant the failed qualification, then sleep as we want worker to return self.agent.mephisto_agent.get_worker().grant_qualification( self.onboarding_qualification, 0 @@ -134,7 +138,7 @@ def get_custom_task_data(self): class BaseModelChatWorld(CrowdTaskWorld, ABC): - def __init__(self, opt, agent, bot): + def __init__(self, opt, agent, bot, mt: Optional[DialogTranslator]=None): super().__init__(opt, agent) # num_turns turns for a single side, and really it appears to be @@ -145,6 +149,7 @@ def __init__(self, opt, agent, bot): self.opt = opt self.bot = bot + self.mt: DialogTranslator = mt self.task_turn_idx = 0 self.num_turns = num_turns @@ -152,7 +157,7 @@ def __init__(self, opt, agent, bot): # self.agent.agent_id = "" self.bot.agent_id = 'BOT' # self.bot.agent_id = "" - self.target_user="" + self.target_user = "" self.dialog = [] self.tag = f'conversation_id {agent.mephisto_agent.db_id}' @@ -169,7 +174,7 @@ def __init__(self, opt, agent, bot): # below are timeout protocols self.max_resp_time = max_resp_time # in secs - print( + log.info( f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.' ) @@ -177,7 +182,7 @@ def __add_problem_data_to_utterance(self, p, turn_idx: int): """ Attach problem data to the bot's prior utterance, given by turn_idx. """ - print(p) + log.info(p) assert ( self.dialog[turn_idx]['agent_idx'] == 1 ), 'Problem data must be attached to a bot utterance.' @@ -187,8 +192,8 @@ def __add_problem_data_to_utterance(self, p, turn_idx: int): self.dialog[turn_idx]['problem_data'] = p def parley(self): - print( - f'{self.__class__.__name__}:{self.tag}: is at turn {self.task_turn_idx}, with {self.num_turns} pairs of turns needed...' + log.info(f'{self.__class__.__name__}:{self.tag}: is at turn' + '{self.task_turn_idx}, with {self.num_turns} pairs of turns needed...' ) if self.task_turn_idx == 0: @@ -197,30 +202,21 @@ def parley(self): return """Otherwise, we proceed accordingly""" - print( - f'{self.__class__.__name__}:{self.tag}: About to act with task turn idx: {self.task_turn_idx}' - ) + log.info(f'{self.__class__.__name__}:{self.tag}: ' + 'About to act with task turn idx: {self.task_turn_idx}') acts = [None, None] self.agent.agent_id = self.target_user for idx, agent in enumerate([self.agent, self.bot]): if not self.chat_done: acts[idx] = agent.act(timeout=self.max_resp_time) - if ( - agent == self.bot - and hasattr(self.bot, 'agent_id') - and self.bot.agent_id - ): + if agent == self.bot and\ + hasattr(self.bot, 'agent_id') and self.bot.agent_id: # Set speaker name as self.bot_agent_id otherwise, at frontend bot name such as "TransformerGenerator" would appear - Compatibility.backward_compatible_force_set( - acts[idx], 'id', self.bot.agent_id - ) - acts[idx] = Message( - Compatibility.maybe_fix_act(acts[idx]) - ).json_safe_payload() - print( - f'Got act for agent idx {idx}, act was: {acts[idx]} and self.task_turn_idx: {self.task_turn_idx}.' - ) + Compatibility.backward_compatible_force_set(acts[idx], 'id', self.bot.agent_id) + acts[idx] = Message(Compatibility.maybe_fix_act(acts[idx])).json_safe_payload() + log.info(f'Got act for agent idx {idx}, act was: {acts[idx]} ' + 'and self.task_turn_idx: {self.task_turn_idx}.') if acts[idx].get('task_data', {}).get('final_rating') is not None: @@ -235,24 +231,17 @@ def parley(self): p = acts[idx]['task_data'].get('problem_data_for_prior_message') if p is not None: self.__add_problem_data_to_utterance(p, turn_idx=turn_idx) - self.dialog[turn_idx]['final_rating'] = acts[idx]['task_data'][ - 'final_rating' - ] + self.dialog[turn_idx]['final_rating'] = acts[idx]['task_data']['final_rating'] # Save the final chat data date_folder = time.strftime('%Y_%m_%d') time_string = time.strftime('%Y%m%d_%H%M%S') - chat_data_subfolder = os.path.join( - self.opt['chat_data_folder'], date_folder - ) + chat_data_subfolder = os.path.join(self.opt['chat_data_folder'], date_folder) os.makedirs(chat_data_subfolder, exist_ok=True) - chat_data_path = os.path.join( - chat_data_subfolder, - f'{time_string}_{np.random.randint(0, 1000)}_{self.task_type}.json', - ) + chat_data_path = os.path.join(chat_data_subfolder, + f'{time_string}_{np.random.randint(0, 1000)}_{self.task_type}.json') self.final_chat_data = self.get_final_chat_data() - self.agent.mephisto_agent.state.messages.append( - { + self.agent.mephisto_agent.state.messages.append({ 'final_chat_data': self.final_chat_data, 'data': {}, 'packet_type': None, @@ -266,47 +255,43 @@ def parley(self): with open(chat_data_path, 'w+') as f_json: data_str = json.dumps(self.final_chat_data) f_json.write(data_str) - print( - f'{self.__class__.__name__}:{self.tag}: Data saved at ' - f'{chat_data_path} for model: {self.bot.worker_id}.' - ) + log.info(f'{self.__class__.__name__}:{self.tag}: Data saved at ' + f'{chat_data_path} for model: {self.bot.worker_id}.') # Soft-block the worker if there were acceptability violations - acceptability_violations = self.final_chat_data[ - 'acceptability_violations' - ][0] - if ( - acceptability_violations is not None - and acceptability_violations != '' - ): - print( - f'**NOTE** Acceptability violations detected: {acceptability_violations}' - ) + acceptability_violations = self.final_chat_data['acceptability_violations'][0] + if acceptability_violations is not None and acceptability_violations != '': + log.info(f'**NOTE** Acceptability violations detected: {acceptability_violations}') # Grant the failed qualification self.agent.mephisto_agent.get_worker().grant_qualification( - self.block_qualification, 1 - ) - + self.block_qualification, 1) return - else: - # if idx == 1: - # user_name= "BOT" - # else: + # if idx == 1: + # user_name= "BOT" + # else: # user_name = self.target_user utterance_data = { 'agent_idx': idx, # Get rid of annotations HTML if it's the bot response 'text': acts[idx]['text'].split('
')[0], - 'id': acts[idx].get( - 'id', 'NULL_ID' - ), # In case model doesn't set id + 'id': acts[idx].get('id', 'NULL_ID'), # In case model doesn't set id } + print(f"===={idx} agent{agent}===") + pprint(utterance_data, width=200) + if self.mt: + text = utterance_data['text'] + # preprocess for human/agent, post process for bot + translate_fn = self.mt.maybe_preprocess if agent is self.agent else self.mt.maybe_postprocess + utterance_data['text_orig'] = text + utterance_data['text'] = translate_fn(text) + self.dialog.append(utterance_data) if idx == 0: # Human has just responded. Any problem data received now will be # regarding the bot's prior utterance - p = acts[idx]['task_data'].get('problem_data_for_prior_message') + p = acts[idx]['task_data'].get( + 'problem_data_for_prior_message') if p is not None: turn_idx = -2 # Attach the problem data to the second-to-last utterance, since @@ -318,9 +303,9 @@ def parley(self): if other_agent != agent: other_agent.observe(validate(acts[idx])) - print( - f'[agent {idx}] self.task_turn_idx: {self.task_turn_idx}, self.dialog is: {self.dialog}' - ) + log.info(f'[agent {idx}] self.task_turn_idx: ' + '{self.task_turn_idx}, self.dialog is: {self.dialog}') + self.task_turn_idx += 1 @abstractmethod @@ -328,6 +313,7 @@ def _run_initial_turn(self) -> None: """ Runs logic for the first turn of the human and the bot. """ + pass def _postprocess_acts(self, acts: List[dict], agent_idx: int): """ @@ -336,18 +322,14 @@ def _postprocess_acts(self, acts: List[dict], agent_idx: int): Useful for subclasses. Will be executed after saving act data to self.dialog but before showing the act to the other agent. """ + pass def shutdown(self): if self.chat_done: self.opt['run_statistics'][self.bot.worker_id] += 1 - print( - 'Runs completed per model: ' - + ', '.join( - f'{model}: {count:d}' - for model, count in self.opt['run_statistics'].items() - ) - ) + log.info('Runs completed per model: ' + ', '.join( + f'{model}: {count:d}' for model, count in self.opt['run_statistics'].items())) self.agent.shutdown() @@ -427,8 +409,9 @@ class ModelChatWorld(BaseModelChatWorld): of this task, like personas and BST-style seed utterances. """ - def __init__(self, opt, agent, bot, context_info: Optional[dict] = None): - super().__init__(opt, agent=agent, bot=bot) + def __init__(self, opt, agent, bot, context_info: Optional[dict] = None, + mt: Optional[DialogTranslator]=None): + super().__init__(opt, agent=agent, bot=bot, mt=mt) if context_info is not None: self.context_info = context_info @@ -446,18 +429,19 @@ def _run_initial_turn(self) -> None: """ control_msg = {"episode_done": False} - if self.opt['conversation_start_mode'] == "empty": + if self.opt['conversation_start_mode'] == "empty": pass - elif self.opt['conversation_start_mode'] == "custom": - print("Use custom dialogue seeds to start conversations with some context") - print(f"Context info: {self.context_info}") + elif self.opt['conversation_start_mode'] == "custom": + log.info( + "Use custom dialogue seeds to start conversations with some context") + log.info(f"Context info: {self.context_info}") dialogue = self.context_info["conversation"] - # make each turn in the context be from the bot except for the target user + # make each turn in the context be from the bot except for the target user self.target_user = self.context_info["target_user"] - for idx, turn in enumerate(dialogue): + for idx, turn in enumerate(dialogue): msg = { 'episode_done': False, @@ -466,17 +450,17 @@ def _run_initial_turn(self) -> None: 'fake_start': True, 'agent_idx': 0 if turn['speaker_id'] == self.target_user else 1, } - # if turn["speaker_id"] == self.target_user: + # if turn["speaker_id"] == self.target_user: # msg['id'] = self.agent.agent_id - # else: + # else: # msg['id'] = self.bot.agent_id - + self.dialog.append(msg) self.agent.observe(validate(msg)) self.bot.observe(validate(msg)) - # bot responds to the last turn - if idx == len(dialogue) - 1: + # bot responds to the last turn + if idx == len(dialogue) - 1: first_bot_act = self.bot.act() first_bot_act = Compatibility.backward_compatible_force_set( first_bot_act, 'id', self.bot.agent_id @@ -496,8 +480,6 @@ def _run_initial_turn(self) -> None: f"not recognized!" ) - - def get_final_chat_data(self) -> Dict[str, Any]: """ Add non-image-chat-specific fields to the final chat data. @@ -535,7 +517,7 @@ def validate_onboarding(data): """ Check the contents of the data to ensure they are valid. """ - print(f"Validating onboarding data {data}") + log.info(f"Validating onboarding data {data}") messages = data['outputs']['messages'] if len(messages) == 0: return False @@ -565,6 +547,21 @@ def get_bot_worker(opt: Dict[str, Any], model_name: str) -> TurkLikeAgent: ) return bot_worker +def get_dialog_mt(opt: Dict[str, Any]): + if 'translator' not in opt or opt['translator'].get('activation') == None: + log.info("translator is either disabled or not configured") + return None + args = opt['translator'] + assert args['activation'] in {'pre', 'post', 'pre+post'} + pre_mt, post_mt = None, None + if 'pre' in args['activation']: + pre_mt = get_translator(name=args['preprocess'], + args=args['preprocess_args']) + if 'post' in args['activation']: + post_mt = get_translator(name=args['postprocess'], + args=args['postprocess_args']) + mt = DialogTranslator(pre_translator=pre_mt, post_translator=post_mt) + return mt def make_world(opt, agents): @@ -575,7 +572,7 @@ def make_world(opt, agents): # Get context: personas, previous utterances, etc. if context_generator is not None: context_info = context_generator.get_context(DarmaContextGenerator.idx) - DarmaContextGenerator.idx += 1 + DarmaContextGenerator.idx += 1 else: context_info = None @@ -587,13 +584,14 @@ def make_world(opt, agents): ] remaining_counts_needed.sort(reverse=True, key=lambda x: x[1]) model_name = remaining_counts_needed[0][0] - print(f'Remaining conversation counts needed: {remaining_counts_needed}') - print(f'Choosing the "{model_name}" model for the bot.') + log.info( + f'Remaining conversation counts needed: {remaining_counts_needed}') + log.info(f'Choosing the "{model_name}" model for the bot.') bot_worker = get_bot_worker(opt=opt, model_name=model_name) + mt = get_dialog_mt(opt=opt) - return ModelChatWorld( - opt, agent=agents[0], bot=bot_worker, context_info=context_info - ) + return ModelChatWorld(opt, agent=agents[0], bot=bot_worker, mt=mt, + context_info=context_info) def get_world_params():