Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use shared memory in DivergenceF2C stencil operators #2184

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

charleskawczynski
Copy link
Member

@charleskawczynski charleskawczynski commented Feb 12, 2025

I sketched out applying shared memory to DivergenceF2C stencil operators a while ago, and I wanted to open a PR with this branch before I accidentally clean it up and delete it locally.

but the performance improvement is not nearly what we should expect. Currently, per thread, we do

  • 1 read for J on centers
  • 2 reads for Jinv on faces
  • 2 reads for arg on faces

with a total of 5 reads per thread per point. Using shared memory, this should be:

  • 1 read for J on centers
  • 1 read for Jinv on faces
  • 1 reads for arg on faces

with a total of 3 reads per thread per point. So, we should see a 40% performance improvement, but I'm only seeing ~15%.

there are still some edge cases that need fixed.

@charleskawczynski
Copy link
Member Author

Note to self: need to look at threading pattern

@charleskawczynski charleskawczynski force-pushed the ck/fd_shmem branch 7 times, most recently from 82ff728 to f4b175f Compare February 13, 2025 17:23
@Sbozzolo
Copy link
Member

I think the expectation of a 40 % improvement relies on the assumption that you have complete L1 cache misses, but if different threads within a block share faces and not much more data is needed, the effective reads to global memory will be on average less than 5 because some values will be cached in the L1 cache of the streaming multiprocessor (which has the same latency as the shared memory).

@charleskawczynski charleskawczynski force-pushed the ck/fd_shmem branch 4 times, most recently from ba4f5dc to 3f85c1b Compare February 17, 2025 05:30
@charleskawczynski
Copy link
Member Author

Even slightly complicated cases are showing a nice improvement (~2x), so it may just depend on additional factors (e.g., register pressure / if there are errors / traps emitted by LLVM).

@charleskawczynski
Copy link
Member Author

In this PR, shared memory (shmem) is supported through a single layer of operator composition. That is, composed operators like div(grad(f)) will put the result of grad(f) into shared memory per thread (or, per point), but duplicate reads of f will still occur (thread i will read f at i+1 and thread i+1 will read f at i). To improve on this, we can write combined operators (e.g., divgrad), where shmem is managed, and this will result in a minimal number of required reads/writes.

Unfortunately, there are probably a lot of combinations of combined operators, however, we may be able to automatically transform the composed operators into the combined operators on the back-end, so that we don't need to introduce a slew of new terminology to users.

@charleskawczynski
Copy link
Member Author

charleskawczynski commented Feb 17, 2025

I think that this is in near merge-worthy shape, the only remaining issue(s) I see are:

  • This PR changes the default threading pattern for all stencil operations, which may not be best until shmem is supported. We could add a check before the kernel launch to look to see if any operators support shmem. we can’t really do this because this approach requires overloading getidx, and there’s no way to determine if shmem should be used in a local scope of a broadcasted object without dispatching on a new argument in getidx, which would be a pretty invasive change.
  • I think this may change the limit of vertical resolution that we support (1,000 levels in the test was lowered)-- should we dispatch to a shmem-free version to support higher vertical resolution? I don't think we'll ever need more than 256, which is supported by this implementation.

@charleskawczynski
Copy link
Member Author

charleskawczynski commented Feb 20, 2025

I've fixed the main bug I was running into, I think that this is ready to go. Here are the preliminary results for some of the relevant benchmarks (in test/Operators/finitedifference/benchmark_fd_ops_shared_memory.jl):

       main (Float64)  shmem (Float64)   main (Float32)  shmem (Float32)
