# %%
"""
・
・参考サイト：
  - gym-trading-env: https://gym-trading-env.readthedocs.io/en/latest/index.html
  - 複数データセット: https://gym-trading-env.readthedocs.io/en/latest/multi_datasets.html
"""

# %%
import os

import gym_trading_env
import gymnasium as gym
import matplotlib.pyplot as plt
import mplfinance as mpf
import numpy as np
import pandas as pd
import torch
from gym_trading_env.renderer import Renderer
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv

# %%
# フォルダ設定
output_dir_path = f"./data/999_check/強化学習"
os.makedirs(output_dir_path, exist_ok=True)


# %%
# 株価データの長さ
n_points = 150

# 時系列データを作成
time = np.arange(n_points)

# 正弦波（sin波）を生成
sin_wave = np.sin(time * 0.05)

# 正弦波を右斜めに傾けるために線形項を加える
slope = 0.01  # 傾き
sin_wave_slope = sin_wave + slope * time + 10

# ランダムノイズを加える
# noise = np.random.normal(0, 0.2, n_points)

# 最終的な株価データ
stock_prices = sin_wave_slope  # + noise

# 開始時間を設定（例: 2024年1月1日 9:00）
start_time = pd.Timestamp("2024-01-01 09:00:00")

# 1分ごとの時間リストを作成
datetime_series = [start_time + pd.Timedelta(minutes=i) for i in range(n_points)]

# データフレームを作成（datetime列を含める）
df = pd.DataFrame(
    {
        "datetime": datetime_series,
        "open": stock_prices,
        "high": stock_prices + np.random.uniform(0.1, 0.2, n_points),
        "low": stock_prices - np.random.uniform(0.1, 0.2, n_points),
        "close": stock_prices,
        "volume": [10] * n_points,
    }
)
df = df[80:].reset_index(drop=True)

# 最初の5行を表示
print(len(df), df["datetime"].min(), df["datetime"].max())
display(df.head())

# datetime を index に設定
df.set_index("datetime", inplace=True)

# mplfinance のローソク足チャートを表示
mpf.plot(df, type="candle", volume=False, ylabel="Price")


# %%
# 「feature_」と付く変数名は全て状態としてみなされる
# df["feature_pct_change"] = df["close"].pct_change()
df["feature_close"] = df["close"]
# df["feature_open"] = df["open"]/df["close"]
# df["feature_high"] = df["high"]/df["close"]
# df["feature_low"] = df["low"]/df["close"]
df.dropna(inplace=True)
display(df)


# %%
def reward_function(history):
    return np.log(
        history["portfolio_valuation", -1] / history["portfolio_valuation", -2]
    )


def dynamic_feature_last_position_taken(history):
    return history["position", -1]


def dynamic_feature_real_position(history):
    return history["real_position", -1]


# %%
# env = gym.make(
#     "TradingEnv",
#     reward_function=reward_function,
#     df=df,
#     positions=[-1, 0, 1],
#     # dynamic_feature_functions=[
#     #     dynamic_feature_last_position_taken,
#     #     dynamic_feature_real_position,
#     # ],
# )

# env.unwrapped.add_metric(
#     "Position Changes", lambda history: np.sum(np.diff(history["position"]) != 0)
# )
# env.unwrapped.add_metric("Episode Length", lambda history: len(history["position"]))
# observation, info = env.reset()
# print(f"{observation=}")


# %%
# ----- 環境の作成 -----
env = gym.make(
    "TradingEnv",
    df=df,
    reward_function=reward_function,
    positions=[-1, 0, 1],  # -1: 売り, 0: 何もしない, 1: 買い
)

# ラップしてベクトル環境化（Stable Baselines3 では VecEnv が必要）
env = DummyVecEnv([lambda: env])

# ----- DQN エージェントの作成 -----
model = DQN(
    "MlpPolicy",
    env,
    learning_rate=0.0005,  # 学習率を下げる
    buffer_size=100000,  # リプレイバッファを大きく
    learning_starts=5000,  # 学習開始ステップを増やす
    batch_size=64,  # バッチサイズを増やす
    gamma=0.99,
    train_freq=4,
    exploration_fraction=0.2,  # 探索割合を増やす
    exploration_final_eps=0.01,  # 最終的なεを小さくする
    target_update_interval=1000,  # ターゲットネットワークの更新頻度を増やす
    verbose=1,
)


# ----- 学習 -----
model.learn(total_timesteps=10000)  # 5万ステップ学習

# ----- 学習済みモデルの保存 -----
# model_path = "dqn_trading_model"
# model.save(model_path)

# # ----- 学習済みモデルのロード -----
# loaded_model = DQN.load(model_path, env=env)

# %%
# ----- 環境の作成 -----
env = gym.make(
    "TradingEnv",
    df=df,  # 対象データ (Datetime index)
    reward_function=reward_function,
    positions=[-1, 0, 1],  # -1: 売り, 0: 何もしない, 1: 買い
)
env = DummyVecEnv([lambda: env])  # DQN に必要なラップ

# ----- モデルを適用（シミュレーション） -----
obs = env.reset()
df["position"] = np.nan  # 初期化 (欠損値 NaN で埋める)

done = False
step = 0
while not done:
    action, _ = model.predict(obs, deterministic=True)  # 学習済みモデルで行動予測
    obs, reward, done, info = env.step(action)  # 環境を1ステップ進める

    # DataFrame の index に対応する position を記録
    current_time = df.index[step]  # 現在の時間 (DatetimeIndex)
    df.loc[current_time, "position"] = info[0]["position"]

    print(f"{step=}, {current_time=}, {obs=}, {info=}")

    step += 1

display(df)

# NaN を埋める（前の値を継続する）
# df["position"].fillna(method="ffill", inplace=True)

# ----- 可視化（Renderer を使う） -----
renderer = Renderer(render_logs_dir="render_logs")

# 移動平均線（SMA5 & SMA20）の表示
renderer.add_line(
    name="sma05",
    function=lambda df: df["close"].rolling(5).mean(),
    line_options={"width": 1, "color": "purple"},
)
renderer.add_line(
    name="sma20",
    function=lambda df: df["close"].rolling(20).mean(),
    line_options={"width": 1, "color": "blue"},
)

# 描画実行
renderer.run()


# %%
