diff --git a/qlib/rl/contrib/train_opdt.py b/qlib/rl/contrib/train_opdt.py new file mode 100644 index 0000000000..5af351c66a --- /dev/null +++ b/qlib/rl/contrib/train_opdt.py @@ -0,0 +1,252 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import argparse +import os +import random +from pathlib import Path +from typing import cast, List, Optional + +import numpy as np +import pandas as pd +import qlib +import torch +import yaml +from qlib.backtest import Order +from qlib.backtest.decision import OrderDir +from qlib.constant import ONE_MIN +from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter +from qlib.rl.order_execution import SingleAssetOrderExecutionSimple +from qlib.rl.reward import Reward +from qlib.rl.trainer import Checkpoint, backtest, train +from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter +from qlib.rl.utils.log import CsvWriter, ActionWriter +from qlib.utils import init_instance_by_config +from tianshou.policy import BasePolicy +from torch import nn +from torch.utils.data import Dataset + + +def seed_everything(seed: int) -> None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +def _read_orders(order_dir: Path) -> pd.DataFrame: + if os.path.isfile(order_dir): + return pd.read_pickle(order_dir) + else: + orders = [] + for file in order_dir.iterdir(): + order_data = pd.read_pickle(file) + orders.append(order_data) + return pd.concat(orders) + + +class LazyLoadDataset(Dataset): + def __init__( + self, + order_file_path: Path, + data_dir: Path, + default_start_time_index: int, + default_end_time_index: int, + ) -> None: + self._default_start_time_index = default_start_time_index + self._default_end_time_index = default_end_time_index + + self._order_file_path = order_file_path + self._order_df = _read_orders(order_file_path).reset_index() + + self._data_dir = data_dir + self._ticks_index: Optional[pd.DatetimeIndex] = None + + def __len__(self) -> int: + return len(self._order_df) + + def __getitem__(self, index: int) -> Order: + row = self._order_df.iloc[index] + date = pd.Timestamp(str(row["date"])) + + if self._ticks_index is None: + # TODO: We only load ticks index once based on the assumption that ticks index of different dates + # TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index + # TODO: of all dates. + backtest_data = load_simple_intraday_backtest_data( + data_dir=self._data_dir, + stock_id=row["instrument"], + date=date, + ) + self._ticks_index = [t - date for t in backtest_data.get_time_index()] + + order = Order( + stock_id=row["instrument"], + amount=row["amount"], + direction=OrderDir(int(row["order_type"])), + start_time=date + self._ticks_index[self._default_start_time_index], + end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN, + ) + + return order + + +def train_and_test( + env_config: dict, + simulator_config: dict, + trainer_config: dict, + data_config: dict, + state_interpreter: StateInterpreter, + action_interpreter: ActionInterpreter, + policy: BasePolicy, + reward: Reward, + run_backtest: bool, +) -> None: + qlib.init() + + order_root_path = Path(data_config["source"]["order_dir"]) + + data_granularity = simulator_config.get("data_granularity", 1) + + def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: + return SingleAssetOrderExecutionSimple( + order=order, + data_dir=Path(data_config["source"]["data_dir"]), + ticks_per_step=simulator_config["time_per_step"], + data_granularity=data_granularity, + deal_price_type=data_config["source"].get("deal_price_column", "close"), + vol_threshold=simulator_config["vol_limit"], + ) + + assert data_config["source"]["default_start_time_index"] % data_granularity == 0 + assert data_config["source"]["default_end_time_index"] % data_granularity == 0 + + train_dataset, valid_dataset, test_dataset = [ + LazyLoadDataset( + order_file_path=order_root_path / tag, + data_dir=Path(data_config["source"]["data_dir"]), + default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, + default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, + ) + for tag in ("train", "valid", "all") + ] + + if "checkpoint_path" in trainer_config: + callbacks: List[Callback] = [] + callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"]))) + callbacks.append( + Checkpoint( + dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints", + every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1), + save_latest="copy", + ), + ) + if "earlystop_patience" in trainer_config: + callbacks.append( + EarlyStopping( + patience=trainer_config["earlystop_patience"], + monitor="val/pa", + ) + ) + + trainer_kwargs = { + "max_iters": trainer_config["max_epoch"], + "finite_env_type": env_config["parallel_mode"], + "concurrency": env_config["concurrency"], + "val_every_n_iters": trainer_config.get("val_every_n_epoch", None), + "callbacks": callbacks, + } + vessel_kwargs = { + "episode_per_iter": trainer_config["episode_per_collect"], + "update_kwargs": { + "batch_size": trainer_config["batch_size"], + "repeat": trainer_config["repeat_per_collect"], + }, + "val_initial_states": valid_dataset, + } + + train( + simulator_fn=_simulator_factory_simple, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + policy=policy, + reward=reward, + initial_states=cast(List[Order], train_dataset), + trainer_kwargs=trainer_kwargs, + vessel_kwargs=vessel_kwargs, + ) + + if run_backtest: + backtest( + simulator_fn=_simulator_factory_simple, + state_interpreter=state_interpreter, + action_interpreter=action_interpreter, + initial_states=test_dataset, + policy=policy, + logger=[CsvWriter(Path(trainer_config["checkpoint_path"])), ActionWriter(Path(trainer_config["checkpoint_path"]))], + # logger = CsvWriter(Path(trainer_config["checkpoint_path"]), loglevel=10), + reward=reward, + finite_env_type=trainer_kwargs["finite_env_type"], + concurrency=trainer_kwargs["concurrency"], + ) + + +def main(config: dict, run_backtest: bool) -> None: + if "seed" in config["runtime"]: + seed_everything(config["runtime"]["seed"]) + + state_config = config["state_interpreter"] + state_interpreter: StateInterpreter = init_instance_by_config(state_config) + + action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"]) + reward: Reward = init_instance_by_config(config["reward"]) + + # Create torch network + if "kwargs" not in config["network"]: + config["network"]["kwargs"] = {} + config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space}) + network: nn.Module = init_instance_by_config(config["network"]) + + # Create policy + config["policy"]["kwargs"].update( + { + "network": network, + "obs_space": state_interpreter.observation_space, + "action_space": action_interpreter.action_space, + } + ) + policy: BasePolicy = init_instance_by_config(config["policy"]) + + use_cuda = config["runtime"].get("use_cuda", False) + if use_cuda: + policy.cuda() + + train_and_test( + env_config=config["env"], + simulator_config=config["simulator"], + data_config=config["data"], + trainer_config=config["trainer"], + action_interpreter=action_interpreter, + state_interpreter=state_interpreter, + policy=policy, + reward=reward, + run_backtest=run_backtest, + ) + + +if __name__ == "__main__": + import warnings + + warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") + parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow after training is finished") + args = parser.parse_args() + + with open(args.config_path, "r") as input_stream: + config = yaml.safe_load(input_stream) + + main(config, run_backtest=args.run_backtest) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 3f21c08550..87d00ff1d4 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -249,6 +249,28 @@ def get_data( ) +class TeacherActionData: + teacher_action: pd.DataFrame + step: int + + def __init__(self, teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> None: # type: ignore + data = pd.read_pickle(teacher_action_file).loc[pd.IndexSlice[stock_id, date.date()]] # type: ignore + self.teacher_action = data["policy_act"] + self.step = data["step"] + + +def load_teacher_action_data(teacher_action_file: Path, stock_id: str, date: pd.Timestamp) -> TeacherActionData: # type: ignore + return TeacherActionData(teacher_action_file, stock_id, date) + + +class TeacherActionDataProvider: + def __init__(self, teacher_action_file: Path) -> None: + self._teacher_action_file = teacher_action_file + + def get_data(self, stock_id: str, date: pd.Timestamp) -> TeacherActionData: + return load_teacher_action_data(self._teacher_action_file, stock_id, date) + + def load_orders( order_path: Path, start_time: pd.Timestamp = None, diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 01b0811530..77ce488f84 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -12,6 +12,7 @@ from qlib.constant import EPS from qlib.rl.data.base import ProcessedDataProvider +from qlib.rl.data.pickle_styled import TeacherActionDataProvider from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.order_execution.state import SAOEState from qlib.typehint import TypedDict @@ -53,6 +54,10 @@ class FullHistoryObs(TypedDict): position_history: Any +class OPDObs(FullHistoryObs): + teacher_action: Any + + class DummyStateInterpreter(StateInterpreter[SAOEState, dict]): """Dummy interpreter for policies that do not need inputs (for example, AllOne).""" @@ -153,6 +158,110 @@ def _mask_future_info(arr: pd.DataFrame, current: pd.Timestamp) -> pd.DataFrame: return arr +class OracleObsInterpreter(FullHistoryStateInterpreter): + def interpret(self, state: SAOEState) -> FullHistoryObs: + processed = self.processed_data_provider.get_data( + stock_id=state.order.stock_id, + date=pd.Timestamp(state.order.start_time.date()), + feature_dim=self.data_dim, + time_index=state.ticks_index, + ) + + position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32) + position_history[0] = state.order.amount + position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy() + + return cast( + FullHistoryObs, + canonicalize( + { + "data_processed": np.array(processed.today), + "data_processed_prev": np.array(processed.yesterday), + "acquiring": _to_int32(state.order.direction == state.order.BUY), + "cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)), + "cur_step": _to_int32(min(state.cur_step, self.max_step - 1)), + "num_step": _to_int32(self.max_step), + "target": _to_float32(state.order.amount), + "position": _to_float32(state.position), + "position_history": _to_float32(position_history[: self.max_step]), + }, + ), + ) + + +class OPDObsInterpreter(FullHistoryStateInterpreter): + def __init__( + self, + max_step: int, + data_ticks: int, + data_dim: int, + processed_data_provider: dict | ProcessedDataProvider, + teacher_action_data_provider: dict | TeacherActionDataProvider, + ) -> None: + super().__init__(max_step, data_ticks, data_dim, processed_data_provider) + # self.teacher_action_file = teacher_action_file + self.teacher_action_data_provider = init_instance_by_config( + teacher_action_data_provider, accept_types=TeacherActionDataProvider + ) + + def interpret(self, state: SAOEState) -> OPDObs: + processed = self.processed_data_provider.get_data( + stock_id=state.order.stock_id, + date=pd.Timestamp(state.order.start_time.date()), + feature_dim=self.data_dim, + time_index=state.ticks_index, + ) + + position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32) + position_history[0] = state.order.amount + position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy() + teacher_action = self.teacher_action_data_provider.get_data( + stock_id=state.order.stock_id, date=pd.Timestamp(state.order.start_time.date()) + ).teacher_action + try: + this_teacher_action = teacher_action.values[state.cur_step] + except IndexError: + this_teacher_action = 0 + + # The min, slice here are to make sure that indices fit into the range, + # even after the final step of the simulator (in the done step), + # to make network in policy happy. + return cast( + OPDObs, + canonicalize( + { + "data_processed": np.array(self._mask_future_info(processed.today, state.cur_time)), + "data_processed_prev": np.array(processed.yesterday), + "acquiring": _to_int32(state.order.direction == state.order.BUY), + "cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)), + "cur_step": _to_int32(min(state.cur_step, self.max_step - 1)), + "num_step": _to_int32(self.max_step), + "target": _to_float32(state.order.amount), + "position": _to_float32(state.position), + "position_history": _to_float32(position_history[: self.max_step]), + "teacher_action": _to_int32(this_teacher_action), + }, + ), + ) + + @property + def observation_space(self) -> spaces.Dict: + space = { + "data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), + "data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), + "acquiring": spaces.Discrete(2), + "cur_tick": spaces.Box(0, self.data_ticks - 1, shape=(), dtype=np.int32), + "cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32), + # TODO: support arbitrary length index + "num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32), + "target": spaces.Box(-EPS, np.inf, shape=()), + "position": spaces.Box(-EPS, np.inf, shape=()), + "position_history": spaces.Box(-EPS, np.inf, shape=(self.max_step,)), + "teacher_action": spaces.Box(-EPS, np.inf, shape=(), dtype=np.int32), + } + return spaces.Dict(space) + + class CurrentStateObs(TypedDict): acquiring: bool cur_step: int diff --git a/qlib/rl/order_execution/network.py b/qlib/rl/order_execution/network.py index d6a11189cf..f0116e52cc 100644 --- a/qlib/rl/order_execution/network.py +++ b/qlib/rl/order_execution/network.py @@ -119,6 +119,37 @@ def forward(self, batch: Batch) -> torch.Tensor: return self.fc(out) +class OracleRecurrent(Recurrent): + def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]: + bs, _, data_dim = obs["data_processed"].size() + data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1) + cur_step = obs["cur_step"].long() + bs_indices = torch.arange(bs, device=device) + + position = obs["position_history"] / obs["target"].unsqueeze(-1) # [bs, num_step] + steps = ( + torch.arange(position.size(-1), device=device).unsqueeze(0).repeat(bs, 1).float() + / obs["num_step"].unsqueeze(-1).float() + ) # [bs, num_step] + priv = torch.stack((position.float(), steps), -1) + + data_in = self.raw_fc(data) + data_out, _ = self.raw_rnn(data_in) + # get last minute output + data_out_slice = data_out[bs_indices, -1] + + priv_in = self.pri_fc(priv) + priv_out = self.pri_rnn(priv_in)[0] + priv_out = priv_out[bs_indices, cur_step] + + sources = [data_out_slice, priv_out] + + dir_out = self.dire_fc(torch.stack((obs["acquiring"], 1 - obs["acquiring"]), -1).float()) + sources.append(dir_out) + + return sources, data_out + + class Attention(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index a46b587aa1..22e8f893f4 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -4,7 +4,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Dict, Generator, Iterable, Optional, OrderedDict, Tuple, cast +from typing import Any, Dict, Generator, Iterable, List, Optional, OrderedDict, Tuple, cast import gym import numpy as np @@ -12,7 +12,7 @@ import torch.nn as nn from gym.spaces import Discrete from tianshou.data import Batch, ReplayBuffer, to_torch -from tianshou.policy import BasePolicy, PPOPolicy, DQNPolicy +from tianshou.policy import BasePolicy, DQNPolicy, PPOPolicy from qlib.rl.trainer.trainer import Trainer @@ -158,6 +158,110 @@ def __init__( set_weight(self, Trainer.get_policy_state_dict(weight_file)) +class OPD(PPO): + """Oracle Policy Distillation. + + Reference: + Universal Trading for Order Execution with Oracle Policy Distillation. https://arxiv.org/abs/2103.10860 + """ + + def __init__( + self, + network: nn.Module, + obs_space: gym.Space, + action_space: gym.Space, + lr: float, + weight_decay: float = 0.0, + discount_factor: float = 1.0, + max_grad_norm: float = 100.0, + reward_normalization: bool = True, + eps_clip: float = 0.3, + value_clip: bool = True, + vf_coef: float = 1.0, + gae_lambda: float = 1.0, + max_batch_size: int = 256, + deterministic_eval: bool = True, + dis_coef: float = 0.01, + weight_file: Optional[Path] = None, + ) -> None: + self._weight_dis = dis_coef + super().__init__( + network, + obs_space, + action_space, + lr, + weight_decay, + discount_factor, + max_grad_norm, + reward_normalization, + eps_clip, + value_clip, + vf_coef, + gae_lambda, + max_batch_size, + deterministic_eval, + weight_file, + ) + + def learn( # type: ignore + self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any + ) -> Dict[str, List[float]]: + losses, clip_losses, vf_losses, dis_losses, ent_losses = [], [], [], [], [] + for step in range(repeat): + if self._recompute_adv and step > 0: + batch = self._compute_returns(batch, self._buffer, self._indices) + for minibatch in batch.split(batch_size, merge_last=True): + # calculate loss for actor + out = self(minibatch) + dist = out.dist + if self._norm_adv: + mean, std = minibatch.adv.mean(), minibatch.adv.std() + minibatch.adv = (minibatch.adv - mean) / std # per-batch norm + ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() + ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) + surr1 = ratio * minibatch.adv + surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * minibatch.adv + if self._dual_clip: + clip1 = torch.min(surr1, surr2) + clip2 = torch.max(clip1, self._dual_clip * minibatch.adv) + clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean() + else: + clip_loss = -torch.min(surr1, surr2).mean() + # calculate loss for critic + value = self.critic(minibatch.obs).flatten() + if self._value_clip: + v_clip = minibatch.v_s + (value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip) + vf1 = (minibatch.returns - value).pow(2) + vf2 = (minibatch.returns - v_clip).pow(2) + vf_loss = torch.max(vf1, vf2).mean() + else: + vf_loss = (minibatch.returns - value).pow(2).mean() + # calculate distillation loss + teacher_action = torch.tensor(minibatch.obs["teacher_action"]).long() + logits = out.logits + dis_loss = nn.functional.nll_loss(logits.log(), teacher_action) + # calculate regularization and overall loss + ent_loss = dist.entropy().mean() + loss = clip_loss + self._weight_vf * vf_loss - self._weight_ent * ent_loss + self._weight_dis * dis_loss + self.optim.zero_grad() + loss.backward() + if self._grad_norm: # clip large gradient + nn.utils.clip_grad_norm_(self._actor_critic.parameters(), max_norm=self._grad_norm) + self.optim.step() + clip_losses.append(clip_loss.item()) + vf_losses.append(vf_loss.item()) + dis_losses.append(dis_loss.item()) + ent_losses.append(ent_loss.item()) + losses.append(loss.item()) + + return { + "loss": losses, + "loss/clip": clip_losses, + "loss/vf": vf_losses, + "loss/ent": ent_losses, + } + + DQNModel = PPOActor # Reuse PPOActor. diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py index 75aab20688..ff8f4dd08d 100644 --- a/qlib/rl/utils/log.py +++ b/qlib/rl/utils/log.py @@ -514,6 +514,66 @@ def on_env_all_done(self) -> None: class PickleWriter(LogWriter): """Dump logs to pickle files.""" + SUPPORTED_TYPES = Any + + all_records: list[dict[str, Any]] + + def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC): + super().__init__(loglevel) + self.output_dir = output_dir + self.output_dir.mkdir(exist_ok=True, parents=True) + + def clear(self): + super().claer() + self.all_records = [] + + def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + # FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup + episode_wise_contents: dict[str, list] = defaultdict(list) + + for step_contents in contents: + for name, value in step_contents.items(): + if isinstance(value, self.SUPPORTED_TYPES): + logs[name].append(value) + + self.all_records.append(logs) + + def on_env_all_done(self) -> None: + # FIXME: this is temporary + pd.DataFrame.from_records(self.all_records).to_pickle(self.output_dir / "result.pkl") + + +class ActionWriter(LogWriter): + """Dump policy actions to pickle files""" + + all_records: dict[str, list[Any]] + + def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.DEBUG) -> None: + super().__init__(loglevel) + self.output_dir = output_dir + self.output_dir.mkdir(exist_ok=True) + + def clear(self) -> None: + super().clear() + self.all_records = defaultdict(list) + + def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None: + for step_index, step_contents in enumerate(contents): + for name, value in step_contents.items(): + if name == "policy_act": + self.all_records[name].append(value) + if name == "datetime": + self.all_records["date"].extend([value.date()] * len(contents)) + if name == "stock_id": + self.all_records[name].extend([value] * len(contents)) + self.all_records["step"].append(step_index) + + def on_env_all_done(self) -> None: + # FIXME: this is temporary + pd.DataFrame.from_dict(self.all_records).set_index(["stock_id", "date"]).sort_index().to_pickle( + self.output_dir / "action.pkl" + ) + class TensorboardWriter(LogWriter): """Write logs to event files that can be visualized with tensorboard."""