diff --git a/finrl/meta/data_processors/processor_alpaca.py b/finrl/meta/data_processors/processor_alpaca.py index 200a7f447..80e0a2f0b 100644 --- a/finrl/meta/data_processors/processor_alpaca.py +++ b/finrl/meta/data_processors/processor_alpaca.py @@ -2,33 +2,40 @@ from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from datetime import timedelta as td -import alpaca_trade_api as tradeapi import exchange_calendars as tc import numpy as np import pandas as pd import pytz +from alpaca.data.historical import StockHistoricalDataClient +from alpaca.data.requests import StockBarsRequest +from alpaca.data.timeframe import TimeFrame from stockstats import StockDataFrame as Sdf +# import alpaca_trade_api as tradeapi + class AlpacaProcessor: - def __init__(self, API_KEY=None, API_SECRET=None, API_BASE_URL=None, api=None): - if api is None: + def __init__(self, API_KEY=None, API_SECRET=None, API_BASE_URL=None, client=None): + if client is None: try: - self.api = tradeapi.REST(API_KEY, API_SECRET, API_BASE_URL, "v2") + self.client = StockHistoricalDataClient(API_KEY, API_SECRET) except BaseException: raise ValueError("Wrong Account Info!") else: - self.api = api + self.client = client def _fetch_data_for_ticker(self, ticker, start_date, end_date, time_interval): - bars = self.api.get_bars( - ticker, - time_interval, - start=start_date.isoformat(), - end=end_date.isoformat(), - ).df - bars["symbol"] = ticker + request_params = StockBarsRequest( + symbol_or_symbols=ticker, + timeframe=TimeFrame.Minute, + start=start_date, + end=end_date, + ) + bars = self.client.get_stock_bars(request_params).df + return bars def download_data( @@ -53,7 +60,7 @@ def download_data( NY = "America/New_York" start_date = pd.Timestamp(start_date + " 09:30:00", tz=NY) end_date = pd.Timestamp(end_date + " 15:59:00", tz=NY) - + data_list = [] # Use ThreadPoolExecutor to fetch data for multiple tickers concurrently with ThreadPoolExecutor(max_workers=10) as executor: futures = [ @@ -66,7 +73,42 @@ def download_data( ) for ticker in ticker_list ] - data_list = [future.result() for future in futures] + for future in futures: + + bars = future.result() + # fix start + # Reorganize the dataframes to be in original alpaca_trade_api structure + # Rename the existing 'symbol' column if it exists + if not bars.empty: + + # Now reset the index + bars.reset_index(inplace=True) + + # Set 'timestamp' as the new index + if "level_1" in bars.columns: + bars.rename(columns={"level_1": "timestamp"}, inplace=True) + if "level_0" in bars.columns: + bars.rename(columns={"level_0": "symbol"}, inplace=True) + + bars.set_index("timestamp", inplace=True) + + # Reorder and rename columns as needed + bars = bars[ + [ + "close", + "high", + "low", + "trade_count", + "open", + "volume", + "vwap", + "symbol", + ] + ] + + data_list.append(bars) + else: + print("empty") # Combine the data data_df = pd.concat(data_list, axis=0) @@ -371,7 +413,40 @@ def fetch_latest_data( ) -> pd.DataFrame: data_df = pd.DataFrame() for tic in ticker_list: - barset = self.api.get_bars([tic], time_interval, limit=limit).df # [tic] + request_params = StockBarsRequest( + symbol_or_symbols=[tic], timeframe=TimeFrame.Minute, limit=limit + ) + + barset = self.client.get_stock_bars(request_params).df + # Reorganize the dataframes to be in original alpaca_trade_api structure + # Rename the existing 'symbol' column if it exists + if "symbol" in barset.columns: + barset.rename(columns={"symbol": "symbol_old"}, inplace=True) + + # Now reset the index + barset.reset_index(inplace=True) + + # Set 'timestamp' as the new index + if "level_0" in barset.columns: + barset.rename(columns={"level_0": "symbol"}, inplace=True) + if "level_1" in bars.columns: + barset.rename(columns={"level_1": "timestamp"}, inplace=True) + barset.set_index("timestamp", inplace=True) + + # Reorder and rename columns as needed + barset = bars[ + [ + "close", + "high", + "low", + "trade_count", + "open", + "volume", + "vwap", + "symbol", + ] + ] + barset["tic"] = tic barset = barset.reset_index() data_df = pd.concat([data_df, barset]) @@ -451,6 +526,9 @@ def fetch_latest_data( ) latest_price = price_array[-1] latest_tech = tech_array[-1] - turb_df = self.api.get_bars(["VIXY"], time_interval, limit=1).df + request_params = StockBarsRequest( + symbol_or_symbols="VIXY", timeframe=TimeFrame.Minute, limit=1 + ) + turb_df = self.client.get_stock_bars(request_params).df latest_turb = turb_df["close"].values return latest_price, latest_tech, latest_turb diff --git a/requirements.txt b/requirements.txt index 78291dd7c..2c62a4345 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,7 +43,7 @@ swig tensorboardX wheel>=0.33.6 +wrds # market data & paper trading API yfinance -wrds