Spaces:
Running
Running
import streamlit as st | |
from datetime import date, datetime, timedelta | |
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
from prophet import Prophet | |
from prophet.plot import plot_plotly | |
import plotly.graph_objects as go | |
from sklearn.metrics import mean_absolute_error, mean_squared_error | |
import plotly.express as px | |
# Configure Streamlit page settings | |
st.set_page_config( | |
page_title="Stock & Crypto Forecast", | |
page_icon="๐", | |
layout="wide" | |
) | |
# Constants | |
START = "2015-01-01" | |
TODAY = date.today().strftime("%Y-%m-%d") | |
# Asset categories | |
ASSETS = { | |
'Stocks': ['GOOG', 'AAPL', 'MSFT', 'GME'], | |
'Cryptocurrencies': ['BTC-USD', 'ETH-USD', 'DOGE-USD', 'ADA-USD'] | |
} | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
.stButton>button { | |
width: 100%; | |
} | |
.reportview-container { | |
background: #f0f2f6 | |
} | |
.custom-date { | |
margin-top: 1rem; | |
padding: 1rem; | |
background-color: #f8f9fa; | |
border-radius: 0.5rem; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def load_data(ticker): | |
"""Load and validate financial data.""" | |
try: | |
data = yf.download(ticker, START, TODAY) | |
if data.empty: | |
raise ValueError(f"No data found for {ticker}") | |
data.reset_index(inplace=True) | |
required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume'] | |
for col in required_columns: | |
if col not in data.columns: | |
raise ValueError(f"Missing required column: {col}") | |
if col != 'Date': | |
data[col] = pd.to_numeric(data[col], errors='coerce') | |
data.dropna(inplace=True) | |
return data | |
except Exception as e: | |
st.error(f"Error loading data: {str(e)}") | |
return None | |
def calculate_rsi(prices, period=14): | |
"""Calculate Relative Strength Index.""" | |
delta = prices.diff() | |
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean() | |
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean() | |
rs = gain / loss | |
return 100 - (100 / (1 + rs)) | |
def prepare_prophet_data(data): | |
"""Prepare data for Prophet model.""" | |
df_prophet = data[['Date', 'Close']].copy() | |
df_prophet.columns = ['ds', 'y'] | |
return df_prophet | |
def train_prophet_model(data, period): | |
"""Train and return Prophet model with customized parameters.""" | |
model = Prophet( | |
yearly_seasonality=True, | |
weekly_seasonality=True, | |
daily_seasonality=True, | |
changepoint_prior_scale=0.05, | |
seasonality_prior_scale=10.0, | |
changepoint_range=0.9 | |
) | |
# Add custom seasonalities | |
model.add_seasonality( | |
name='monthly', | |
period=30.5, | |
fourier_order=5 | |
) | |
model.fit(data) | |
future = model.make_future_dataframe(periods=period) | |
return model, future | |
def plot_technical_analysis(data, selected_asset): | |
"""Create technical analysis plot.""" | |
fig = go.Figure() | |
fig.add_trace(go.Candlestick( | |
x=data['Date'], | |
open=data['Open'], | |
high=data['High'], | |
low=data['Low'], | |
close=data['Close'], | |
name='Price' | |
)) | |
fig.add_trace(go.Scatter( | |
x=data['Date'], | |
y=data['SMA_20'], | |
name='SMA 20', | |
line=dict(color='orange') | |
)) | |
fig.add_trace(go.Scatter( | |
x=data['Date'], | |
y=data['SMA_50'], | |
name='SMA 50', | |
line=dict(color='blue') | |
)) | |
fig.update_layout( | |
title=f'{selected_asset} Technical Analysis', | |
yaxis_title='Price', | |
template='plotly_dark' | |
) | |
return fig | |
def plot_forecast_components(model, forecast): | |
"""Create custom forecast components plot.""" | |
fig = go.Figure() | |
# Trend | |
fig.add_trace(go.Scatter( | |
x=forecast['ds'], | |
y=forecast['trend'], | |
name='Trend', | |
line=dict(color='blue') | |
)) | |
# Yearly seasonality | |
if 'yearly' in forecast.columns: | |
fig.add_trace(go.Scatter( | |
x=forecast['ds'], | |
y=forecast['yearly'], | |
name='Yearly Seasonality', | |
line=dict(color='green') | |
)) | |
# Weekly seasonality | |
if 'weekly' in forecast.columns: | |
fig.add_trace(go.Scatter( | |
x=forecast['ds'], | |
y=forecast['weekly'], | |
name='Weekly Seasonality', | |
line=dict(color='red') | |
)) | |
fig.update_layout( | |
title='Forecast Components', | |
template='plotly_dark', | |
height=800, | |
showlegend=True | |
) | |
return fig | |
def convert_df_to_csv(df): | |
"""Convert dataframe to CSV for download.""" | |
return df.to_csv(index=False).encode('utf-8') | |
def get_specific_date_prediction(model, date_input, forecast): | |
"""Get prediction for a specific date.""" | |
try: | |
date_prediction = forecast[forecast['ds'] == pd.to_datetime(date_input)].iloc[0] | |
return { | |
'Predicted Value': f"${date_prediction['yhat']:.2f}", | |
'Lower Bound': f"${date_prediction['yhat_lower']:.2f}", | |
'Upper Bound': f"${date_prediction['yhat_upper']:.2f}", | |
'Trend': f"${date_prediction['trend']:.2f}" | |
} | |
except IndexError: | |
return None | |
def main(): | |
st.title('๐ Advanced Stock & Cryptocurrency Forecast') | |
# Search bar for assets | |
search_term = st.text_input('๐ Search for assets (e.g., "AAPL" for Apple Inc.)', '') | |
# Filter assets based on search | |
filtered_assets = { | |
category: [asset for asset in assets | |
if search_term.upper() in asset.upper()] | |
for category, assets in ASSETS.items() | |
} | |
# Sidebar configuration | |
st.sidebar.title("โ๏ธ Configuration") | |
asset_type = st.sidebar.radio("Select Asset Type", list(filtered_assets.keys())) | |
selected_asset = st.sidebar.selectbox('Select Asset', filtered_assets[asset_type]) | |
# Main content layout | |
col1, col2 = st.columns(2) | |
with col1: | |
n_years = st.slider('Forecast Period (Years):', 1, 4) | |
with col2: | |
confidence_level = st.slider('Confidence Level:', 0.8, 0.99, 0.95) | |
period = n_years * 365 | |
# Date-specific prediction section | |
st.subheader('๐ฏ Get Prediction for Specific Date') | |
prediction_date = st.date_input( | |
"Select a date for prediction", | |
min_value=date.today(), | |
max_value=date.today() + timedelta(days=period), | |
value=date.today() + timedelta(days=30) | |
) | |
# Load and process data | |
with st.spinner('Loading data...'): | |
data = load_data(selected_asset) | |
if data is not None: | |
# Calculate technical indicators | |
data['SMA_20'] = data['Close'].rolling(window=20).mean() | |
data['SMA_50'] = data['Close'].rolling(window=50).mean() | |
data['RSI'] = calculate_rsi(data['Close']) | |
# Display technical analysis | |
st.subheader('๐ Technical Analysis') | |
fig_technical = plot_technical_analysis(data, selected_asset) | |
st.plotly_chart(fig_technical, use_container_width=True) | |
# Prepare and train Prophet model | |
df_prophet = prepare_prophet_data(data) | |
try: | |
model, future = train_prophet_model(df_prophet, period) | |
forecast = model.predict(future) | |
# Get specific date prediction | |
specific_prediction = get_specific_date_prediction( | |
model, | |
prediction_date, | |
forecast | |
) | |
if specific_prediction: | |
st.subheader(f"Prediction for {prediction_date}") | |
cols = st.columns(4) | |
for i, (metric, value) in enumerate(specific_prediction.items()): | |
cols[i].metric(metric, value) | |
# Calculate metrics | |
historical_predictions = forecast[forecast['ds'].isin(df_prophet['ds'])] | |
mae = mean_absolute_error(df_prophet['y'], historical_predictions['yhat']) | |
rmse = np.sqrt(mean_squared_error(df_prophet['y'], historical_predictions['yhat'])) | |
mape = np.mean(np.abs((df_prophet['y'] - historical_predictions['yhat']) / df_prophet['y'])) * 100 | |
# Display metrics | |
st.subheader('๐ Model Performance Metrics') | |
col1, col2, col3 = st.columns(3) | |
col1.metric("MAE", f"${mae:.2f}") | |
col2.metric("RMSE", f"${rmse:.2f}") | |
col3.metric("MAPE", f"{mape:.2f}%") | |
# Display forecast | |
st.subheader('๐ฎ Price Forecast') | |
fig_forecast = plot_plotly(model, forecast) | |
fig_forecast.update_layout(template='plotly_dark') | |
st.plotly_chart(fig_forecast, use_container_width=True) | |
# Display components using custom plotting function | |
st.subheader("๐ Forecast Components") | |
fig_components = plot_forecast_components(model, forecast) | |
st.plotly_chart(fig_components, use_container_width=True) | |
# Add download button | |
csv = convert_df_to_csv(forecast) | |
st.download_button( | |
label="Download Forecast Data", | |
data=csv, | |
file_name=f'{selected_asset}_forecast.csv', | |
mime='text/csv' | |
) | |
except Exception as e: | |
st.error(f"Error in prediction: {str(e)}") | |
st.exception(e) | |
if __name__ == "__main__": | |
main() |