Skip to content

Commit

Permalink
Further edits - batching
Browse files Browse the repository at this point in the history
  • Loading branch information
kjaisingh committed Dec 2, 2024
1 parent d822ecb commit d75af10
Showing 1 changed file with 45 additions and 6 deletions.
51 changes: 45 additions & 6 deletions scripts/notebooks/Batching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@
" None.\n",
" \"\"\"\n",
" for i in input_vals:\n",
" if not isinstance(i, str) or not i:\n",
" if not isinstance(i, str):\n",
" raise Exception('Value input must be a string.')\n",
" \n",
" if log:\n",
Expand Down Expand Up @@ -427,6 +427,41 @@
" print(\"Inputs are valid - please proceed to the next cell.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"def validate_batch_sizes(df, target_batch_size, min_batch_size, max_batch_size):\n",
" \"\"\"\n",
" Validates user input to check whether they are valid bins to use when batching. \n",
" \n",
" Args:\n",
" df (pd.DataFrame): Dataframe with sample data.\n",
" target_batch_size (int): Target size of each batch.\n",
" min_batch_size (int): Minimum size of each batch.\n",
" max_batch_size (int): Maximum size of each batch.\n",
" \n",
" Returns:\n",
" None.\n",
" \"\"\"\n",
" validate_numeric_inputs([target_batch_size, min_batch_size, max_batch_size], log=False)\n",
" \n",
" if (target_batch_size < min_batch_size):\n",
" raise Exception(\"TARGET_BATCH_SIZE must exceed MIN_BATCH_SIZE.\")\n",
" \n",
" if (max_batch_size < target_batch_size):\n",
" raise Exception(\"MAX_BATCH_SIZE must exceed TARGET_BATCH_SIZE.\")\n",
" \n",
" if (len(df) < min_batch_size):\n",
" raise Exception(\"MIN_BATCH_SIZE must exceed the number of samples.\")\n",
" \n",
" print(\"Inputs are valid - please proceed to the next cell.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -883,9 +918,13 @@
" Tuple of batches dictionary and batches metadata dictionary.\n",
" \"\"\"\n",
" # Validation\n",
" if not min_batch_size <= len(df):\n",
" raise Exception('MIN_BATCH_SIZE must not exceed the number of samples.')\n",
" \n",
" estimated_samples_per_split = len(df)\n",
" for bins in include_bins:\n",
" estimated_samples_per_split /= bins\n",
" if estimated_samples_per_split < min_batch_size:\n",
" raise Exception(f\"Based on INCLUDE_BINS, expected samples per split ({estimated_samples_per_split}) \" +\n",
" \"is less than MIN_BATCH_SIZE ({min_batch_size}). Please adjust parameters accordingly.\")\n",
" \n",
" # Group related samples and select family representatives\n",
" related_groups = group_related_samples(df, reference_ped)\n",
" family_representatives = select_family_representatives(reference_ped, related_groups) if reference_ped is not None else set(df['sample_id'])\n",
Expand Down Expand Up @@ -1021,7 +1060,7 @@
"duplicates_dict = id_counts[id_counts > 1].to_dict()\n",
"\n",
"if (len(duplicates_dict) > 0):\n",
" print(f\"{len(duplicates_dict)} duplicate 'sample_id' exist in the dataset.\")\n",
" print(f\"{len(duplicates_dict)} duplicate 'sample_id' e xist in the dataset.\")\n",
" for sample_id, count in duplicates_dict.items():\n",
" print(f\"Sample ID: {sample_id}, Count: {count}\")\n",
" raise Exception(\"Batching requires unique 'sample_id' - please resolve this before proceeding.\")\n",
Expand Down Expand Up @@ -1118,7 +1157,7 @@
"MINIMUM_BATCH_SIZE = None\n",
"MAXIMUM_BATCH_SIZE = None\n",
"\n",
"validate_numeric_inputs([TARGET_BATCH_SIZE, MINIMUM_BATCH_SIZE, MAXIMUM_BATCH_SIZE])"
"validate_batch_sizes(pass_df, TARGET_BATCH_SIZE, MINIMUM_BATCH_SIZE, MAXIMUM_BATCH_SIZE)"
]
},
{
Expand Down

0 comments on commit d75af10

Please sign in to comment.