Skip to content

Commit

Permalink
Merge pull request #2642 from ASFHyP3/type-annotations
Browse files Browse the repository at this point in the history
Finish adding type annotations
  • Loading branch information
jtherrmann authored Mar 6, 2025
2 parents 1c382dc + ceeaf1d commit f803c97
Show file tree
Hide file tree
Showing 20 changed files with 104 additions and 64 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- When the API returns an error for an `INSAR_ISCE_BURST` job because the requested scenes have different polarizations, the error message now always includes the requested polarizations in the same order as the requested scenes (previously, the order of the polarizations was not guaranteed). For example, passing two scenes with `VV` and `HH` polarizations, respectively, results in the error message: `The requested scenes need to have the same polarization, got: VV, HH`
- The API validation behavior for the `INSAR_ISCE_MULTI_BURST` job type is now more closely aligned with the CLI validation for the underlying [HyP3 ISCE2](https://github.com/ASFHyP3/hyp3-isce2/) container. Currently, this only affects the `hyp3-multi-burst-sandbox` deployment.
- The requested scene names are now validated before DEM coverage for both `INSAR_ISCE_BURST` and `INSAR_ISCE_MULTI_BURST`.
- The `lambda_logging.log_exceptions` decorator (for logging unhandled exceptions in AWS Lambda functions) now returns the wrapped function's return value rather than always returning `None`.
- Ruff now enforces that all functions and methods must have type annotations.

## [9.5.2]

Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ DISABLE_PRIVATE_DNS = ${PWD}/apps/disable-private-dns/src
UPDATE_DB = ${PWD}/apps/update-db/src
UPLOAD_LOG = ${PWD}/apps/upload-log/src
DYNAMO = ${PWD}/lib/dynamo
export PYTHONPATH = ${API}:${CHECK_PROCESSING_TIME}:${GET_FILES}:${HANDLE_BATCH_EVENT}:${SET_BATCH_OVERRIDES}:${SCALE_CLUSTER}:${START_EXECUTION_MANAGER}:${START_EXECUTION_WORKER}:${DISABLE_PRIVATE_DNS}:${UPDATE_DB}:${UPLOAD_LOG}:${DYNAMO}:${APPS}
LAMBDA_LOGGING = ${PWD}/lib/lambda_logging
export PYTHONPATH = ${API}:${CHECK_PROCESSING_TIME}:${GET_FILES}:${HANDLE_BATCH_EVENT}:${SET_BATCH_OVERRIDES}:${SCALE_CLUSTER}:${START_EXECUTION_MANAGER}:${START_EXECUTION_WORKER}:${DISABLE_PRIVATE_DNS}:${UPDATE_DB}:${UPLOAD_LOG}:${DYNAMO}:${LAMBDA_LOGGING}:${APPS}


build: render
Expand Down
2 changes: 1 addition & 1 deletion apps/api/src/hyp3_api/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_jobs(
except util.TokenDeserializeError:
abort(problem_format(400, 'Invalid start_token value'))
jobs, last_evaluated_key = dynamo.jobs.query_jobs(user, start, end, status_code, name, job_type, start_key)
payload = {'jobs': jobs}
payload: dict = {'jobs': jobs}
if last_evaluated_key is not None:
next_token = util.serialize(last_evaluated_key)
payload['next'] = util.build_next_url(
Expand Down
4 changes: 3 additions & 1 deletion apps/api/src/hyp3_api/lambda_handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import serverless_wsgi

from hyp3_api import app
Expand All @@ -6,5 +8,5 @@
serverless_wsgi.TEXT_MIME_TYPES.append('application/problem+json')


def handler(event, context):
def handler(event: dict, context: Any) -> Any:
return serverless_wsgi.handle_request(app, event, context)
29 changes: 17 additions & 12 deletions apps/api/src/hyp3_api/routes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import datetime
import json
from collections.abc import Iterable
from decimal import Decimal
from os import environ
from pathlib import Path
from typing import Any

import werkzeug
import yaml
from flask import Response, abort, g, jsonify, make_response, redirect, render_template, request
from flask.json.provider import JSONProvider
Expand All @@ -27,15 +29,16 @@


@app.before_request
def check_system_available():
def check_system_available() -> Response | None:
if environ['SYSTEM_AVAILABLE'] != 'true':
message = 'HyP3 is currently unavailable. Please try again later.'
error = {'detail': message, 'status': 503, 'title': 'Service Unavailable', 'type': 'about:blank'}
return make_response(jsonify(error), 503)
return None


@app.before_request
def authenticate_user():
def authenticate_user() -> None:
cookie = request.cookies.get('asf-urs')
payload = auth.decode_token(cookie)
if payload is not None:
Expand All @@ -47,27 +50,27 @@ def authenticate_user():


@app.route('/')
def redirect_to_ui():
def redirect_to_ui() -> werkzeug.wrappers.response.Response:
return redirect('/ui/')


@app.route('/openapi.json')
def get_open_api_json():
def get_open_api_json() -> Response:
return jsonify(api_spec_dict)


@app.route('/openapi.yaml')
def get_open_api_yaml():
def get_open_api_yaml() -> str:
return yaml.dump(api_spec_dict)


@app.route('/ui/')
def render_ui():
def render_ui() -> str:
return render_template('index.html')


@app.errorhandler(404)
def error404(_):
def error404(_) -> Response:
return handlers.problem_format(
404,
'The requested URL was not found on the server.'
Expand All @@ -76,7 +79,9 @@ def error404(_):


class CustomEncoder(json.JSONEncoder):
def default(self, o):
def default(self, o: object) -> object:
# https://docs.python.org/3/library/json.html#json.JSONEncoder.default

if isinstance(o, datetime.datetime):
if o.tzinfo:
# eg: '2015-09-25T23:14:42.588601+00:00'
Expand All @@ -94,8 +99,8 @@ def default(self, o):
return int(o)
return float(o)

# Raises a TypeError
json.JSONEncoder.default(self, o)
# Let the base class default method raise the TypeError
return super().default(o)


class CustomJSONProvider(JSONProvider):
Expand All @@ -107,10 +112,10 @@ def loads(self, s: str | bytes, **kwargs: Any) -> Any:


class ErrorHandler(FlaskOpenAPIErrorsHandler):
def __init__(self):
def __init__(self) -> None:
super().__init__()

def __call__(self, errors):
def __call__(self, errors: Iterable[Exception]) -> Response:
response = super().__call__(errors)
error = response.json['errors'][0] # type: ignore[index]
return handlers.problem_format(error['status'], error['title'])
Expand Down
2 changes: 1 addition & 1 deletion apps/api/src/hyp3_api/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def deserialize(token: str) -> Any:
raise TokenDeserializeError


def build_next_url(url, start_token, x_forwarded_host=None, root_path=''):
def build_next_url(url: str, start_token: str, x_forwarded_host: str | None = None, root_path: str = '') -> str:
url_parts = list(urlparse(url))

if x_forwarded_host:
Expand Down
4 changes: 2 additions & 2 deletions apps/api/src/hyp3_api/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ def check_bounds_formatting(job: dict, _) -> None:
'Invalid order for bounds. Bounds should be ordered [min lon, min lat, max lon, max lat].'
)

def bad_lat(lat):
def bad_lat(lat: float) -> bool:
return lat > 90 or lat < -90

def bad_lon(lon):
def bad_lon(lon: float) -> bool:
return lon > 180 or lon < -180

if any([bad_lon(bounds[0]), bad_lon(bounds[2]), bad_lat(bounds[1]), bad_lat(bounds[3])]):
Expand Down
9 changes: 5 additions & 4 deletions apps/disable-private-dns/src/disable_private_dns.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
from typing import Any

import boto3


CLIENT = boto3.client('ec2')


def get_endpoint(vpc_id, endpoint_name):
def get_endpoint(vpc_id: str, endpoint_name: str) -> dict:
response = CLIENT.describe_vpc_endpoints()
endpoints = [endpoint for endpoint in response['VpcEndpoints'] if endpoint['VpcId'] == vpc_id]
if len(endpoints) == 0:
Expand All @@ -24,14 +25,14 @@ def get_endpoint(vpc_id, endpoint_name):
return desired_endpoint


def set_private_dns_disabled(endpoint_id):
def set_private_dns_disabled(endpoint_id: str) -> None:
response = CLIENT.modify_vpc_endpoint(VpcEndpointId=endpoint_id, PrivateDnsEnabled=False)
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2/client/modify_vpc_endpoint.html
assert response['Return'] is True, response
print(f'Private DNS disabled for VPC Endpoint: {endpoint_id}.')


def disable_private_dns(vpc_id, endpoint_name):
def disable_private_dns(vpc_id: str, endpoint_name: str) -> None:
endpoint = get_endpoint(vpc_id, endpoint_name)
if endpoint['PrivateDnsEnabled']:
print(f'Private DNS enabled for VPC Endpoint: {endpoint["VpcEndpointId"]}, changing...')
Expand All @@ -40,7 +41,7 @@ def disable_private_dns(vpc_id, endpoint_name):
print(f'Private DNS already disabled for VPC Endpoint: {endpoint["VpcEndpointId"]}, doing nothing.')


def lambda_handler(event, context):
def lambda_handler(event: dict, context: Any) -> None:
vpc_id = os.environ['VPCID']
endpoint_name = os.environ['ENDPOINT_NAME']
print(f'VPC ID {vpc_id}')
Expand Down
17 changes: 9 additions & 8 deletions apps/get-files/src/get_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from datetime import datetime
from os import environ
from pathlib import Path
from typing import Any

import boto3


S3_CLIENT = boto3.client('s3')


def get_download_url(bucket, key):
def get_download_url(bucket: str, key: str) -> str:
if distribution_url := os.getenv('DISTRIBUTION_URL'):
download_url = urllib.parse.urljoin(distribution_url, key)
else:
Expand All @@ -19,14 +20,14 @@ def get_download_url(bucket, key):
return download_url


def get_expiration_time(bucket, key):
def get_expiration_time(bucket: str, key: str) -> str:
s3_object = S3_CLIENT.get_object(Bucket=bucket, Key=key)
expiration_string = s3_object['Expiration'].split('"')[1]
expiration_datetime = datetime.strptime(expiration_string, '%a, %d %b %Y %H:%M:%S %Z')
return expiration_datetime.isoformat(timespec='seconds') + '+00:00'


def get_object_file_type(bucket, key):
def get_object_file_type(bucket: str, key: str) -> str | None:
response = S3_CLIENT.get_object_tagging(Bucket=bucket, Key=key)
for tag in response['TagSet']:
if tag['Key'] == 'file_type':
Expand All @@ -38,7 +39,7 @@ def visible_product(product_path: str | Path) -> bool:
return Path(product_path).suffix in ('.zip', '.nc', '.geojson')


def get_products(files):
def get_products(files: list[dict]) -> list[dict]:
return [
{
'url': item['download_url'],
Expand All @@ -51,17 +52,17 @@ def get_products(files):
]


def get_file_urls_by_type(file_list, file_type):
def get_file_urls_by_type(file_list: list[dict], file_type: str) -> list[str]:
files = [item for item in file_list if file_type in item['file_type']]
sorted_files = sorted(files, key=lambda x: x['file_type'])
urls = [item['download_url'] for item in sorted_files]
return urls


def organize_files(files_dict, bucket):
def organize_files(s3_objects: list[dict], bucket: str) -> dict:
all_files = []
expiration = None
for item in files_dict:
for item in s3_objects:
download_url = get_download_url(bucket, item['Key'])
file_type = get_object_file_type(bucket, item['Key'])
all_files.append(
Expand All @@ -88,7 +89,7 @@ def organize_files(files_dict, bucket):
}


def lambda_handler(event, context):
def lambda_handler(event: dict, context: Any) -> dict:
bucket = environ['BUCKET']

response = S3_CLIENT.list_objects_v2(Bucket=bucket, Prefix=event['job_id'])
Expand Down
2 changes: 1 addition & 1 deletion apps/render_cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def validate_job_spec(job_type: str, job_spec: dict) -> None:
raise ValueError(f'{job_type} has image {step["image"]} but docker requires the image to be all lowercase')


def main():
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument('-j', '--job-spec-files', required=True, nargs='+', type=Path)
parser.add_argument('-e', '--compute-environment-file', required=True, type=Path)
Expand Down
14 changes: 10 additions & 4 deletions apps/scale-cluster/src/scale_cluster.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import calendar
from datetime import date
from os import environ
from typing import Any

import boto3
import dateutil.relativedelta
Expand All @@ -24,7 +25,7 @@ def get_month_to_date_spending(today: date) -> float:
return float(response['ResultsByTime'][0]['Total']['UnblendedCost']['Amount'])


def get_current_desired_vcpus(compute_environment_arn):
def get_current_desired_vcpus(compute_environment_arn: str) -> int:
response = BATCH.describe_compute_environments(computeEnvironments=[compute_environment_arn])
return response['computeEnvironments'][0]['computeResources']['desiredvCpus']

Expand All @@ -50,8 +51,13 @@ def set_max_vcpus(compute_environment_arn: str, target_max_vcpus: int, current_d


def get_target_max_vcpus(
today, monthly_budget, month_to_date_spending, default_max_vcpus, expanded_max_vcpus, required_surplus
):
today: date,
monthly_budget: int,
month_to_date_spending: float,
default_max_vcpus: int,
expanded_max_vcpus: int,
required_surplus: int,
) -> int:
days_in_month = calendar.monthrange(today.year, today.month)[1]
month_to_date_budget = monthly_budget * today.day / days_in_month
available_surplus = month_to_date_budget - month_to_date_spending
Expand All @@ -68,7 +74,7 @@ def get_target_max_vcpus(
return max_vcpus


def lambda_handler(event, context):
def lambda_handler(event: dict, context: Any) -> None:
target_max_vcpus = get_target_max_vcpus(
today=date.today(),
monthly_budget=int(environ['MONTHLY_BUDGET']),
Expand Down
4 changes: 3 additions & 1 deletion apps/update-db/src/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any

from dynamo import jobs


def lambda_handler(event, context):
def lambda_handler(event: dict, context: Any) -> None:
jobs.update_job(event)
7 changes: 4 additions & 3 deletions apps/upload-log/src/upload_log.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from os import environ
from typing import Any

import boto3
from botocore.config import Config
Expand All @@ -16,7 +17,7 @@ def get_log_stream(result: dict) -> str | None:
return result['Container'].get('LogStreamName')


def get_log_content(log_group, log_stream):
def get_log_content(log_group: str, log_stream: str) -> str:
response = CLOUDWATCH.get_log_events(logGroupName=log_group, logStreamName=log_stream, startFromHead=True)
messages = [event['message'] for event in response['events']]

Expand Down Expand Up @@ -44,7 +45,7 @@ def get_log_content_from_failed_attempts(cause: dict) -> str:
return content


def write_log_to_s3(bucket, prefix, content):
def write_log_to_s3(bucket: str, prefix: str, content: str) -> None:
key = f'{prefix}/{prefix}.log'
S3.put_object(Bucket=bucket, Key=key, Body=content, ContentType='text/plain')
tag_set = {
Expand All @@ -58,7 +59,7 @@ def write_log_to_s3(bucket, prefix, content):
S3.put_object_tagging(Bucket=bucket, Key=key, Tagging=tag_set)


def lambda_handler(event, context):
def lambda_handler(event: dict, context: Any) -> None:
results_dict = event['processing_results']
result = results_dict[max(results_dict.keys())]

Expand Down
Loading

0 comments on commit f803c97

Please sign in to comment.