Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dmtrrk committed Nov 27, 2024
1 parent 7609478 commit bfc7424
Showing 1 changed file with 83 additions and 12 deletions.
95 changes: 83 additions & 12 deletions tests/test_transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,57 @@ def test_get_transcript_text(self, mock_session, make_mock_response):
mock_session.request.assert_called_once_with("GET",
URL,
headers=expected_headers)

@pytest.mark.parametrize('group_channels_by, group_channels_threshold_ms', [(None, None), ('sentence', 5000), ('word', 2000)])
def test_get_transcript_text(self, mock_session, make_mock_response, group_channels_by, group_channels_threshold_ms):
data = 'Test'
client = RevAiAPIClient(TOKEN)
expected_headers = {'Accept': 'text/plain'}
expected_headers.update(client.default_headers)
response = make_mock_response(url=URL, text=data)
mock_session.request.return_value = response
expected_url = URL
if group_channels_by and group_channels_threshold_ms:
expected_url += f"?group_channels_by={group_channels_by}&group_channels_threshold_ms={group_channels_threshold_ms}"

res = client.get_transcript_text(JOB_ID, group_channels_by=group_channels_by, group_channels_threshold_ms=group_channels_threshold_ms)

assert res == data
mock_session.request.assert_called_once_with("GET",
expected_url,
headers=expected_headers)

@pytest.mark.parametrize('id', [None, ''])
def test_get_transcript_text_with_no_job_id(self, id, mock_session):
with pytest.raises(ValueError, match='id_ must be provided'):
RevAiAPIClient(TOKEN).get_transcript_text(id)

def test_get_transcript_text_as_stream(self, mock_session, make_mock_response):
@pytest.mark.parametrize(
'group_channels_by, group_channels_threshold_ms',
[(None, None), ('sentence', 5000), ('word', 2000)]
)
def test_get_transcript_text_as_stream(
self,
mock_session,
make_mock_response,
group_channels_by,
group_channels_threshold_ms
):
data = 'Test'
client = RevAiAPIClient(TOKEN)
expected_headers = {'Accept': 'text/plain'}
expected_headers.update(client.default_headers)
response = make_mock_response(url=URL, text=data)
mock_session.request.return_value = response
expected_url = URL
if group_channels_by and group_channels_threshold_ms:
expected_url += f"?group_channels_by={group_channels_by}&group_channels_threshold_ms={group_channels_threshold_ms}"

res = client.get_transcript_text_as_stream(JOB_ID)
res = client.get_transcript_text_as_stream(JOB_ID, group_channels_by=group_channels_by, group_channels_threshold_ms=group_channels_threshold_ms)

assert res.content == data
mock_session.request.assert_called_once_with("GET",
URL,
expected_url,
headers=expected_headers,
stream=True)

Expand All @@ -61,7 +93,17 @@ def test_get_transcript_text_as_stream_with_no_job_id(self, id, mock_session):
with pytest.raises(ValueError, match='id_ must be provided'):
RevAiAPIClient(TOKEN).get_transcript_text_as_stream(id)

def test_get_transcript_json(self, mock_session, make_mock_response):
@pytest.mark.parametrize(
'group_channels_by, group_channels_threshold_ms',
[(None, None), ('sentence', 5000), ('word', 2000)]
)
def test_get_transcript_json(
self,
mock_session,
make_mock_response,
group_channels_by,
group_channels_threshold_ms
):
data = {
'monologues': [{
'speaker': 1,
Expand All @@ -80,19 +122,32 @@ def test_get_transcript_json(self, mock_session, make_mock_response):
expected_headers.update(client.default_headers)
response = make_mock_response(url=URL, json_data=data)
mock_session.request.return_value = response
expected_url = URL
if group_channels_by and group_channels_threshold_ms:
expected_url += f"?group_channels_by={group_channels_by}&group_channels_threshold_ms={group_channels_threshold_ms}"

res = client.get_transcript_json(JOB_ID)
res = client.get_transcript_json(JOB_ID, group_channels_by=group_channels_by, group_channels_threshold_ms=group_channels_threshold_ms)

assert res == expected
mock_session.request.assert_called_once_with(
"GET", URL, headers=expected_headers)
"GET", expected_url, headers=expected_headers)

@pytest.mark.parametrize('id', [None, ''])
def test_get_transcript_json_with_no_job_id(self, id, mock_session):
with pytest.raises(ValueError, match='id_ must be provided'):
RevAiAPIClient(TOKEN).get_transcript_json(id)

def test_get_transcript_json_as_stream(self, mock_session, make_mock_response):
@pytest.mark.parametrize(
'group_channels_by, group_channels_threshold_ms',
[(None, None), ('sentence', 5000), ('word', 2000)]
)
def test_get_transcript_json_as_stream(
self,
mock_session,
make_mock_response,
group_channels_by,
group_channels_threshold_ms
):
data = {
'monologues': [{
'speaker': 1,
Expand All @@ -111,19 +166,32 @@ def test_get_transcript_json_as_stream(self, mock_session, make_mock_response):
expected_headers.update(client.default_headers)
response = make_mock_response(url=URL, json_data=data)
mock_session.request.return_value = response
expected_url = URL
if group_channels_by and group_channels_threshold_ms:
expected_url += f"?group_channels_by={group_channels_by}&group_channels_threshold_ms={group_channels_threshold_ms}"

res = client.get_transcript_json_as_stream(JOB_ID)
res = client.get_transcript_json_as_stream(JOB_ID, group_channels_by=group_channels_by, group_channels_threshold_ms=group_channels_threshold_ms)

assert json.loads(res.content.decode('utf-8').replace("\'", "\"")) == expected
mock_session.request.assert_called_once_with(
"GET", URL, headers=expected_headers, stream=True)
"GET", expected_url, headers=expected_headers, stream=True)

@pytest.mark.parametrize('id', [None, ''])
def test_get_transcript_json_as_stream_with_no_job_id(self, id, mock_session):
with pytest.raises(ValueError, match='id_ must be provided'):
RevAiAPIClient(TOKEN).get_transcript_json_as_stream(id)

def test_get_transcript_object_with_success(self, mock_session, make_mock_response):
@pytest.mark.parametrize(
'group_channels_by, group_channels_threshold_ms',
[(None, None), ('sentence', 5000), ('word', 2000)]
)
def test_get_transcript_object_with_success(
self,
mock_session,
make_mock_response,
group_channels_by,
group_channels_threshold_ms
):
data = {
'monologues': [{
'speaker': 1,
Expand All @@ -142,12 +210,15 @@ def test_get_transcript_object_with_success(self, mock_session, make_mock_respon
expected_headers.update(client.default_headers)
response = make_mock_response(url=URL, json_data=data)
mock_session.request.return_value = response
expected_url = URL
if group_channels_by and group_channels_threshold_ms:
expected_url += f"?group_channels_by={group_channels_by}&group_channels_threshold_ms={group_channels_threshold_ms}"

res = client.get_transcript_object(JOB_ID)
res = client.get_transcript_object(JOB_ID, group_channels_by=group_channels_by, group_channels_threshold_ms=group_channels_threshold_ms)

assert res == expected
mock_session.request.assert_called_once_with(
"GET", URL, headers=expected_headers)
"GET", expected_url, headers=expected_headers)

@pytest.mark.parametrize('id', [None, ''])
def test_get_transcript_object_with_no_job_id(self, id, mock_session):
Expand Down

0 comments on commit bfc7424

Please sign in to comment.