Skip to content

Commit

Permalink
apply mypy and fix minor issues
Browse files Browse the repository at this point in the history
  • Loading branch information
skasberger committed Apr 8, 2021
1 parent e63da94 commit 4c5b39d
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 72 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ to ``upstream/develop``, so you have to branch-off from it too.
There is one exception: If you
want to suggest a change to the docs in the folder
``app/docs/`` (e. g. fix a typo in
:ref:`User Guide - Basic Usage <user_basic-usage>`),
:ref:`Installation <user_installation>`),
you can also pull to ``upstream/master``. This means, you have also to
branch-off from the ``master`` branch.

Expand Down
29 changes: 5 additions & 24 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Type
from typing import Union
from typing import Any

from flask import Flask
from pydantic import BaseSettings


Expand Down Expand Up @@ -47,7 +47,7 @@ class Config:
env_file = os.path.join(ROOT_DIR, "env/development.env")

@classmethod
def init_app(cls, app):
def init_app(cls, app: Flask):
BaseConfig.init_app(app)

from flask_debugtoolbar import DebugToolbarExtension
Expand Down Expand Up @@ -144,27 +144,8 @@ def init_app(cls, app):
DockerConfig.init_app(app)


ConfigTypes: Type[
Union[
DevelopmentConfig,
TestingConfig,
ProductionConfig,
UnixConfig,
DockerConfig,
DockerComposeConfig,
]
] = Union[
DevelopmentConfig,
TestingConfig,
ProductionConfig,
UnixConfig,
DockerConfig,
DockerComposeConfig,
]


