This repository has been archived by the owner on Aug 28, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathdb_connector.py
149 lines (117 loc) · 4.41 KB
/
db_connector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
try:
import urlparse
except ImportError:
import urllib.parse as urlparse
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from config import Config
from utils import write_records_to_csv
def get_db_connection_string(db_url):
parsed = urlparse.urlparse(db_url)
if parsed.scheme == 'mysql':
parts = ('{}+pymysql'.format(parsed.scheme),) + parsed[1:]
elif parsed.scheme == 'sqlite':
return db_url
else:
parts = parsed
return urlparse.urlunparse(parts)
def get_engine():
if os.environ.get('ENV') != 'prod': # We are not in Heroku
Config.init_environment()
db_url = os.environ.get('MYSQL_DB_URL',
os.environ.get('SQLLITE_DB_URL', ''))
else:
db_url = os.environ['CLEARDB_DATABASE_URL']
if not db_url:
raise RuntimeError('missing database environment configuration')
db_connection_string = get_db_connection_string(db_url)
return create_engine(db_connection_string)
engine = get_engine()
Base = declarative_base(engine)
class GitHubData(Base):
"""Autogenerated model for the github_data table,
see db/data_schema.sql for details
This table stores GitHub data (e.g. Pull Requests, Forks, Stars, etc.)
"""
__tablename__ = 'github_data'
__table_args__ = {'autoload': True}
class PackageManagerData(Base):
"""Autogenerated model for the package_manager_data table,
see db/data_schema.sql for details
This table stores Package Manager (e.g. Nuget, Packagist, npm, etc.)
download data
"""
__tablename__ = 'package_manager_data'
__table_args__ = {'autoload': True}
def loadSession():
"""Return a DB session
:returns: A SQLAlchemy DB session
:rtype: session
"""
Session = sessionmaker(bind=engine)
session = Session()
return session
class DBConnector(object):
"""CRUD for DB access"""
def __init__(self):
self.session = loadSession()
return
def add_data(self, data_object):
"""Add an item to the DB
:param data_object: An object that represents a table in the DB
:type data_object: Object
:returns: True if the addition was successful
:rtype: Data object
"""
res = self.session.merge(data_object)
self.session.commit()
return res
def get_data(self, data_object):
"""Retrieve all rows from a table in the DB
:param data_object: An object that represents a table in the DB
:type data_object: Object
:returns: All data objects (GitHubData or PackageManagerData)
in the DB
:rtype: List
"""
return self.session.query(data_object).all()
def delete_data(self, id, table):
"""Deletes a record from the DB
:param id: ID of the record
:param table: Table to delete it from, 'github_data'
or 'package_manager_data'
:type id: Integer
:type table: String
:returns: True if deleted
:rtype: Bool
"""
if table == 'github_data':
self.session.query(GitHubData).filter(GitHubData.id == id).delete()
self.session.commit()
return True
elif table == 'package_manager_data':
self.session.query(PackageManagerData) \
.filter(PackageManagerData.id == id).delete()
self.session.commit()
return True
return False
def export_table_to_csv(self, data_object, header=True):
"""Write CSV to local directory of table, using table name as filename.
:param data_object: An object that represents a table in the DB
:type data_object: Object
:param header: If true, include column names as first row.
:type header: Bool
:returns: True if succeeded
:rtype: Bool
"""
column_names = [col.name for col in data_object.__mapper__.columns]
records = self.get_data(data_object)
rows = []
# Convert SQLAlchemy Objects to list of strings
for record in records:
rows.append([getattr(record, col) for col in column_names])
table_name = data_object.__tablename__
filename = "./csv/{}.csv".format(table_name)
write_records_to_csv(filename, rows, column_names)