Skip to content

Commit

Permalink
Modern_diag_manager: Fix error messages + Fixes for opemp (#1432)
Browse files Browse the repository at this point in the history
* Changes function or subroutine

* Fix 'race condition'

---------

Co-authored-by: Uriel Ramirez <[email protected]>
  • Loading branch information
uramirez8707 and Uriel Ramirez authored Jan 9, 2024
1 parent ebb3264 commit 17cd78e
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 26 deletions.
12 changes: 5 additions & 7 deletions diag_manager/diag_manager.F90
Original file line number Diff line number Diff line change
Expand Up @@ -1781,8 +1781,8 @@ LOGICAL FUNCTION diag_send_data(diag_field_id, field, time, is_in, js_in, ks_in,
! Split old and modern2023 here
modern_if: iF (use_modern_diag) then
field_name = fms_diag_object%fms_get_field_name_from_id(diag_field_id)
field_remap = copy_3d_to_4d(field, trim(field_name)//"'s data")
if (present(rmask)) rmask_remap = copy_3d_to_4d(rmask, trim(field_name)//"'s mask")
call copy_3d_to_4d(field, field_remap, trim(field_name)//"'s data")
if (present(rmask)) call copy_3d_to_4d(rmask, rmask_remap, trim(field_name)//"'s mask")
if (present(mask)) then
allocate(mask_remap(1:size(mask,1), 1:size(mask,2), 1:size(mask,3), 1))
mask_remap(:,:,:,1) = mask
Expand Down Expand Up @@ -4586,12 +4586,10 @@ SUBROUTINE diag_field_add_cell_measures(diag_field_id, area, volume)
END SUBROUTINE diag_field_add_cell_measures

!> @brief Copies a 3d buffer to a 4d buffer
!> @return a 4d buffer
function copy_3d_to_4d(data_in, field_name) &
result(data_out)
subroutine copy_3d_to_4d(data_in, data_out, field_name)
class (*), intent(in) :: data_in(:,:,:) !< Data to copy
character(len=*), intent(in) :: field_name !< Name of the field copying (for error messages)
class (*), allocatable :: data_out(:,:,:,:)
class (*), allocatable, intent(out) :: data_out(:,:,:,:) !< 4D version of the data

!TODO this should be extended to integers
select type(data_in)
Expand All @@ -4617,7 +4615,7 @@ function copy_3d_to_4d(data_in, field_name) &
call mpp_error(FATAL, "The data for "//trim(field_name)//&
&" is not a valid type. Currently only r4 and r8 are supported")
end select
end function copy_3d_to_4d
end subroutine copy_3d_to_4d

END MODULE diag_manager_mod
!> @}
Expand Down
10 changes: 8 additions & 2 deletions diag_manager/fms_diag_field_object.F90
Original file line number Diff line number Diff line change
Expand Up @@ -1676,15 +1676,21 @@ subroutine allocate_mask(this, mask_in, omp_axis)
end subroutine allocate_mask

!> Sets previously allocated mask to mask_in at given index ranges
subroutine set_mask(this, mask_in, is, js, ks, ie, je, ke)
subroutine set_mask(this, mask_in, field_info, is, js, ks, ie, je, ke)
class(fmsDiagField_type), intent(inout) :: this
logical, intent(in) :: mask_in(:,:,:,:)
character(len=*), intent(in) :: field_info !< Field info to add to error message
integer, optional, intent(in) :: is, js, ks, ie, je, ke
if(present(is)) then
if(is .lt. lbound(this%mask,1) .or. ie .gt. ubound(this%mask,1) .or. &
js .lt. lbound(this%mask,2) .or. je .gt. ubound(this%mask,2) .or. &
ks .lt. lbound(this%mask,3) .or. ke .gt. ubound(this%mask,3)) then
print *, mpp_pe(), "alloc'd", SHAPE(this%mask), "passed:", is,ie,js,je,ks,ke
print *, "PE:", int2str(mpp_pe()), "The size of the mask is", &
SHAPE(this%mask), &
"But the indices passed in are is=", int2str(is), " ie=", int2str(ie),&
" js=", int2str(js), " je=", int2str(je), &
" ks=", int2str(ks), " ke=", int2str(ke), &
" ", trim(field_info)
call mpp_error(FATAL,"set_mask:: given indices out of bounds for allocated mask")
endif
this%mask(is:ie, js:je, ks:ke, :) = mask_in
Expand Down
2 changes: 1 addition & 1 deletion diag_manager/fms_diag_file_object.F90
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,7 @@ pure function get_buffer_ids (this)
integer, allocatable :: get_buffer_ids(:) !< returned buffer ids for this file

allocate(get_buffer_ids(this%number_of_buffers))
get_buffer_ids = this%buffer_ids
get_buffer_ids = this%buffer_ids(1:this%number_of_buffers)
end function get_buffer_ids

!> Gets the stored number of buffers from the file object
Expand Down
39 changes: 26 additions & 13 deletions diag_manager/fms_diag_object.F90
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,8 @@ logical function fms_diag_accept_data (this, diag_field_id, field_data, mask, rm
#ifndef use_yaml
CALL MPP_ERROR(FATAL,"You can not use the modern diag manager without compiling with -Duse_yaml")
#else
field_info = " Check send data call for field:"//trim(this%FMS_diag_fields(diag_field_id)%get_varname())
field_info = " Check send data call for field:"//trim(this%FMS_diag_fields(diag_field_id)%get_varname())//&
" and module:"//trim(this%FMS_diag_fields(diag_field_id)%get_modname())

!< Check if time should be present for this field
if (.not.this%FMS_diag_fields(diag_field_id)%is_static() .and. .not.present(time)) &
Expand All @@ -539,10 +540,6 @@ logical function fms_diag_accept_data (this, diag_field_id, field_data, mask, rm
!< Set the field_weight. If "weight" is not present it will be set to 1.0_r8_kind
field_weight = set_weight(weight)

!< Set the variable type based off passed in field data
if(.not. this%FMS_diag_fields(diag_field_id)%has_vartype()) &
call this%FMS_diag_fields(diag_field_id)%set_type(field_data(1,1,1,1))

!< Check that the indices are present in the correct combination
error_string = check_indices_order(is_in, ie_in, js_in, je_in)
if (trim(error_string) .ne. "") call mpp_error(FATAL, trim(error_string)//". "//trim(field_info))
Expand All @@ -555,16 +552,11 @@ logical function fms_diag_accept_data (this, diag_field_id, field_data, mask, rm
if ((present(is_in) .and. present(ie_in)) .or. (present(js_in) .and. present(je_in))) &
has_halos = .true.

if(has_halos) call this%FMS_diag_fields(diag_field_id)%set_halo_present()

!< If the field has `mask_variant=.true.`, check that mask OR rmask are present
if (this%FMS_diag_fields(diag_field_id)%is_mask_variant()) then
if (.not. allocated(mask) .and. .not. allocated(rmask)) call mpp_error(FATAL, &
"The field was registered with mask_variant, but mask or rmask are not present in the send_data call. "//&
trim(field_info))
else
if (allocated(mask) .or. allocated(rmask)) &
call this%FMS_diag_fields(diag_field_id)%set_mask_variant(.True.)
endif

!< Check that mask and rmask are not both present
Expand Down Expand Up @@ -606,6 +598,17 @@ logical function fms_diag_accept_data (this, diag_field_id, field_data, mask, rm
main_if: if (buffer_the_data) then
!> Only 1 thread allocates the output buffer and sets set_math_needs_to_be_done
!$omp critical
!< These set_* calls need to be done inside an omp_critical to avoid any race conditions
!! and allocation issues
if(has_halos) call this%FMS_diag_fields(diag_field_id)%set_halo_present()

!< Set the variable type based off passed in field data
if(.not. this%FMS_diag_fields(diag_field_id)%has_vartype()) &
call this%FMS_diag_fields(diag_field_id)%set_type(field_data(1,1,1,1))

if (allocated(mask) .or. allocated(rmask)) &
call this%FMS_diag_fields(diag_field_id)%set_mask_variant(.True.)

if (.not. this%FMS_diag_fields(diag_field_id)%is_data_buffer_allocated()) then
data_buffer_is_allocated = &
this%FMS_diag_fields(diag_field_id)%allocate_data_buffer(field_data, this%diag_axis)
Expand All @@ -617,10 +620,21 @@ logical function fms_diag_accept_data (this, diag_field_id, field_data, mask, rm
!$omp end critical
call this%FMS_diag_fields(diag_field_id)%set_data_buffer(field_data, field_weight, &
is, js, ks, ie, je, ke)
call this%FMS_diag_fields(diag_field_id)%set_mask(oor_mask, is, js, ks, ie, je, ke)
call this%FMS_diag_fields(diag_field_id)%set_mask(oor_mask, field_info, is, js, ks, ie, je, ke)
fms_diag_accept_data = .TRUE.
return
else
!< At this point if we are no longer in an openmp region or running with 1 thread
!! so it is safe to have these set_* calls
if(has_halos) call this%FMS_diag_fields(diag_field_id)%set_halo_present()

!< Set the variable type based off passed in field data
if(.not. this%FMS_diag_fields(diag_field_id)%has_vartype()) &
call this%FMS_diag_fields(diag_field_id)%set_type(field_data(1,1,1,1))

if (allocated(mask) .or. allocated(rmask)) &
call this%FMS_diag_fields(diag_field_id)%set_mask_variant(.True.)

error_string = bounds%set_bounds(field_data, is, ie, js, je, ks, ke, has_halos)
if (trim(error_string) .ne. "") call mpp_error(FATAL, trim(error_string)//". "//trim(field_info))

Expand All @@ -631,7 +645,7 @@ logical function fms_diag_accept_data (this, diag_field_id, field_data, mask, rm
call this%FMS_diag_fields(diag_field_id)%set_math_needs_to_be_done(.FALSE.)
if(.not. this%FMS_diag_fields(diag_field_id)%has_mask_allocated()) &
call this%FMS_diag_fields(diag_field_id)%allocate_mask(oor_mask)
call this%FMS_diag_fields(diag_field_id)%set_mask(oor_mask)
call this%FMS_diag_fields(diag_field_id)%set_mask(oor_mask, field_info)
return
end if main_if
!> Return false if nothing is done
Expand Down Expand Up @@ -757,7 +771,6 @@ subroutine fms_diag_do_io(this, is_end_of_run)

! finish reduction method if its time to write
buff_reduct: if (is_writing) then
allocate(buff_ids(diag_file%FMS_diag_file%get_number_of_buffers()))
buff_ids = diag_file%FMS_diag_file%get_buffer_ids()
! loop through the buffers and finish reduction if needed
buff_loop: do ibuff=1, SIZE(buff_ids)
Expand Down
8 changes: 5 additions & 3 deletions diag_manager/fms_diag_yaml.F90
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ module fms_diag_yaml_mod
get_block_ids, get_key_value, get_key_ids, get_key_name
use mpp_mod, only: mpp_error, FATAL, mpp_pe, mpp_root_pe, stdout
use, intrinsic :: iso_c_binding, only : c_ptr, c_null_char
use fms_string_utils_mod, only: fms_array_to_pointer, fms_find_my_string, fms_sort_this, fms_find_unique
use fms_string_utils_mod, only: fms_array_to_pointer, fms_find_my_string, fms_sort_this, fms_find_unique, string
use platform_mod, only: r4_kind, i4_kind
use fms_mod, only: lowercase

Expand Down Expand Up @@ -1445,10 +1445,12 @@ function get_diag_files_id(indices) &
& trim(filename)//c_null_char)

if (size(file_indices) .ne. 1) &
& call mpp_error(FATAL, "get_diag_files_id: Error getting the correct number of file indices!")
& call mpp_error(FATAL, "get_diag_files_id: Error getting the correct number of file indices!"//&
" The diag file "//trim(filename)//" was defined "//string(size(file_indices))&
// " times")

if (file_indices(1) .eq. diag_null) &
& call mpp_error(FATAL, "get_diag_files_id: Error finding the filename in the diag_files yaml")
& call mpp_error(FATAL, "get_diag_files_id: Error finding the file "//trim(filename)//" in the diag_files yaml")

!< Get the index of the file in the diag_yaml file
file_id(i) = file_list%diag_file_indices(file_indices(1))
Expand Down

0 comments on commit 17cd78e

Please sign in to comment.