Skip to content

Commit

Permalink
add /viz visualiation
Browse files Browse the repository at this point in the history
  • Loading branch information
dmhliu committed Jul 30, 2020
1 parent 827d1bb commit 34fb37a
Show file tree
Hide file tree
Showing 4 changed files with 48,205 additions and 44 deletions.
79 changes: 42 additions & 37 deletions app/api/viz.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,50 @@
from fastapi import APIRouter, HTTPException
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from .predict import Item
from joblib import load

router = APIRouter()
sfw_model = load('nn_cleaned.joblib')
sfw_tfidf = load('tfidf_cleaned.joblib')

tfidf = sfw_tfidf
model = sfw_model
df = pd.read_csv('cleaned_subs.csv', usecols=[1])
subreddits = df['subreddit']

@router.post('/viz')
async def viz(postbody: Item):
query = tfidf.transform([postbody.title+postbody.selftext]) #use sfw

query_results= model.kneighbors(query.todense())
preds = list(zip(query_results[1][0], query_results[0][0]))
predictions = []
values = []
size = []


@router.get('/viz/{statecode}')
async def viz(statecode: str):
"""Visualize state unemployment rate from Federal Reserve Economic Data"""

# Validate the state code
statecodes = {
'AL': 'Alabama', 'AK': 'Alaska', 'AZ': 'Arizona', 'AR': 'Arkansas',
'CA': 'California', 'CO': 'Colorado', 'CT': 'Connecticut',
'DE': 'Delaware', 'DC': 'District of Columbia', 'FL': 'Florida',
'GA': 'Georgia', 'HI': 'Hawaii', 'ID': 'Idaho', 'IL': 'Illinois',
'IN': 'Indiana', 'IA': 'Iowa', 'KS': 'Kansas', 'KY': 'Kentucky',
'LA': 'Louisiana', 'ME': 'Maine', 'MD': 'Maryland',
'MA': 'Massachusetts', 'MI': 'Michigan', 'MN': 'Minnesota',
'MS': 'Mississippi', 'MO': 'Missouri', 'MT': 'Montana',
'NE': 'Nebraska', 'NV': 'Nevada', 'NH': 'New Hampshire',
'NJ': 'New Jersey', 'NM': 'New Mexico', 'NY': 'New York',
'NC': 'North Carolina', 'ND': 'North Dakota', 'OH': 'Ohio',
'OK': 'Oklahoma', 'OR': 'Oregon', 'PA': 'Pennsylvania',
'RI': 'Rhode Island', 'SC': 'South Carolina', 'SD': 'South Dakota',
'TN': 'Tennessee', 'TX': 'Texas', 'UT': 'Utah', 'VT': 'Vermont',
'VA': 'Virginia', 'WA': 'Washington', 'WV': 'West Virginia',
'WI': 'Wisconsin', 'WY': 'Wyoming'
}
statecode = statecode.upper()
if statecode not in statecodes:
raise HTTPException(status_code=404, detail=f'State code {statecode} not found')

# Get the state's unemployment rate data from FRED
url = f'https://fred.stlouisfed.org/graph/fredgraph.csv?id={statecode}UR'
df = pd.read_csv(url, parse_dates=['DATE'])
df.columns = ['Date', 'Percent']

# Make Plotly figure
statename = statecodes[statecode]
fig = px.line(df, x='Date', y='Percent', title=f'{statename} Unemployment Rate')

# Return Plotly figure as JSON string
for i in preds:
if subreddits[i[0]] not in predictions:
predictions.append(subreddits[i[0]])
values.append(i[1])
size.append((i[1]+1)*10)

predictions = predictions[:6]
values = values[:6]
predictions.reverse()
values.reverse()

fig = go.Figure(data=[go.Scatter(
x=values, y=predictions,
mode='markers',
marker=dict(
color=values,
size=size
)
)])

return fig.to_json()

2 changes: 1 addition & 1 deletion app/tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_predict_routes():
"""
routes_to_test = ['/predict','/nsfw_predict', '/test_predict']
post_body = {'title': 'foo bar bar barrrr',
'selftext': 'banjo didjeridoo djembe khomuz igil',
'selftext': 'banjo didjeridoo djembe khomuz igil'
}
for route in routes_to_test:
response = client.post(route,json=post_body)
Expand Down
17 changes: 11 additions & 6 deletions app/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@

client = TestClient(app)

post_routes_to_test = ['/viz']

def test_valid_input():
"""Return 200 Success for valid 2 character US state postal code."""
response = client.get('/viz/IL')
assert response.status_code == 200
assert 'Illinois Unemployment Rate' in response.text

def test_all_routes_200():
"""Return 200 Success with valid input."""
post_body = {'title': 'foo bar bar barrrr',
'selftext': 'banjo didjeridoo djembe khomuz igil'
}
for route in post_routes_to_test:
response = client.post(route, json=post_body)
body = response.json()
assert response.status_code == 200


def test_invalid_input():
"""Return 404 if the endpoint isn't valid US state postal code."""
response = client.get('/viz/ZZ')
body = response.json()
assert response.status_code == 404
assert body['detail'] == 'State code ZZ not found'
Loading

0 comments on commit 34fb37a

Please sign in to comment.