def get_config_class(config_name: str = "default") -> ConfigTypes:
configs = {
def get_config_class(config_name: str = "default",) -> Any:
configs: dict = {
"development": DevelopmentConfig,
"testing": TestingConfig,
"production": ProductionConfig,
Expand Down
14 changes: 7 additions & 7 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from app.models import Url


DatabaseModels = Type[Union[Import, Doi, Url, Request, FBRequest]]
DatabaseModelTypes = Type[Union[Import, Doi, Url, Request, FBRequest]]


def get_all(
db: Session, model: DatabaseModels, skip: int = 0, limit: int = 100,
db: Session, model: DatabaseModelTypes, skip: int = 0, limit: int = 100,
) -> Query:
"""Get all entries of a model.
Expand All @@ -42,7 +42,7 @@ def get_all(
return db.session.query(model).offset(skip).limit(limit).all()


def get_first(db: Session, model: DatabaseModels, kwargs: dict) -> Query:
def get_first(db: Session, model: DatabaseModelTypes, kwargs: dict) -> Query:
"""Get first entry of a model.
Parameters
Expand All @@ -62,8 +62,8 @@ def get_first(db: Session, model: DatabaseModels, kwargs: dict) -> Query:


def create_entity(
db: Session, model: DatabaseModels, kwargs: dict = {}
) -> DatabaseModels:
db: Session, model: DatabaseModelTypes, kwargs: dict = {}
) -> DatabaseModelTypes:
"""Create one entry of a model.
Parameters
Expand All @@ -85,8 +85,8 @@ def create_entity(


def create_entities(
db: Session, model: DatabaseModels, iterable: list, kwargs: dict = {},
) -> DatabaseModels:
db: Session, model: DatabaseModelTypes, iterable: list, kwargs: dict = {},
) -> DatabaseModelTypes:
"""Create entities of a model.
Parameters
Expand Down
65 changes: 34 additions & 31 deletions app/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,46 @@
import os
from datetime import datetime
from json import dumps
from typing import Any
from typing import List
from urllib.parse import quote

from flask import g
from flask_sqlalchemy import SQLAlchemy
from pandas import read_csv
from tqdm import tqdm

try:
from urllib.parse import quote
except ImportError:
from urlparse import quote

from app.config import get_config_class, ConfigTypes
from app.crud import create_entity, create_entities, get_all
from app.models import db, Doi, Import, Url, Request, FBRequest
from app.requests import (
request_doi_landingpage,
request_ncbi_api,
request_unpaywall_api,
get_graph_api,
get_graph_api_urls,
get_graph_api_token,
)
from app.config import get_config_class
from app.crud import create_entities
from app.crud import create_entity
from app.crud import get_all
from app.models import db
from app.models import Doi
from app.models import FBRequest
from app.models import Import
from app.models import Request
from app.models import Url
from app.requests import get_graph_api
from app.requests import get_graph_api_token
from app.requests import get_graph_api_urls
from app.requests import request_doi_landingpage
from app.requests import request_ncbi_api
from app.requests import request_unpaywall_api
from app.utils import is_valid_doi


def get_config() -> ConfigTypes:
DATABASE = db


def get_config() -> Any:
"""Get config."""
config_name = os.getenv("FLASK_CONFIG") or "default"
config_class = get_config_class(config_name)

if os.getenv("ENV_FILE"):
config = config_class(_env_file=os.getenv("ENV_FILE"))
return config_class(_env_file=os.getenv("ENV_FILE"))
else:
config = config_class()

return config
return config_class()


def get_db() -> SQLAlchemy:
Expand All @@ -51,7 +54,7 @@ def get_db() -> SQLAlchemy:
again.
"""
if "db" not in g:
g.db = db
g.db = DATABASE
return g.db


Expand All @@ -68,14 +71,14 @@ def close_db(e=None) -> None:

def init_db() -> None:
"""Connect to database and create new tables."""
db = get_db()
db.create_all()
database = get_db()
database.create_all()


def drop_db() -> None:
"""Drop database."""
db = get_db()
db.drop_all()
database = get_db()
database.drop_all()


def dev() -> None:
Expand Down Expand Up @@ -412,15 +415,15 @@ def create_unpaywall_urls() -> None:
"""
num_urls_unpaywall_added = 0
num_requests_added = 0
urls_added = []
urls_added: List[str] = []

db = get_db()
database = get_db()
config = get_config()
# batch_size = config.URL_BATCH_SIZE # TODO: identify default and best practice values
email = config.APP_EMAIL
batch_size = 20

url_list = [url.url for url in get_all(db, Url)]
url_list = [url.url for url in get_all(database, Url)]

list_dois = Doi.query.filter(Doi.url_unpaywall is False).all()
print("Found DOI's: {0}".format(len(list_dois)))
Expand Down Expand Up @@ -468,8 +471,8 @@ def create_unpaywall_urls() -> None:
url_dict = {"url": url, "doi": row.doi, "url_type": url_type}
urls_added.append(url)
db_urls_added.append(url_dict)
create_entities(db, Request, db_requests_added)
create_entities(db, Url, db_urls_added)
create_entities(database, Request, db_requests_added)
create_entities(database, Url, db_urls_added)
num_urls_unpaywall_added += len(db_urls_added)
num_requests_added += len(db_requests_added)
print('{0} "Unpaywall" URL\'s added to database.'.format(num_urls_unpaywall_added))
Expand Down
3 changes: 1 addition & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from app.db import create_unpaywall_urls
from app.db import dev
from app.db import drop_db
from app.db import get_config
from app.db import get_fb_data
from app.db import import_basedata
from app.db import init_db
Expand All @@ -43,7 +42,7 @@ def create_app() -> Flask:
print("* Start FHE Collector...")

app = Flask("fhe_collector", root_path=ROOT_DIR)
env_name = os.getenv("FLASK_ENV")
env_name = os.getenv("FLASK_ENV", "default")
config = get_config_class(env_name)
app.config.from_object(config())
config.init_app(app)
Expand Down
4 changes: 3 additions & 1 deletion app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from datetime import timezone

from flask_sqlalchemy import SQLAlchemy
from flask_sqlalchemy.model import DefaultMeta
from sqlalchemy.exc import IntegrityError


db = SQLAlchemy()
BaseModelType: DefaultMeta = db.Model


class BaseModel(db.Model):
class BaseModel(BaseModelType):
__abstract__ = True

created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc))
Expand Down
5 changes: 3 additions & 2 deletions app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from csv import DictReader
from csv import DictWriter
from csv import reader
from csv import writer
from json import dump
from json import dumps
from json import load
Expand Down Expand Up @@ -171,9 +172,9 @@ def write_csv(
"""
with open(filename, "w", newline=newline, encoding=encoding) as csvfile:
writer = writer(csvfile, delimiter=delimiter, quotechar=quotechar)
csv_writer = writer(csvfile, delimiter=delimiter, quotechar=quotechar)
for row in data:
writer.writerow(row)
csv_writer.writerow(row)


def read_csv_as_dicts(
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ exclude =

[pylint]
max-line-length = 88

[mypy]
ignore_missing_imports = True
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import re
import sys
from typing import List

from setuptools import find_packages
from setuptools import setup
Expand Down Expand Up @@ -46,8 +47,7 @@ def run_tests(self):
# import here, cause outside the eggs aren't loaded
import tox

errcode = tox.cmdline(self.test_args)
sys.exit(errcode)
tox.cmdline(self.test_args)


INSTALL_REQUIREMENTS = [
Expand All @@ -65,7 +65,7 @@ def run_tests(self):
"pydantic==1.7.2",
]

TESTS_REQUIREMENTS = []
TESTS_REQUIREMENTS: List = []

CLASSIFIERS = [
# How mature is this project? Common values are
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tox]
envlist = py36,py37,py38,coverage,docs,packaging,dist_install
envlist = py36,py37,py38,mypy,docs
skip_missing_interpreters = True
ignore_basepython_conflict = True

Expand Down

0 comments on commit 4c5b39d

Please sign in to comment.