-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathvisualize.py
328 lines (270 loc) · 12 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
# 可视化模块
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def show_pnl(data, result, init_balance, indicators=None):
df = pd.DataFrame(data)
df_result = pd.DataFrame(result)
if df_result.empty:
return
# 计算余额变化
df_result['cumulative_pnl'] = df_result['pnl'].cumsum() + init_balance
# 获取初始日期
initial_date = df.index[0]
# 创建包含初始余额的DataFrame
initial_balance_df = pd.DataFrame({
'close_date': [initial_date],
'cumulative_pnl': [init_balance]
})
# 合并初始余额和交易结果
df_balance = pd.concat([initial_balance_df, df_result[['close_date', 'cumulative_pnl']]], ignore_index=True)
df_balance = df_balance.sort_values('close_date').reset_index(drop=True)
# 设置图像属性
fig, ax1 = plt.subplots(figsize=(15, 8))
# 价格曲线
color = 'tab:blue'
ax1.set_xlabel('Time')
ax1.set_ylabel('Price', color=color)
ax1.plot(df.index, df['close'], color=color, label='Price')
ax1.tick_params(axis='y', labelcolor=color)
# 已添加图例的标记
legend_added = {"long_open": False, "long_close": False, "short_open": False, "short_close": False}
# 开平仓标记
for _, row in df_result.iterrows():
if row['dir'] == 'long':
if not legend_added["long_open"]:
ax1.plot(row['open_date'], row['open_price'], '^', markersize=10, color='green', label='Long Open')
legend_added["long_open"] = True
else:
ax1.plot(row['open_date'], row['open_price'], '^', markersize=10, color='green', label="_nolegend_")
if not legend_added["long_close"]:
ax1.plot(row['close_date'], row['close_price'], 'v', markersize=10, color='red', label='Long Close')
legend_added["long_close"] = True
else:
ax1.plot(row['close_date'], row['close_price'], 'v', markersize=10, color='red', label="_nolegend_")
elif row['dir'] == 'short':
if not legend_added["short_open"]:
ax1.plot(row['open_date'], row['open_price'], '^', markersize=10, color='red', label='Short Open')
legend_added["short_open"] = True
else:
ax1.plot(row['open_date'], row['open_price'], '^', markersize=10, color='red', label="_nolegend_")
if not legend_added["short_close"]:
ax1.plot(row['close_date'], row['close_price'], 'v', markersize=10, color='green', label='Short Close')
legend_added["short_close"] = True
else:
ax1.plot(row['close_date'], row['close_price'], 'v', markersize=10, color='green', label="_nolegend_")
# 画余额变化曲线
ax2 = ax1.twinx()
color = 'tab:red'
ax2.set_ylabel('Balance', color=color)
ax2.plot(df_balance['close_date'], df_balance['cumulative_pnl'], color=color, label='Balance')
ax2.tick_params(axis='y', labelcolor=color)
# 画指标曲线
if indicators:
for indicator_name, indicator_values in indicators.items():
if indicator_name not in ['rsi', 'volumn', 'macd']:
ax1.plot(df.index, indicator_values, label=indicator_name)
# 图例
ax1.legend()
fig.tight_layout()
plt.show()
def show_total_pnl(results, init_balance):
"""
显示多 symbol 的总 PnL 曲线。
参数:
- results: 多 symbol 的交易结果字典,键为 symbol,值为交易结果列表。
- init_balance: 初始资金。
"""
# 创建一个空的 DataFrame,用于存储所有 symbol 的余额时间序列
all_balances = pd.DataFrame()
# 遍历每个 symbol 的交易结果
for symbol, trades in results.items():
df_trades = pd.DataFrame(trades)
if df_trades.empty:
print(f'No trading result for {symbol}')
continue
# 提取 'close_date' 和 'balance' 列,构建余额时间序列
balance_series = df_trades[['close_date', 'balance']].copy()
balance_series['close_date'] = pd.to_datetime(balance_series['close_date'])
balance_series.set_index('close_date', inplace=True)
# 添加初始余额点
# 获取该 symbol 的最早交易日期
earliest_date = balance_series.index.min()
# 创建包含初始余额的 DataFrame
initial_balance_df = pd.DataFrame({
'balance': [init_balance]
}, index=[earliest_date])
# 将初始余额 DataFrame 与余额时间序列 DataFrame 合并
balance_series = pd.concat([initial_balance_df, balance_series], axis=0)
balance_series = balance_series.sort_index()
# 去重(如果初始余额和第一个交易在同一天,会有重复的索引)
balance_series = balance_series[~balance_series.index.duplicated(keep='first')]
# 将该 symbol 的余额时间序列添加到 all_balances DataFrame 中
balance_series = balance_series.rename(columns={'balance': symbol})
all_balances = pd.concat([all_balances, balance_series], axis=1)
if all_balances.empty:
print('No trading data available.')
return
# 对齐所有 symbol 的日期索引
all_balances = all_balances.sort_index()
# 使用前向填充填充缺失值
all_balances = all_balances.fillna(method='ffill')
# 将初始缺失值填充为初始余额
all_balances = all_balances.fillna(init_balance)
# 获取所有的 symbol 列(排除可能存在的非 symbol 列)
symbol_columns = [col for col in all_balances.columns if col != 'Total Balance']
# 计算组合的总余额
all_balances['Total Balance'] = all_balances[symbol_columns].sum(axis=1)
# 绘制组合的总 PnL 曲线
plt.figure(figsize=(15, 8))
plt.plot(all_balances.index, all_balances['Total Balance'], label='Total Balance', color='black', linewidth=2)
# 绘制每个 symbol 的余额曲线(可选)
# for symbol in symbol_columns:
# plt.plot(all_balances.index, all_balances[symbol], label=f'{symbol} Balance', linestyle='--')
plt.title('Total PnL Over Time')
plt.xlabel('Date')
plt.ylabel('Balance')
plt.legend()
plt.grid(True)
plt.show()
def show_multi_symbol_pnl(results, init_balance):
"""
显示多个 symbol 的 PnL 曲线
"""
plt.figure(figsize=(15, 8))
symbol_colors = {} # 存储每个 symbol 的颜色
# 获取所有交易的最早日期
initial_dates = []
for history in results.values():
df = pd.DataFrame(history)
if not df.empty:
initial_dates.append(df['close_date'].min())
if not initial_dates:
print('No trading data available.')
return
initial_date = min(initial_dates)
# 创建包含初始余额的 DataFrame
initial_balance_df = pd.DataFrame({
'close_date': [initial_date],
'cumulative_pnl': [init_balance]
})
# 遍历每个 symbol 的回测结果
for symbol, history in results.items():
df = pd.DataFrame(history)
if df.empty:
print(f'No trading result for {symbol}')
continue
# 计算每个 symbol 的累计 PnL
df['cumulative_pnl'] = df['pnl'].cumsum() + init_balance
# 添加初始余额点
df_symbol_balance = pd.concat([initial_balance_df, df[['close_date', 'cumulative_pnl']]], ignore_index=True)
df_symbol_balance = df_symbol_balance.sort_values('close_date').reset_index(drop=True)
# 为每个 symbol 分配颜色,并绘制 PnL 曲线
color = plt.cm.tab10(len(symbol_colors) % 10) # 从10种颜色中选择
symbol_colors[symbol] = color
plt.plot(df_symbol_balance['close_date'], df_symbol_balance['cumulative_pnl'],
label=f'{symbol} PnL', color=color)
# 图表美化
plt.title('PnL Curves for Multiple Symbols')
plt.xlabel('Date')
plt.ylabel('PnL')
plt.legend()
plt.grid(True)
# 显示图表
plt.show()
def show_indicators(data, indicators):
# 确定有多少个指标需要放在副图中
subplots_needed = sum(1 for indicator_name in indicators if indicator_name in ['rsi', 'volume', 'macd'])
# 创建足够的子图来容纳所有指标
# 添加squeeze=False确保axes始终是数组形式
fig, axes = plt.subplots(subplots_needed + 1, 1, figsize=(15, 8), sharex=True, squeeze=False)
fig.subplots_adjust(hspace=0) # 调整子图之间的间距
# 主图显示价格和可能的一些其他指标
axes[0,0].plot(data.index, data['close'], label='Price') # 修改为axes[0,0]访问第一个子图
# 遍历所有指标,决定它们应该放在主图还是副图
subplot_index = 1 # 副图的索引从1开始
for indicator_name, indicator_values in indicators.items():
if indicator_name in ['rsi', 'volume', 'macd']:
# 放在副图
ax = axes[subplot_index, 0] # 修改为axes[subplot_index, 0]
subplot_index += 1
if indicator_name == 'volume':
ax.bar(data.index, indicator_values, label=indicator_name)
else:
ax.plot(data.index, indicator_values, label=indicator_name)
ax.legend(loc='upper left')
else:
# 放在主图
axes[0,0].plot(data.index, indicator_values, label=indicator_name) # 修改为axes[0,0]
axes[0,0].legend(loc='upper left') # 修改为axes[0,0]
# 允许用户放大缩小图表来观察细节
# 在jupter notebook中会有bug
# plt.get_current_fig_manager().toolbar.zoom()
# 显示图表
plt.show()
def show_return_distribution(results, bins=50):
"""
绘制收益率分布的直方图,支持单 symbol 和多 symbol。
参数:
- results: 单 symbol 时为列表,多 symbol 时为字典。
- bins: 直方图的分箱数量,默认为50。
"""
# 判断是单 symbol 还是多 symbol
if isinstance(results, list):
# 单 symbol 情况
df = pd.DataFrame(results)
if df.empty or 'pnl' not in df.columns or 'open_price' not in df.columns or 'amount' not in df.columns or 'dir' not in df.columns:
print('交易数据缺少必要的字段。')
return
# 计算收益率
df['return'] = df.apply(_calculate_trade_return, axis=1)
return_data = df['return']
title = 'Return Distribution'
elif isinstance(results, dict):
# 多 symbol 情况
return_list = []
for _, trades in results.items():
df = pd.DataFrame(trades)
if not df.empty and all(col in df.columns for col in ['pnl', 'open_price', 'amount', 'dir']):
df['return'] = df.apply(_calculate_trade_return, axis=1)
return_list.extend(df['return'].tolist())
if not return_list:
print('没有可用的收益率数据。')
return
return_data = pd.Series(return_list)
title = 'Combined Return Distribution'
else:
print('Invalid results format.')
return
# 转换收益率为百分比
return_data = return_data * 100
# 绘制直方图和核密度估计曲线
plt.figure(figsize=(10, 6))
sns.histplot(return_data, bins=bins, kde=True, stat="density", edgecolor='k', alpha=0.7)
# 计算并绘制平均收益率
mean_return = return_data.mean()
plt.axvline(mean_return, color='r', linestyle='dashed', linewidth=1)
plt.text(mean_return, plt.ylim()[1]*0.9, f'Mean: {mean_return:.2f}%')
plt.title(title)
plt.xlabel('Return (%)')
plt.ylabel('Density')
plt.grid(True)
plt.show()
def _calculate_trade_return(trade):
"""
计算单笔交易的收益率。
参数:
- trade: 包含单笔交易数据的 Series。
返回:
- 收益率(浮点数)。
"""
pnl = trade['pnl']
dir = trade['dir']
open_price = trade['open_price']
amount = trade['amount']
# 投入资金的计算
invested_capital = open_price * amount
if invested_capital == 0:
return 0 # 避免除以零
trade_return = pnl / invested_capital
return trade_return