Skip to content

Commit

Permalink
removing unnecessary type calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
joker2411 committed Feb 26, 2025
1 parent db134bb commit c990d44
Showing 1 changed file with 16 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def preprocess_and_predict(
stage_name = results["model_info"]["file_location"]["stage"]
pkl_model_file_name = results["model_info"]["file_location"]["file_name"]

score_column = trainer.outputs.column_names.get("score")
percentile_column = trainer.outputs.column_names.get("percentile")
output_column = trainer.outputs.column_names.get("output_label_column")
model_id_column = "model_id"

input_column_types = results["column_names"]["input_column_types"]
numeric_columns = results["column_names"]["input_column_types"]["numeric"]
categorical_columns = results["column_names"]["input_column_types"]["categorical"]
Expand Down Expand Up @@ -257,9 +262,9 @@ def end_partition(self, df):
prediction_udf,
trainer.entity_column,
trainer.index_timestamp,
trainer.outputs.column_names.get("score"),
trainer.outputs.column_names.get("percentile"),
trainer.outputs.column_names.get("output_label_column"),
score_column,
percentile_column,
output_column,
train_model_id,
input_df,
trainer.pred_output_df_columns,
Expand Down Expand Up @@ -327,16 +332,16 @@ def predict_scores_rs(df: pd.DataFrame) -> pd.DataFrame:
batch_result[trainer.index_timestamp] = batch_predict_data[
trainer.index_timestamp
]
batch_result["model_id"] = train_model_id
batch_result[model_id_column] = train_model_id

# Add predictions
batch_result[trainer.outputs.column_names.get("score")] = batch_predictions[
batch_result[score_column] = batch_predictions[
trainer.pred_output_df_columns["score"]
].round(5)
if "label" in trainer.pred_output_df_columns:
batch_result[
trainer.outputs.column_names.get("output_label_column")
] = batch_predictions[trainer.pred_output_df_columns["label"]]
batch_result[output_column] = batch_predictions[
trainer.pred_output_df_columns["label"]
]

logger.get().debug("Writing predictions to warehouse")
if first_batch:
Expand All @@ -350,30 +355,14 @@ def predict_scores_rs(df: pd.DataFrame) -> pd.DataFrame:
)
first_batch = False
else:
columns = ", ".join(batch_result.columns)
values_list = []
for _, row in batch_result.iterrows():
formatted_values = []
for val in row:
if isinstance(val, str):
formatted_values.append(f"'{val}'")
elif isinstance(
val, (pd.Timestamp, datetime.datetime, datetime.date)
):
formatted_values.append(f"'{val}'")
elif isinstance(val, bool):
formatted_values.append(str(val).upper())
elif pd.isna(val) or val is None:
formatted_values.append("NULL")
else:
formatted_values.append(str(val))

row_values = f"({', '.join(formatted_values)})"
row_values = f"('{row[trainer.entity_column]}', '{row[trainer.index_timestamp]}', '{row[model_id_column]}', {row[score_column]}, {row[output_column]})"
values_list.append(row_values)

values_clause = ",\n ".join(values_list)

insert_query = f"""INSERT INTO {output_tablename} ({columns})
insert_query = f"""INSERT INTO {output_tablename} ({trainer.entity_column}, {trainer.index_timestamp}, {model_id_column}, {score_column}, {output_column})
VALUES
{values_clause}"""
connector.run_query(insert_query, response=False)
Expand All @@ -387,8 +376,6 @@ def predict_scores_rs(df: pd.DataFrame) -> pd.DataFrame:
)

# Calculate percentiles using SQL directly in the warehouse
score_column = trainer.outputs.column_names.get("score")
percentile_column = trainer.outputs.column_names.get("percentile")
percentile_column_type = "FLOAT" if creds["type"] == "redshift" else "FLOAT64"

percentile_column_create_query = f"""
Expand Down Expand Up @@ -451,7 +438,7 @@ def predict_scores_rs(df: pd.DataFrame) -> pd.DataFrame:
metrics_df = pd.DataFrame(
{
"model_name": [trainer.output_profiles_ml_model],
"model_id": [model_id],
model_id_column: [model_id],
"prediction_date": [prediction_date],
"label_date": [end_ts],
"prediction_table_name": [prev_prediction_table],
Expand Down

0 comments on commit c990d44

Please sign in to comment.