Skip to content

Commit

Permalink
Add benchmarking logging (#725)
Browse files Browse the repository at this point in the history
* Add get_memory_usage function

* Rewrite memory usage functions

* Added memory usage logging to finalise.py

* Added logging to forced_extraction.py

* Added debug logging to loading.py

* Added debug logging to model_generator.py

* Added debug logging to main.py

* Added debug logging to new_sources.py

* Import psutil

* Fix truncated variable name

* Fixed logging in calculate_n_partitions

* Add some logging to partially address #718

* ?

* PEP8

* PEP8

* More PEP8

* PEP8 hell

* Updated changelog
  • Loading branch information
ddobie authored Aug 9, 2024
1 parent cc690e8 commit 266f5a7
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 77 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

#### Added

- Added further memory usage and timing debug logging [#725](https://github.com/askap-vast/vast-pipeline/pull/725)
- Add support for python 3.10 [#740](https://github.com/askap-vast/vast-pipeline/pull/740)
- Added support calculate_n_partitions for sensible dask dataframe partitioning [#724](https://github.com/askap-vast/vast-pipeline/pull/724)
- Added support for compressed FITS files [#694](https://github.com/askap-vast/vast-pipeline/pull/694)
Expand Down Expand Up @@ -126,6 +127,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

#### List of PRs

- [#725](https://github.com/askap-vast/vast-pipeline/pull/725): feat: Added further memory usage and timing debug logging
- [#740](https://github.com/askap-vast/vast-pipeline/pull/740): feat: Add support for python 3.10
- [#728](https://github.com/askap-vast/vast-pipeline/pull/728): fix: Adjust package versions and fix mkdocs serve issues
- [#728](https://github.com/askap-vast/vast-pipeline/pull/728): fix: Adjust package versions and fix python 3.9 tests breaking on github actions
Expand Down
114 changes: 89 additions & 25 deletions vast_pipeline/pipeline/finalise.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
update_sources
)
from vast_pipeline.pipeline.pairs import calculate_measurement_pair_metrics
from vast_pipeline.pipeline.utils import parallel_groupby

from vast_pipeline.pipeline.utils import (
parallel_groupby, get_df_memory_usage, log_total_memory_usage
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,7 +68,8 @@ def calculate_measurement_pair_aggregate_metrics(
check_df
.groupby("source")
.agg(m_abs_max_idx=(f"m_{flux_type}", lambda x: x.abs().idxmax()),)
.astype(np.int32)["m_abs_max_idx"] # cast row indices to int and select them
# cast row indices to int and select them
.astype(np.int32)["m_abs_max_idx"]
.reset_index(drop=True) # keep only the row indices
][[f"vs_{flux_type}", f"m_{flux_type}", "source"]]

Expand Down Expand Up @@ -107,8 +109,8 @@ def final_operations(
The new sources dataframe, only contains the
'new_source_high_sigma' column (source_id is the index).
calculate_pairs:
Whether to calculate the measurement pairs and their 2-epoch metrics, Vs and
m.
Whether to calculate the measurement pairs and their 2-epoch
metrics, Vs and m.
source_aggregate_pair_metrics_min_abs_vs:
Only measurement pairs where the Vs metric exceeds this value
are selected for the aggregate pair metrics that are stored in
Expand All @@ -120,10 +122,10 @@ def final_operations(
in the previous run in add mode.
Returns:
The number of sources contained in the pipeline run (used in the next steps
of main.py).
The number of new sources contained in the pipeline run (used in the next steps
of main.py).
The number of sources contained in the pipeline run (used in the next
steps of main.py).
The number of new sources contained in the pipeline run (used in the
next steps of main.py).
"""
timer = StopWatch()

Expand All @@ -132,8 +134,14 @@ def final_operations(
'Calculating statistics for %i sources...',
sources_df.source.unique().shape[0]
)
log_total_memory_usage()

srcs_df = parallel_groupby(sources_df)

mem_usage = get_df_memory_usage(srcs_df)
logger.info('Groupby-apply time: %.2f seconds', timer.reset())
logger.debug(f"Initial srcs_df memory: {mem_usage}MB")
log_total_memory_usage()

# add new sources
srcs_df["new"] = srcs_df.index.isin(new_sources_df.index)
Expand All @@ -146,6 +154,10 @@ def final_operations(
)
srcs_df["new_high_sigma"] = srcs_df["new_high_sigma"].fillna(0.0)

mem_usage = get_df_memory_usage(srcs_df)
logger.debug(f"srcs_df memory after adding new sources: {mem_usage}MB")
log_total_memory_usage()

# calculate nearest neighbour
srcs_skycoord = SkyCoord(
srcs_df['wavg_ra'].values,
Expand All @@ -160,38 +172,58 @@ def final_operations(
# add the separation distance in degrees
srcs_df['n_neighbour_dist'] = d2d.deg

mem_usage = get_df_memory_usage(srcs_df)
logger.debug(f"srcs_df memory after nearest-neighbour: {mem_usage}MB")
log_total_memory_usage()

# create measurement pairs, aka 2-epoch metrics
if calculate_pairs:
timer.reset()
measurement_pairs_df = calculate_measurement_pair_metrics(sources_df)
logger.info('Measurement pair metrics time: %.2f seconds', timer.reset())

# calculate measurement pair metric aggregates for sources by finding the row indices
# of the aggregate max of the abs(m) metric for each flux type.
logger.info(
'Measurement pair metrics time: %.2f seconds',
timer.reset())
mem_usage = get_df_memory_usage(measurement_pairs_df)
logger.debug(f"measurment_pairs_df memory: {mem_usage}MB")
log_total_memory_usage()

# calculate measurement pair metric aggregates for sources by finding
# the row indices of the aggregate max of the abs(m) metric for each
# flux type.
pair_agg_metrics = pd.merge(
calculate_measurement_pair_aggregate_metrics(
measurement_pairs_df, source_aggregate_pair_metrics_min_abs_vs, flux_type="peak",
measurement_pairs_df,
source_aggregate_pair_metrics_min_abs_vs,
flux_type="peak",
),
calculate_measurement_pair_aggregate_metrics(
measurement_pairs_df, source_aggregate_pair_metrics_min_abs_vs, flux_type="int",
measurement_pairs_df,
source_aggregate_pair_metrics_min_abs_vs,
flux_type="int",
),
how="outer",
left_index=True,
right_index=True,
)

# join with sources and replace agg metrics NaNs with 0 as the DataTables API JSON
# serialization doesn't like them
# join with sources and replace agg metrics NaNs with 0 as the
# DataTables API JSON serialization doesn't like them
srcs_df = srcs_df.join(pair_agg_metrics).fillna(value={
"vs_abs_significant_max_peak": 0.0,
"m_abs_significant_max_peak": 0.0,
"vs_abs_significant_max_int": 0.0,
"m_abs_significant_max_int": 0.0,
})
logger.info("Measurement pair aggregate metrics time: %.2f seconds", timer.reset())
logger.info(
"Measurement pair aggregate metrics time: %.2f seconds",
timer.reset())
mem_usage = get_df_memory_usage(srcs_df)
logger.debug(f"srcs_df memory after calculate_pairs: {mem_usage}MB")
log_total_memory_usage()
else:
logger.info(
"Skipping measurement pair metric calculation as specified in the run configuration."
"Skipping measurement pair metric calculation as specified in "
"the run configuration."
)

# upload sources to DB, column 'id' with DB id is contained in return
Expand All @@ -201,18 +233,39 @@ def final_operations(
# upload new ones first (new id's are fetched)
src_done_mask = srcs_df.index.isin(done_source_ids)
srcs_df_upload = srcs_df.loc[~src_done_mask].copy()

mem_usage = get_df_memory_usage(srcs_df_upload)
logger.debug(f"srcs_df_upload initial memory: {mem_usage}MB")
log_total_memory_usage()

srcs_df_upload = make_upload_sources(srcs_df_upload, p_run, add_mode)

mem_usage = get_df_memory_usage(srcs_df_upload)
logger.debug(f"srcs_df_upload memory after upload: {mem_usage}MB")
log_total_memory_usage()

# And now update
srcs_df_update = srcs_df.loc[src_done_mask].copy()
logger.info(
f"Updating {srcs_df_update.shape[0]} sources with new metrics.")
mem_usage = get_df_memory_usage(srcs_df_update)
logger.debug(f"srcs_df_update memory: {mem_usage}MB")
log_total_memory_usage()

srcs_df = update_sources(srcs_df_update, batch_size=1000)
mem_usage = get_df_memory_usage(srcs_df_update)
logger.debug(f"srcs_df_update memory: {mem_usage}MB")
log_total_memory_usage()
# Add back together
if not srcs_df_upload.empty:
srcs_df = pd.concat([srcs_df, srcs_df_upload])
else:
srcs_df = make_upload_sources(srcs_df, p_run, add_mode)

mem_usage = get_df_memory_usage(srcs_df)
logger.debug(f"srcs_df memory after uploading sources: {mem_usage}MB")
log_total_memory_usage()

# gather the related df, upload to db and save to parquet file
# the df will look like
#
Expand All @@ -228,13 +281,17 @@ def final_operations(
related_df = (
srcs_df.loc[srcs_df["related_list"] != -1, ["id", "related_list"]]
.explode("related_list")
.rename(columns={"id": "from_source_id", "related_list": "to_source_id"})
.rename(columns={"id": "from_source_id",
"related_list": "to_source_id"
})
)

# for the column 'from_source_id', replace relation source ids with db id
related_df["to_source_id"] = related_df["to_source_id"].map(srcs_df["id"].to_dict())
related_df["to_source_id"] = related_df["to_source_id"].map(
srcs_df["id"].to_dict())
# drop relationships with the same source
related_df = related_df[related_df["from_source_id"] != related_df["to_source_id"]]
related_df = related_df[related_df["from_source_id"]
!= related_df["to_source_id"]]

# write symmetrical relations to parquet
related_df.to_parquet(
Expand Down Expand Up @@ -263,16 +320,21 @@ def final_operations(
# write sources to parquet file
srcs_df = srcs_df.drop(["related_list", "img_list"], axis=1)
(
srcs_df.set_index('id') # set the index to db ids, dropping the source idx
# set the index to db ids, dropping the source idx
srcs_df.set_index('id')
.to_parquet(os.path.join(p_run.path, 'sources.parquet'))
)

# update measurments with sources to get associations
# update measurements with sources to get associations
sources_df = (
sources_df.drop('related', axis=1)
.merge(srcs_df.rename(columns={'id': 'source_id'}), on='source')
)

mem_usage = get_df_memory_usage(sources_df)
logger.debug(f"sources_df memory after srcs_df merge: {mem_usage}MB")
log_total_memory_usage()

if add_mode:
# Load old associations so the already uploaded ones can be removed
old_assoications = (
Expand Down Expand Up @@ -317,7 +379,9 @@ def final_operations(
os.path.join(p_run.path, "measurement_pairs.parquet"), index=False
)

logger.info("Total final operations time: %.2f seconds", timer.reset_init())
logger.info(
"Total final operations time: %.2f seconds",
timer.reset_init())

nr_sources = srcs_df["id"].count()
nr_new_sources = srcs_df['new'].sum()
Expand Down
4 changes: 4 additions & 0 deletions vast_pipeline/pipeline/forced_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def extract_from_image(
Dictionary with input dataframe with added columns (flux_int,
flux_int_err, chi_squared_fit) and image name.
"""
timer = StopWatch()

# create the skycoord obj to pass to the forced extraction
# see usage https://github.com/dlakaplan/forced_phot
P_islands = SkyCoord(
Expand Down Expand Up @@ -193,6 +195,8 @@ def extract_from_image(
df['flux_int_err'] = flux_err * 1.e3
df['chi_squared_fit'] = chisq

logger.debug(f"Time to measure FP for {image}: {timer.reset()}s")

return {'df': df, 'image': df['image_name'].iloc[0]}


Expand Down
42 changes: 33 additions & 9 deletions vast_pipeline/pipeline/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
Association, Band, Measurement, SkyRegion, Source, RelatedSource,
Run, Image
)
from vast_pipeline.pipeline.utils import get_create_img, get_create_img_band
from vast_pipeline.pipeline.utils import (
get_create_img, get_create_img_band,
get_df_memory_usage, log_total_memory_usage
)
from vast_pipeline.utils.utils import StopWatch


Expand All @@ -29,8 +32,8 @@
def bulk_upload_model(
djmodel: models.Model,
generator: Iterable[Generator[models.Model, None, None]],
batch_size: int=10_000,
return_ids: bool=False,
batch_size: int = 10_000,
return_ids: bool = False,
) -> List[int]:
'''
Bulk upload a list of generator objects of django models to db.
Expand All @@ -51,7 +54,7 @@ def bulk_upload_model(
'''
reset_queries()

bulk_ids = []
while True:
items = list(islice(generator, batch_size))
Expand Down Expand Up @@ -168,6 +171,12 @@ def make_upload_sources(
Returns:
The input dataframe with the 'id' column added.
'''

logger.debug("Uploading sources...")
mem_usage = get_df_memory_usage(sources_df)
logger.debug(f"sources_df memory usage: {mem_usage}MB")
log_total_memory_usage()

# create sources in DB
with transaction.atomic():
if (add_mode is False and
Expand Down Expand Up @@ -207,6 +216,9 @@ def make_upload_related_sources(related_df: pd.DataFrame) -> None:
None.
"""
logger.info('Populate "related" field of sources...')
mem_usage = get_df_memory_usage(related_df)
logger.debug(f"related_df memory usage: {mem_usage}MB")
log_total_memory_usage()
bulk_upload_model(RelatedSource, related_models_generator(related_df))


Expand All @@ -223,11 +235,17 @@ def make_upload_associations(associations_df: pd.DataFrame) -> None:
None.
"""
logger.info('Upload associations...')

mem_usage = get_df_memory_usage(associations_df)
logger.debug(f"associations_df memory usage: {mem_usage}MB")
log_total_memory_usage()

assoc_chunk_size = 100000
for i in range(0,len(associations_df),assoc_chunk_size):
for i in range(0, len(associations_df), assoc_chunk_size):
bulk_upload_model(
Association,
association_models_generator(associations_df[i:i+assoc_chunk_size])
association_models_generator(
associations_df[i:i + assoc_chunk_size])
)


Expand All @@ -243,6 +261,12 @@ def make_upload_measurements(measurements_df: pd.DataFrame) -> pd.DataFrame:
Returns:
Original DataFrame with the database ID attached to each row.
"""

logger.info("Upload measurements...")
mem_usage = get_df_memory_usage(measurements_df)
logger.debug(f"measurements_df memory usage: {mem_usage}MB")
log_total_memory_usage()

meas_dj_ids = bulk_upload_model(
Measurement,
measurement_models_generator(measurements_df),
Expand Down Expand Up @@ -287,7 +311,7 @@ def update_sources(

sources_df['id'] = sources_df.index.values

batches = np.ceil(len(sources_df)/batch_size)
batches = np.ceil(len(sources_df) / batch_size)
dfs = np.array_split(sources_df, batches)
with connection.cursor() as cursor:
for df_batch in dfs:
Expand Down Expand Up @@ -332,8 +356,8 @@ def SQL_update(

# get names
table = model._meta.db_table
new_columns = ', '.join('new_'+c for c in columns)
set_columns = ', '.join(c+'=new_'+c for c in columns)
new_columns = ', '.join('new_' + c for c in columns)
set_columns = ', '.join(c + '=new_' + c for c in columns)

# get index values and new values
column_headers = [index]
Expand Down
Loading

0 comments on commit 266f5a7

Please sign in to comment.