Skip to content

Commit

Permalink
refac/fix(example): better reporting of runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
rouson committed Oct 1, 2023
1 parent 470ff50 commit 0c2b335
Showing 1 changed file with 47 additions and 26 deletions.
73 changes: 47 additions & 26 deletions example/learn-saturated-mixing-ratio.f90
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ program train_saturated_mixture_ratio
type(trainable_engine_t) trainable_engine
type(bin_t), allocatable :: bins(:)
real, allocatable :: cost(:), random_numbers(:)
integer io_status, network_unit, plot_unit, previous_epoch
integer, parameter :: io_success=0
integer io_status, network_unit, plot_unit
integer, parameter :: io_success=0, diagnostics_print_interval = 1000, network_save_interval = 10000
integer, parameter :: nodes_per_layer(*) = [2, 31, 31, 1]
real, parameter :: cost_tolerance = 1.E-08

call random_init(image_distinct=.true., repeatable=.true.)

Expand Down Expand Up @@ -73,7 +74,7 @@ program train_saturated_mixture_ratio
end block

block
integer e, b, stop_unit
integer e, b, stop_unit, previous_epoch
real previous_clock_time

call open_plot_file_for_appending("cost.plt", plot_unit, previous_epoch, previous_clock_time)
Expand All @@ -85,30 +86,38 @@ program train_saturated_mixture_ratio
call shuffle(input_output_pairs, random_numbers)
mini_batches = [(mini_batch_t(input_output_pairs(bins(b)%first():bins(b)%last())), b = 1, size(bins))]
call trainable_engine%train(mini_batches, cost, adam=.true.)
associate(cost_avg => sum(cost)/size(cost))
if (mod(e, 1000)==0) then
call system_clock(counter_end, clock_rate)
associate(cumulative_clock_time => previous_clock_time + real(counter_end - counter_start) / real(clock_rate))
write(output_unit,fmt='(3(g13.5,2x))', advance='no') e, cost_avg, cumulative_clock_time
write(output_unit, fmt=csv) nodes_per_layer
write(plot_unit,fmt='(3(g13.5,2x))', advance='no') e, cost_avg, cumulative_clock_time
write(plot_unit, fmt=csv) nodes_per_layer
end associate
end if
if (mod(e, 10000)==0) call output(trainable_engine%to_inference_engine(), network_file)
if (cost_avg < 1.E-08) exit
call system_clock(counter_end, clock_rate)

associate( &
cost_avg => sum(cost)/size(cost), &
cumulative_clock_time => previous_clock_time + real(counter_end - counter_start) / real(clock_rate), &
loop_ending => e == previous_epoch + num_epochs &
)
write_and_exit_if_converged: &
if (cost_avg < cost_tolerance) then
call print_diagnostics(plot_unit, e, cost_avg, cumulative_clock_time, nodes_per_layer)
call output(trainable_engine%to_inference_engine(), network_file)
exit
end if write_and_exit_if_converged

open(newunit=stop_unit, file="stop", form='formatted', status='old', iostat=io_status)
if (io_status==0) exit

write_and_exit_if_stop_file_exists: &
if (io_status==0) then
call print_diagnostics(plot_unit, e, cost_avg, cumulative_clock_time, nodes_per_layer)
call output(trainable_engine%to_inference_engine(), network_file)
exit
end if write_and_exit_if_stop_file_exists

if (mod(e,diagnostics_print_interval)==0 .or. loop_ending) &
call print_diagnostics(plot_unit, e, cost_avg, cumulative_clock_time, nodes_per_layer)
if (mod(e,network_save_interval)==0 .or. loop_ending) call output(trainable_engine%to_inference_engine(), network_file)
end associate
end do

call system_clock(counter_end, clock_rate)

write(output_unit,fmt='(3(g13.5,2x))', advance='no') e, sum(cost)/size(cost), real(counter_end - counter_start) / real(clock_rate)
write(output_unit, fmt=csv) nodes_per_layer
write(plot_unit,fmt='(3(g13.5,2x))', advance='no') e, sum(cost)/size(cost), real(counter_end - counter_start) / real(clock_rate)
write(plot_unit, fmt=csv) nodes_per_layer
close(plot_unit)

report_network_performance: &
block
integer p

Expand All @@ -118,12 +127,10 @@ program train_saturated_mixture_ratio
print "(4(G13.5,2x))", inputs(p)%values(), network_outputs(p)%values(), desired_outputs(p)%values()
end do
end associate
end block
end block report_network_performance

end block

close(plot_unit)

end associate

call output(trainable_engine%to_inference_engine(), network_file)
Expand All @@ -132,6 +139,16 @@ program train_saturated_mixture_ratio

contains

subroutine print_diagnostics(plot_file_unit, epoch, cost, clock, nodes)
integer, intent(in) :: plot_file_unit, epoch, nodes(:)
real, intent(in) :: cost, clock

write(unit=output_unit, fmt='(3(g13.5,2x))', advance='no') epoch, cost, clock
write(unit=output_unit, fmt=csv) nodes
write(unit=plot_file_unit, fmt='(3(g13.5,2x))', advance='no') epoch, cost, clock
write(unit=plot_file_unit, fmt=csv) nodes
end subroutine

subroutine output(inference_engine, file_name)
type(inference_engine_t), intent(in) :: inference_engine
type(string_t), intent(in) :: file_name
Expand Down Expand Up @@ -194,12 +211,16 @@ subroutine open_plot_file_for_appending(plot_file_name, plot_unit, previous_epoc
if (.not. preexisting_plot_file) then
write(plot_unit,*) header
previous_epoch = 0
previous_clock = 0
else
plot_file = file_t(string_t(plot_file_name))
lines = plot_file%lines()
last_line = lines(size(lines))%string()
read(last_line,*, iostat=io_status) previous_epoch, cost, previous_clock
if ((io_status /= io_success .and. last_line == header) .or. len(trim(last_line))==0) previous_epoch = 0
if ((io_status /= io_success .and. last_line == header) .or. len(trim(last_line))==0) then
previous_epoch = 0
previous_clock = 0
end if
end if
end associate

Expand Down

0 comments on commit 0c2b335

Please sign in to comment.