ᶜout1: 245.218 μs,     200.758 μs        116.570 μs      112.719 μs
ᶜout2: 287.298 μs,     228.957 μs        178.039 μs      125.649 μs
ᶜout3: 578.095 μs,     281.637 μs        412.717 μs      194.179 μs
ᶜout4: 320.777 μs,     225.958 μs        268.588 μs      158.699 μs
ᶜout5: 335.117 μs,     249.007 μs        290.848 μs      170.639 μs
ᶜout6: 568.145 μs,     261.858 μs        497.586 μs      180.769 μs
ᶜout7: 439.927 μs,     249.397 μs        388.677 μs      173.019 μs
ᶜout8: 440.407 μs,     239.448 μs        382.127 μs      172.348 μs

It's notable that some of these kernels can be further optimized. For example ᶜout6 looks like @. ᶜout6 = div_bcs(Geometry.WVector(ᶠwinterp(ϕ, ρ))), and we only apply shared memory to div_bcs, but we could also add shared memory support for ᶠwinterp. Finally, we could define a DivergenceF2CWInterp operator that combines the stencil into one, which should, in theory, give us something closer to ~100 μs (on Float32), which is a ~5x speedup.

This PR only adds shared memory for a F2C operator, we should probably add one C2F operator, so that we exercise the cell center-filling shared memory branch code (which is currently not exercised, and is therefore likely not correct).

I've fixed two more out-of-bounds issues, and there is still one inference issue on one case where the BCs are specified as fields. Hoping that is the last one. All of the problematic cases I found I've added in the tests.

@charleskawczynski charleskawczynski force-pushed the ck/fd_shmem branch 3 times, most recently from 4a4b7c7 to bfa8df9 Compare February 21, 2025 05:19
Try dont_limit on recursive resolve_shmem methods

Fixes + more dont limit

Matrix field fixes

Matrix field fixes

DivergenceF2C fix

MatrixField fixes

Qualify DivergenceF2C

wip

Refactor + fixed space bug. All seems good.

More tests..

Fixes

Test updates

Fixes
@@ -541,6 +541,16 @@ Required for statically infering the result type of the divergence operation for
) where {FT, A1, A2 <: LocalAxis, S <: StaticMatrix{S1, S2}} where {S1, S2} =
AxisVector{FT, A2, SVector{S2, FT}}

# TODO: can we better generalize this?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dennisYatunin, can we better generalize this?

@Sbozzolo
Copy link
Member

The build that you run in ClimaAtmos shows a significant regression in the "Benchmark: GPU prog edmf" test compared to main. So maybe that's something that you want to look into before we look at this?

I also leave my first impression here:

I think this may change the limit of vertical resolution that we support (1,000 levels in the test was lowered)-- should we dispatch to a shmem-free version to support higher vertical resolution? I don't think we'll ever need more than 256, which is supported by this implementation.

I think that this should be addressed. We will not need more than 256 levels for a global simulation, but we might want it in other settings. For example, I used more than 256 levels to study self-convergence in ClimaTimeSteppers. I think that the restriction of 256 points on a finite difference method is very strong for most applicaitons that are not global simulations. Moreover, this would further differentiate our CPU and GPU capabilities. (I think it'd be perfectly acceptable to have a slower path for when the number of points is more than 256, but users should still be able to run with such a configuration)

@charleskawczynski
Copy link
Member Author

I'll of course make sure we address the performance before merging.

Allowing a larger number of vertical levels would be nice too. I'm going to think about this. If we could preserve both code-paths somehow, it might fix both issues.

Do you have any other comments?

@Sbozzolo
Copy link
Member

Do you have any other comments?

Yes, but it will be much more efficient to talk in person. So I'd suggest you first look at the problem with that job and after we can schedule a call to chat about this

@charleskawczynski
Copy link
Member Author

I just pushed a fix that should fix all of these issues. Now we check ahead of time and transform the broadcasted style to disable the shmem broadcasted object if no shmem is supported (which should fix the regression) or if the resolution is too high (to maintain super-high resolution support).

@charleskawczynski
Copy link
Member Author

The build that you run in ClimaAtmos shows a significant regression in the "Benchmark: GPU prog edmf" test compared to main. So maybe that's something that you want to look into before we look at this?

Which build are you comparing against?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants