Skip to content

Commit

Permalink
protobuf: tests implemented, full coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
suidpit authored and ytraven committed Jul 17, 2018
1 parent 204faa1 commit d5da746
Show file tree
Hide file tree
Showing 11 changed files with 200 additions and 64 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ MANIFEST
*.py[cdo]
*.swp
*.swo
*.sqlite
*.egg-info/
.coverage*
.idea
Expand Down
4 changes: 2 additions & 2 deletions mitmproxy/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

from .io import FlowWriter, FlowReader, FilteredFlowWriter, read_flows_from_paths
from .db import DbHandler
from .db import DBHandler


__all__ = [
"FlowWriter", "FlowReader", "FilteredFlowWriter", "read_flows_from_paths", "DbHandler"
"FlowWriter", "FlowReader", "FilteredFlowWriter", "read_flows_from_paths", "DBHandler"
]
9 changes: 5 additions & 4 deletions mitmproxy/io/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
import os

from mitmproxy.io import protobuf
from mitmproxy.http import HTTPFlow
from mitmproxy import exceptions


class DbHandler:
class DBHandler:

"""
This class is wrapping up connection to SQLITE DB.
"""

def __init__(self, db_path="/tmp/tmp.sqlite"):
def __init__(self, db_path, mode='load'):
if mode == 'write':
if os.path.isfile(db_path):
os.remove(db_path)
self.db_path = db_path
self._con = sqlite3.connect(self.db_path)
self._c = self._con.cursor()
Expand Down
1 change: 0 additions & 1 deletion mitmproxy/io/proto/http.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ message HTTPFlow {
optional bool marked = 7;
optional string mode = 8;
optional string id = 9;
optional int32 version = 10;
}

message HTTPRequest {
Expand Down
43 changes: 18 additions & 25 deletions mitmproxy/io/proto/http_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 24 additions & 29 deletions mitmproxy/io/protobuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from mitmproxy import flow
from mitmproxy import exceptions
from mitmproxy import ctx
from mitmproxy.http import HTTPFlow, HTTPResponse, HTTPRequest
from mitmproxy.certs import Cert
from mitmproxy.connections import ClientConnection, ServerConnection
Expand All @@ -16,7 +15,8 @@ def _move_attrs(s_obj, d_obj, attrs):
setattr(d_obj, attr, getattr(s_obj, attr))
else:
if hasattr(s_obj, attr) and getattr(s_obj, attr) is not None:
if not getattr(s_obj, attr):
# ugly fix to set None in empty str or bytes fields
if getattr(s_obj, attr) == "" or getattr(s_obj, attr) == b"":
d_obj[attr] = None
else:
d_obj[attr] = getattr(s_obj, attr)
Expand Down Expand Up @@ -87,12 +87,12 @@ def _dump_http_error(e: flow.Error) -> http_pb2.HTTPError:
return pe


def dump_http(f: HTTPFlow) -> http_pb2.HTTPFlow():
def dump_http(f: flow.Flow) -> http_pb2.HTTPFlow:
pf = http_pb2.HTTPFlow()
for p in ['request', 'response', 'client_conn', 'server_conn', 'error']:
if hasattr(f, p):
if hasattr(f, p) and getattr(f, p):
getattr(pf, p).MergeFrom(eval(f"_dump_http_{p}")(getattr(f, p)))
_move_attrs(f, pf, ['intercepted', 'marked', 'mode', 'id', 'version'])
_move_attrs(f, pf, ['intercepted', 'marked', 'mode', 'id'])
return pf


Expand All @@ -105,9 +105,9 @@ def dumps(f: flow.Flow) -> bytes:


def _load_http_request(o: http_pb2.HTTPRequest) -> HTTPRequest:
d = {}
d: dict = {}
_move_attrs(o, d, ['first_line_format', 'method', 'scheme', 'host', 'port', 'path', 'http_version', 'content',
'timestamp_start', 'timestamp_end', 'is_replay'])
'timestamp_start', 'timestamp_end', 'is_replay'])
if d['content'] is None:
d['content'] = b""
d["headers"] = []
Expand All @@ -118,9 +118,9 @@ def _load_http_request(o: http_pb2.HTTPRequest) -> HTTPRequest:


def _load_http_response(o: http_pb2.HTTPResponse) -> HTTPResponse:
d = {}
d: dict = {}
_move_attrs(o, d, ['http_version', 'status_code', 'reason',
'content', 'timestamp_start', 'timestamp_end', 'is_replay'])
'content', 'timestamp_start', 'timestamp_end', 'is_replay'])
if d['content'] is None:
d['content'] = b""
d["headers"] = []
Expand All @@ -131,14 +131,12 @@ def _load_http_response(o: http_pb2.HTTPResponse) -> HTTPResponse:


def _load_http_client_conn(o: http_pb2.ClientConnection) -> ClientConnection:
d = {}
_move_attrs(o, d, ['id', 'tls_established', 'sni', 'alpn_proto_negotiated', 'tls_version',
'timestamp_start', 'timestamp_tcp_setup', 'timestamp_tls_setup', 'timestamp_end'])
d: dict = {}
_move_attrs(o, d, ['id', 'tls_established', 'sni', 'cipher_name', 'alpn_proto_negotiated', 'tls_version',
'timestamp_start', 'timestamp_tcp_setup', 'timestamp_tls_setup', 'timestamp_end'])
for cert in ['clientcert', 'mitmcert']:
if hasattr(o, cert) and getattr(o, cert):
c = Cert("")
c.from_pem(getattr(o, cert))
d[cert] = c
d[cert] = Cert.from_pem(getattr(o, cert))
if o.tls_extensions:
d['tls_extensions'] = []
for extension in o.tls_extensions:
Expand All @@ -152,18 +150,17 @@ def _load_http_client_conn(o: http_pb2.ClientConnection) -> ClientConnection:


def _load_http_server_conn(o: http_pb2.ServerConnection) -> ServerConnection:
d = {}
d: dict = {}
_move_attrs(o, d, ['id', 'tls_established', 'sni', 'alpn_proto_negotiated', 'tls_version',
'timestamp_start', 'timestamp_tcp_setup', 'timestamp_tls_setup', 'timestamp_end'])
'timestamp_start', 'timestamp_tcp_setup', 'timestamp_tls_setup', 'timestamp_end'])
for addr in ['address', 'ip_address', 'source_address']:
if hasattr(o, addr):
d[addr] = (getattr(o, addr).host, getattr(o, addr).port)
if o.cert:
c = Cert("")
c.from_pem(o.cert)
c = Cert.from_pem(o.cert)
d['cert'] = c
if len(o.via.id):
d['via'] = _load_http_server_conn(d['via'])
if o.HasField('via'):
d['via'] = _load_http_server_conn(o.via)
sc = ServerConnection(tuple())
for k, v in d.items():
setattr(sc, k, v)
Expand All @@ -181,9 +178,11 @@ def _load_http_error(o: http_pb2.HTTPError) -> typing.Optional[flow.Error]:
def load_http(hf: http_pb2.HTTPFlow) -> HTTPFlow:
parts = {}
for p in ['request', 'response', 'client_conn', 'server_conn', 'error']:
if hasattr(hf, p) and getattr(hf, p):
if hf.HasField(p):
parts[p] = eval(f"_load_http_{p}")(getattr(hf, p))
_move_attrs(hf, parts, ['intercepted', 'marked', 'mode', 'id', 'version'])
else:
parts[p] = None
_move_attrs(hf, parts, ['intercepted', 'marked', 'mode', 'id'])
f = HTTPFlow(ClientConnection(None, tuple(), None), ServerConnection(tuple()))
for k, v in parts.items():
setattr(f, k, v)
Expand All @@ -195,9 +194,5 @@ def loads(b: bytes, typ="http") -> flow.Flow:
raise exceptions.TypeError("Flow types different than HTTP not supported yet!")
else:
p = http_pb2.HTTPFlow()
try:
p.ParseFromString(b)
return load_http(p)
except Exception as e:
ctx.log(str(e))

p.ParseFromString(b)
return load_http(p)
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
max-line-length = 140
max-complexity = 25
ignore = E251,C901,W503,W292,E722,E741
exclude = mitmproxy/contrib/*,test/mitmproxy/data/*,release/build/*
exclude = mitmproxy/contrib/*,test/mitmproxy/data/*,release/build/*,mitmproxy/io/proto/*
addons = file,open,basestring,xrange,unicode,long,cmp

[tool:pytest]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@
"kaitaistruct>=0.7,<0.9",
"ldap3>=2.5,<2.6",
"passlib>=1.6.5, <1.8",
"protobuf>=3.6.0, <3.7",
"pyasn1>=0.3.1,<0.5",
"pyOpenSSL>=17.5,<18.1",
"pyparsing>=2.1.3, <2.3",
"pyperclip>=1.6.0, <1.7",
"protobuf>=3.6.0, <3.7",
"ruamel.yaml>=0.13.2, <0.16",
"sortedcontainers>=1.5.4,<2.1",
"tornado>=4.3,<5.1",
Expand Down
3 changes: 2 additions & 1 deletion test/filename_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
def check_src_files_have_test():
missing_test_files = []

excluded = ['mitmproxy/contrib/', 'mitmproxy/test/', 'mitmproxy/tools/', 'mitmproxy/platform/']
excluded = ['mitmproxy/contrib/', 'mitmproxy/io/proto/',
'mitmproxy/test/', 'mitmproxy/tools/', 'mitmproxy/platform/']
src_files = glob.glob('mitmproxy/**/*.py', recursive=True) + glob.glob('pathod/**/*.py', recursive=True)
src_files = [f for f in src_files if os.path.basename(f) != '__init__.py']
src_files = [f for f in src_files if not any(os.path.normpath(p) in f for p in excluded)]
Expand Down
26 changes: 26 additions & 0 deletions test/mitmproxy/io/test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from mitmproxy.io import db
from mitmproxy.test import tflow


class TestDB:

def test_create(self, tdata):
dh = db.DBHandler(db_path=tdata.path("mitmproxy/data") + "/tmp.sqlite")
with dh._con as c:
cur = c.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='FLOWS';")
assert cur.fetchall() == [('FLOWS',)]

def test_roundtrip(self, tdata):
dh = db.DBHandler(db_path=tdata.path("mitmproxy/data") + "/tmp.sqlite", mode='write')
flows = []
for i in range(10):
flows.append(tflow.tflow())
dh.store(flows)
dh = db.DBHandler(db_path=tdata.path("mitmproxy/data") + "/tmp.sqlite")
with dh._con as c:
cur = c.cursor()
cur.execute("SELECT count(*) FROM FLOWS;")
assert cur.fetchall()[0][0] == 10
loaded_flows = dh.load()
assert len(loaded_flows) == len(flows)
Loading

0 comments on commit d5da746

Please sign in to comment.