Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fix long output prints #1513

Closed
17 changes: 10 additions & 7 deletions ersilia/cli/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ... import ErsiliaModel
from ...core.session import Session
from ...utils.terminal import print_result_table
from ...utils.terminal import print_result_table, truncate_output
from .. import echo
from . import ersilia_cli

Expand Down Expand Up @@ -69,23 +69,26 @@ def run(input, output, batch_size, as_table):
batch_size=batch_size,
track_run=track_runs,
)

if isinstance(result, types.GeneratorType):
for result in mdl.run(input=input, output=output, batch_size=batch_size):
if result is not None:
formatted = json.dumps(result, indent=4)
for result_item in result:
if result_item is not None:
truncated_result = truncate_output(result_item, max_length=10)
formatted = json.dumps(truncated_result, indent=4)
if as_table:
print_result_table(formatted)
else:
echo(formatted)
else:
echo("Something went wrong", fg="red")
else:
truncated_result = truncate_output(result, max_length=10)
if as_table:
print_result_table(result)
print_result_table(truncated_result)
else:
try:
echo(result)
echo(json.dumps(truncated_result, indent=4))
except:
print_result_table(result)
print_result_table(truncated_result)

return run
61 changes: 61 additions & 0 deletions ersilia/utils/terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,67 @@ def yes_no_input(prompt, default_answer, timeout=5):
return True


def truncate_output(output, max_length=10):
"""
Truncate the given output for better readability in the terminal.

Handles strings, lists, dictionaries, integers, floats, CSV, HDF5 files, and matrices.

Parameters
----------
output : str, list, dict, int, float, or file object
The output to truncate.
max_length : int, optional
The maximum number of characters or items to display. Default is 10.

Returns
-------
str
The truncated output as a string.
"""
if isinstance(output, str):
return (
output if len(output) <= 100 else output[:100] + "..."
) # Strings: 100 chars
elif isinstance(output, (int, float)):
return str(output) # Numbers: Show as-is
elif isinstance(output, list):
if all(isinstance(row, list) for row in output): # Matrix
truncated_rows = [
row[:max_length] for row in output[:max_length]
] # Truncate rows/columns
return json.dumps(truncated_rows, indent=2) + (
"\n..." if len(output) > max_length else ""
)
else: # Regular list
return json.dumps(output[:max_length], indent=2) + (
"..." if len(output) > max_length else ""
)
elif isinstance(output, dict): # Serializable Object
truncated = list(output.items())[:max_length]
return json.dumps(truncated, indent=2) + (
"..." if len(output) > max_length else ""
)
elif hasattr(output, "read"): # Handle file-like objects
if output.name.endswith(".csv"):
output.seek(0)
reader = csv.reader(output)
rows = [row for _, row in zip(range(max_length), reader)]
return "\n".join([",".join(row) for row in rows]) + (
"\n..." if len(list(reader)) > max_length else ""
)
elif output.name.endswith((".h5", ".hdf5")):
import h5py

with h5py.File(output.name, "r") as f:
keys = list(f.keys())[:max_length]
return f"HDF5 file: {output.name}\nDatasets: {keys}" + (
"\n..." if len(keys) > max_length else ""
)
else:
return str(output)


def print_result_table(data):
"""
Print a result table from CSV or JSON-like data.
Expand Down
46 changes: 46 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
from ersilia.utils.terminal import truncate_output


def test_truncate_output_list():
# Test truncation for a long list
output = list(range(20)) # List with 20 items
truncated = truncate_output(output, max_items=10)
assert truncated == "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ... (and 10 more items)"


def test_truncate_output_dict():
# Test truncation for a long dictionary
output = {f"key{i}": i for i in range(20)} # Dictionary with 20 key-value pairs
truncated = truncate_output(output, max_items=5)
assert truncated.startswith('{\n "key0": 0,')
assert truncated.endswith("... (and 15 more lines)")


def test_truncate_output_short_list():
# Test for a short list that doesn't need truncation
output = [1, 2, 3]
truncated = truncate_output(output, max_items=10)
assert truncated == "[1, 2, 3]"


def test_truncate_output_short_dict():
# Test for a short dictionary that doesn't need truncation
output = {"key1": 1, "key2": 2}
truncated = truncate_output(output, max_items=10)
assert "key1" in truncated
assert "key2" in truncated


def test_truncate_output_short_string():
# Test for a short string that doesn't need truncation
output = "Short string"
truncated = truncate_output(output, max_chars=50)
assert truncated == "Short string"


def test_truncate_output_other_types():
# Test for non-list, non-dict, non-string types
output = 12345
truncated = truncate_output(output)
assert truncated == "12345"
Loading