Skip to content

Commit

Permalink
1)fix sh stock list col issue 2)fix wrong short signals issue 3)add m…
Browse files Browse the repository at this point in the history
…ore tests

Former-commit-id: af3cf9a
  • Loading branch information
foolcage committed Jun 30, 2019
1 parent 015c0d6 commit a57b0ac
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 157 deletions.
64 changes: 4 additions & 60 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

**Read this in other languages: [English](README-en.md).**

ZVT是在[fooltrader](https://github.com/foolcage/fooltrader)的基础上重新思考后编写的量化项目,其包含可扩展的数据recorder,api,因子计算,选股,回测,定位为日线级别全市场分析和交易框架
ZVT是在[fooltrader](https://github.com/foolcage/fooltrader)的基础上重新思考后编写的量化项目,其包含可扩展的数据recorder,api,因子计算,选股,回测,定位为**中低频** **多级别** **多标的** 全市场分析和交易框架

# 使用方式
```
Expand All @@ -16,71 +16,15 @@ pip install -U zvt

# 使用展示
[*参考代码*](./zvt/trader/examples)
### 单标的回测 ###
### 单标的单factor(cross ma) ###

#### 股票单标的单factor(cross ma) ####
```
class SingleStockTrader(StockTrader):
def __init__(self,
security: str = 'stock_sz_000338',
start_timestamp: Union[str, pd.Timestamp] = '2005-01-01',
end_timestamp: Union[str, pd.Timestamp] = '2019-06-30',
provider: Union[str, Provider] = 'joinquant',
level: Union[str, TradingLevel] = TradingLevel.LEVEL_1DAY,
trader_name: str = None,
real_time: bool = False,
kdata_use_begin_time: bool = True) -> None:
super().__init__([security], SecurityType.stock, None, None, start_timestamp, end_timestamp, provider,
level, trader_name, real_time, kdata_use_begin_time=kdata_use_begin_time)
def init_selectors(self, security_list, security_type, exchanges, codes, start_timestamp, end_timestamp):
self.selectors = []
technical_selector = TechnicalSelector(security_list=security_list, security_type=security_type,
exchanges=exchanges, codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp, level=TradingLevel.LEVEL_1DAY,
provider='joinquant')
technical_selector.run()
self.selectors.append(technical_selector)
```
<p align="center"><img src='./docs/single-stock-cross-ma.gif'/></p>

#### 股票多单标的单factor(macd) ####
```
class MultipleStockTrader(StockTrader):
def __init__(self,
security_list: List[str] = None,
exchanges: List[str] = ['sh', 'sz'],
codes: List[str] = None,
start_timestamp: Union[str, pd.Timestamp] = None,
end_timestamp: Union[str, pd.Timestamp] = None,
provider: Union[str, Provider] = 'joinquant',
level: Union[str, TradingLevel] = TradingLevel.LEVEL_1DAY,
trader_name: str = None,
real_time: bool = False,
kdata_use_begin_time: bool = False) -> None:
super().__init__(security_list, SecurityType.stock, exchanges, codes, start_timestamp, end_timestamp, provider,
level, trader_name, real_time, kdata_use_begin_time)
#### 多单标的单factor(macd) ####

def init_selectors(self, security_list, security_type, exchanges, codes, start_timestamp, end_timestamp):
my_selector = TargetSelector(security_list=security_list, security_type=security_type, exchanges=exchanges,
codes=codes, start_timestamp=start_timestamp,
end_timestamp=end_timestamp)
# add the factors
my_selector \
.add_filter_factor(BullFactor(security_list=security_list,
security_type=security_type,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
level=TradingLevel.LEVEL_1DAY))
self.selectors.append(my_selector)
```
<p align="center"><img src='./docs/multiple-stock-macd.gif'/></p>

#### 1分钟级别实时交易信号 ####

# 联系方式
QQ群:300911873 加群请备注github用户名
1 change: 1 addition & 0 deletions tests/recorders/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# -*- coding: utf-8 -*-
15 changes: 15 additions & 0 deletions tests/recorders/common/test_china_stock_list_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
from ...context import init_context

init_context()

from zvt.recorders.common.china_stock_list_spider import ChinaStockListSpider


def test_coin_meta_recorder():
recorder = ChinaStockListSpider()

try:
recorder.run()
except:
assert False
59 changes: 47 additions & 12 deletions tests/selectors/test_selector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,34 @@
# -*- coding: utf-8 -*-
from zvt.selectors.examples.technical_selector import TechnicalSelector
from zvt.utils.pd_utils import df_is_not_null
from zvt.factors.technical_factor import CrossMaFactor
from ..context import init_context

init_context()

from zvt.domain import SecurityType, TradingLevel, Provider
from zvt.selectors.examples.technical_selector import TechnicalSelector
from zvt.selectors.selector import TargetSelector


def test_cross_ma_selector():
security_list = ['stock_sz_000338']
security_type = 'stock'
start_timestamp = '2018-01-01'
end_timestamp = '2019-06-30'
my_selector = TargetSelector(security_list=security_list,
security_type=security_type,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp)
# add the factors
my_selector \
.add_filter_factor(CrossMaFactor(security_list=security_list,
security_type=security_type,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
level=TradingLevel.LEVEL_1DAY))
my_selector.run()
print(my_selector.open_long_df)
print(my_selector.open_short_df)
assert 'stock_sz_000338' in my_selector.get_open_short_targets('2018-01-29')


def test_technical_selector():
Expand All @@ -18,17 +41,29 @@ def test_technical_selector():

print(selector.get_result_df())

targets = selector.get_targets('2019-06-04')
if df_is_not_null(targets):
assert 'stock_sz_000338' not in targets['security_id'].tolist()
assert 'stock_sz_000338' not in targets['security_id'].tolist()
assert 'stock_sz_002572' not in targets['security_id'].tolist()
assert 'stock_sz_002572' not in targets['security_id'].tolist()
targets = selector.get_open_long_targets('2019-06-04')

assert 'stock_sz_000338' not in targets
assert 'stock_sz_000338' not in targets
assert 'stock_sz_002572' not in targets
assert 'stock_sz_002572' not in targets

targets = selector.get_open_short_targets('2019-06-04')
assert 'stock_sz_000338' in targets
assert 'stock_sz_000338' in targets
assert 'stock_sz_002572' in targets
assert 'stock_sz_002572' in targets

selector.move_on(timeout=0)

targets = selector.get_targets('2019-06-19')
if df_is_not_null(targets):
assert 'stock_sz_000338' in targets['security_id'].tolist()
targets = selector.get_open_long_targets('2019-06-19')

assert 'stock_sz_000338' in targets

assert 'stock_sz_002572' not in targets

targets = selector.get_keep_long_targets('2019-06-19')

assert 'stock_sz_000338' not in targets

assert 'stock_sz_002572' not in targets['security_id'].tolist()
assert 'stock_sz_002572' not in targets
2 changes: 1 addition & 1 deletion zvt/api/technical.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def init_securities(df, security_type='stock', provider=Provider.EASTMONEY):
db_engine = get_db_engine(provider, store_category=store_category)
security_schema = get_security_schema(security_type)

current = get_securities(security_type=security_type, columns=[security_schema.id], provider=provider)
current = get_securities(security_type=security_type, columns=[security_schema.id,security_schema.code], provider=provider)
df = df[~df['id'].isin(current['id'])]

df.to_sql(security_schema.__tablename__, db_engine, index=False, if_exists='append')
Expand Down
2 changes: 1 addition & 1 deletion zvt/recorders/common/china_stock_list_spider.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def download_stock_list(self, response, exchange):
df = pd.read_csv(io.BytesIO(response.content), sep='\s+', encoding='GB2312', dtype=str,
parse_dates=['上市日期'])
if df is not None:
df = df.loc[:, ['A股代码', 'A股简称', '上市日期']]
df = df.loc[:, ['公司代码', '公司简称', '上市日期']]

elif exchange == 'sz':
df = pd.read_excel(io.BytesIO(response.content), sheet_name='A股列表', dtype=str, parse_dates=['A股上市日期'])
Expand Down
4 changes: 2 additions & 2 deletions zvt/selectors/examples/technical_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

class TechnicalSelector(TargetSelector):
def __init__(self, security_list=None, security_type=SecurityType.stock, exchanges=['sh', 'sz'], codes=None,
the_timestamp=None, start_timestamp=None, end_timestamp=None, threshold=0.8,
the_timestamp=None, start_timestamp=None, end_timestamp=None, long_threshold=0.8, short_threshold=-0.8,
level=TradingLevel.LEVEL_1DAY,
provider='joinquant') -> None:
super().__init__(security_list, security_type, exchanges, codes, the_timestamp, start_timestamp, end_timestamp,
threshold, level, provider)
long_threshold, short_threshold, level, provider)

def init_factors(self, security_list, security_type, exchanges, codes, the_timestamp, start_timestamp,
end_timestamp):
Expand Down
97 changes: 76 additions & 21 deletions zvt/selectors/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def __init__(self,
the_timestamp=None,
start_timestamp=None,
end_timestamp=None,
threshold=0.8,
long_threshold=0.8,
short_threshold=-0.8,
level=TradingLevel.LEVEL_1DAY,
provider='eastmoney') -> None:
self.security_list = security_list
Expand All @@ -39,14 +40,19 @@ def __init__(self,
else:
assert False

self.threshold = threshold
self.long_threshold = long_threshold
self.short_threshold = short_threshold
self.level = level

self.filter_factors: List[FilterFactor] = []
self.score_factors: List[ScoreFactor] = []
self.must_result = None
self.filter_result = None
self.score_result = None
self.result_df: DataFrame = None

self.open_long_df: DataFrame = None
self.open_short_df: DataFrame = None
self.keep_long_df: DataFrame = None
self.keep_short_df: DataFrame = None

self.init_factors(security_list=security_list, security_type=security_type, exchanges=exchanges, codes=codes,
the_timestamp=the_timestamp, start_timestamp=start_timestamp, end_timestamp=end_timestamp)
Expand All @@ -73,10 +79,17 @@ def move_on(self, to_timestamp=None, kdata_use_begin_time=False, timeout=20):
self.run()

def run(self):
"""
"""
if self.filter_factors:
musts = []
for factor in self.filter_factors:
df = factor.get_result_df()

if not df_is_not_null(df):
raise Exception('no data for factor:{},{}'.format(factor.factor_name, factor))

if len(df.columns) > 1:
s = df.agg("and", axis="columns")
s.name = 'score'
Expand All @@ -85,12 +98,15 @@ def run(self):
df.columns = ['score']
musts.append(df)

self.must_result = list(accumulate(musts, func=operator.__and__))[-1]
self.filter_result = list(accumulate(musts, func=operator.__and__))[-1]

if self.score_factors:
scores = []
for factor in self.score_factors:
df = factor.get_result_df()
if not df_is_not_null(df):
raise Exception('no data for factor:{],{}'.format(factor.factor_name, factor))

if len(df.columns) > 1:
s = df.agg("mean", axis="columns")
s.name = 'score'
Expand All @@ -100,25 +116,64 @@ def run(self):
scores.append(df)
self.score_result = list(accumulate(scores, func=operator.__add__))[-1]

if df_is_not_null(self.must_result) and df_is_not_null(self.score_result):
result1 = self.must_result[self.must_result.score]
result2 = self.score_result[self.score_result.score >= self.threshold]
result = result2.loc[result1.index, :]

self.generate_targets()

def get_targets(self, timestamp, target_type='open_long') -> pd.DataFrame:
if target_type == 'open_long':
df = self.open_long_df
if target_type == 'open_short':
df = self.open_short_df
if target_type == 'keep_long':
df = self.keep_long_df
if target_type == 'keep_short':
df = self.keep_short_df

if df_is_not_null(df):
if timestamp in df.index:
target_df = df.loc[[to_pd_timestamp(timestamp)], :]
return target_df['security_id'].tolist()
return []

def get_open_long_targets(self, timestamp):
return self.get_targets(timestamp=timestamp, target_type='open_long')

def get_open_short_targets(self, timestamp):
return self.get_targets(timestamp=timestamp, target_type='open_short')

def get_keep_long_targets(self, timestamp):
return self.get_targets(timestamp=timestamp, target_type='keep_long')

def get_keep_short_targets(self, timestamp):
return self.get_targets(timestamp=timestamp, target_type='keep_short')

# overwrite it to generate targets
def generate_targets(self):
if df_is_not_null(self.filter_result) and df_is_not_null(self.score_result):
# for long
result1 = self.filter_result[self.filter_result.score]
result2 = self.score_result[self.score_result.score >= self.long_threshold]
long_result = result2.loc[result1.index, :]
# for short
result1 = self.filter_result[~self.filter_result.score]
result2 = self.score_result[self.score_result.score <= self.short_threshold]
short_result = result2.loc[result1.index, :]
elif df_is_not_null(self.score_result):
result = self.score_result[self.score_result.score >= self.threshold]
long_result = self.score_result[self.score_result.score >= self.long_threshold]
short_result = self.score_result[self.score_result.score <= self.short_threshold]
else:
result = self.must_result[self.must_result.score]

self.result_df = result.reset_index()
long_result = self.filter_result[self.filter_result.score]
short_result = self.filter_result[~self.filter_result.score]

self.result_df = index_df(self.result_df)
self.open_long_df = self.normalize_result_df(long_result)
self.open_short_df = self.normalize_result_df(short_result)

def get_targets(self, timestamp) -> pd.DataFrame:
if timestamp in self.result_df.index:
return self.result_df.loc[[to_pd_timestamp(timestamp)], :]
else:
return pd.DataFrame()
# TODO:keep_long,keep_short algorithm

def get_result_df(self):
return self.result_df
return self.open_long_df

def normalize_result_df(self, df):
df = df.reset_index()
df = index_df(df)
df = df.sort_values(by=['score', 'security_id'])
return df
17 changes: 8 additions & 9 deletions zvt/trader/examples/stock_trader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pandas as pd

from zvt.domain import TradingLevel, Provider, SecurityType
from zvt.factors.finance_factor import FinanceGrowthFactor
from zvt.factors.technical_factor import BullFactor, CrossMaFactor
from zvt.selectors.selector import TargetSelector
from zvt.settings import SAMPLE_STOCK_CODES
Expand Down Expand Up @@ -44,13 +43,13 @@ def init_selectors(self, security_list, security_type, exchanges, codes, start_t
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
level=TradingLevel.LEVEL_1DAY))
# .add_score_factor(FinanceGrowthFactor(security_list=security_list,
# security_type=security_type,
# exchanges=exchanges,
# codes=codes,
# start_timestamp=start_timestamp,
# end_timestamp=end_timestamp,
# level=TradingLevel.LEVEL_1DAY))
# .add_score_factor(FinanceGrowthFactor(security_list=security_list,
# security_type=security_type,
# exchanges=exchanges,
# codes=codes,
# start_timestamp=start_timestamp,
# end_timestamp=end_timestamp,
# level=TradingLevel.LEVEL_1DAY))
self.selectors.append(my_selector)


Expand Down Expand Up @@ -95,4 +94,4 @@ def get_constructor_meta(cls):
# just get hs300 securities
# security_list = get_securities_in_blocks(block_names=['HS300_'])
# MultipleStockTrader(security_list=security_list, start_timestamp='2018-01-01', end_timestamp='2019-06-25').run()
MultipleStockTrader(codes=SAMPLE_STOCK_CODES, start_timestamp='2018-01-01', end_timestamp='2019-06-25').run()
MultipleStockTrader(codes=['000338','000783'], start_timestamp='2018-01-01', end_timestamp='2019-06-25').run()
Loading

0 comments on commit a57b0ac

Please sign in to comment.