Skip to content

Commit

Permalink
Ignore local aws config (#1230)
Browse files Browse the repository at this point in the history
* Ignore local aws config.

* Ignore local aws config.

* Temporarily override env vars.

* remove cli override.

* Pr feedback.

* Revert change.
  • Loading branch information
squidarth authored Nov 14, 2024
1 parent 895b1ea commit 7b89ce9
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 23 deletions.
50 changes: 27 additions & 23 deletions truss/remote/baseten/utils/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,37 @@
import boto3
from boto3.s3.transfer import TransferConfig
from rich.progress import Progress
from truss.util.env_vars import override_env_vars


def base64_encoded_json_str(obj):
return base64.b64encode(str.encode(json.dumps(obj))).decode("utf-8")


def multipart_upload_boto3(file_path, bucket_name, key, credentials):
s3_resource = boto3.resource("s3", **credentials)
filesize = os.stat(file_path).st_size

# Create a new progress bar
progress = Progress()

# Add a new task to the progress bar
task_id = progress.add_task("[cyan]Uploading...", total=filesize)

with progress:

def callback(bytes_transferred):
# Update the progress bar
progress.update(task_id, advance=bytes_transferred)

s3_resource.Object(bucket_name, key).upload_file(
file_path,
Config=TransferConfig(
max_concurrency=10,
use_threads=True,
),
Callback=callback,
)
# In the CLI flow, ignore any local ~/.aws/config files,
# which can interfere with uploading the Truss to S3.
with override_env_vars({"AWS_CONFIG_FILE": ""}):
s3_resource = boto3.resource("s3", **credentials)
filesize = os.stat(file_path).st_size

# Create a new progress bar
progress = Progress()

# Add a new task to the progress bar
task_id = progress.add_task("[cyan]Uploading...", total=filesize)

with progress:

def callback(bytes_transferred):
# Update the progress bar
progress.update(task_id, advance=bytes_transferred)

s3_resource.Object(bucket_name, key).upload_file(
file_path,
Config=TransferConfig(
max_concurrency=10,
use_threads=True,
),
Callback=callback,
)
14 changes: 14 additions & 0 deletions truss/tests/util/test_env_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os

from truss.util.env_vars import override_env_vars


def test_override_env_vars():
os.environ["API_KEY"] = "original_key"

with override_env_vars({"API_KEY": "new_key", "DEBUG": "true"}):
assert os.environ["API_KEY"] == "new_key"
assert os.environ["DEBUG"] == "true"

assert os.environ["API_KEY"] == "original_key"
assert "DEBUG" not in os.environ
41 changes: 41 additions & 0 deletions truss/util/env_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
from typing import Dict, Optional


class override_env_vars:
"""A context manager for temporarily overwriting environment variables.
Usage:
with override_env_vars({'API_KEY': 'test_key', 'DEBUG': 'true'}):
# Environment variables are modified here
...
# Original environment is restored here
"""

def __init__(self, env_vars: Dict[str, str]):
"""
Args:
env_vars: Dictionary of environment variables to set
"""
self.env_vars = env_vars
self.original_vars: Dict[str, Optional[str]] = {}

def __enter__(self):
for key in self.env_vars:
self.original_vars[key] = os.environ.get(key)

for key, value in self.env_vars.items():
os.environ[key] = value

return self

def __exit__(self, exc_type, exc_val, exc_tb):
# Restore original environment
for key, value in self.original_vars.items():
if value is None:
# Variable didn't exist originally
if key in os.environ:
del os.environ[key]
else:
# Restore original value
os.environ[key] = value

0 comments on commit 7b89ce9

Please sign in to comment.