Skip to content

Commit

Permalink
Cleanup_complex (#1555)
Browse files Browse the repository at this point in the history
* Cleanup the complex conversion between floating point types

We awkwardly introduced the constructors of regulat floating point complex from half or bfloat through a declared constructor that is defined in the specialized header.

Rather than that we can just use a template and completely separate the classical header file from the other ones
  • Loading branch information
miscco authored Apr 3, 2024
1 parent 4673d1c commit 2bd685f
Show file tree
Hide file tree
Showing 9 changed files with 1,241 additions and 1,578 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ set(files
__cccl/system_header.h
__cccl/version.h
__cccl/visibility.h
__complex/nvbf16.h
__complex/nvfp16.h
__complex/vector_support.h
__concepts/__concept_macros.h
__concepts/_One_of.h
__concepts/all_of.h
Expand Down Expand Up @@ -132,8 +135,6 @@ set(files
__cuda/climits_prelude.h
__cuda/cmath_nvbf16.h
__cuda/cmath_nvfp16.h
__cuda/complex_nvbf16.h
__cuda/complex_nvfp16.h
__cuda/cstddef_prelude.h
__cuda/cstdint_prelude.h
__cuda/latch.h
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___CUDA_COMPLEX_NVBF16_H
#define _LIBCUDACXX___CUDA_COMPLEX_NVBF16_H
#ifndef _LIBCUDACXX___COMPLEX_NVBF16_H
#define _LIBCUDACXX___COMPLEX_NVBF16_H

#ifndef __cuda_std__
# include <config>
Expand All @@ -30,12 +30,13 @@ _CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function")
# include <cuda_bf16.h>
_CCCL_DIAG_POP

# include <cuda/std/cmath>
# include <cuda/std/complex>
# include <cuda/std/detail/libcxx/include/__complex/vector_support.h>
# include <cuda/std/detail/libcxx/include/__cuda/cmath_nvbf16.h>
# include <cuda/std/detail/libcxx/include/__type_traits/integral_constant.h>
# include <cuda/std/detail/libcxx/include/__type_traits/enable_if.h>
# include <cuda/std/detail/libcxx/include/__type_traits/is_arithmetic.h>
# include <cuda/std/detail/libcxx/include/__type_traits/is_same.h>
# include <cuda/std/detail/libcxx/include/cmath>
# include <cuda/std/detail/libcxx/include/__type_traits/integral_constant.h>
# include <cuda/std/detail/libcxx/include/__type_traits/is_constructible.h>

# if !defined(_CCCL_COMPILER_NVRTC)
# include <sstream> // for std::basic_ostringstream
Expand All @@ -47,6 +48,10 @@ template <>
struct __is_nvbf16<__nv_bfloat16> : true_type
{};

template <>
struct __complex_alignment<__nv_bfloat16> : integral_constant<size_t, alignof(__nv_bfloat162)>
{};

template <>
struct __type_to_vector<__nv_bfloat16>
{
Expand All @@ -61,150 +66,137 @@ struct __libcpp_complex_overload_traits<__nv_bfloat16, false, false>
};

template <>
class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(alignof(__nv_bfloat162)) complex<__nv_bfloat16>
class _LIBCUDACXX_TEMPLATE_VIS _ALIGNAS(alignof(__nv_bfloat162)) complex<__nv_bfloat16>
{
__nv_bfloat162 __repr;
__nv_bfloat162 __repr_;

template <class _Up>
friend class complex;

public:
typedef __nv_bfloat16 value_type;
using value_type = __nv_bfloat16;

_LIBCUDACXX_INLINE_VISIBILITY complex(__nv_bfloat16 __re = 0.0f, __nv_bfloat16 __im = 0.0f)
: __repr(__re, __im)
{}
template <class _Int, typename = __enable_if_t<is_arithmetic<_Int>::value>>
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(_Int __re = _Int(), _Int __im = _Int())
: __repr(__re, __im)
_LIBCUDACXX_INLINE_VISIBILITY complex(const value_type& __re = value_type(), const value_type& __im = value_type())
: __repr_(__re, __im)
{}

_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions

_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<float>& __c)
: __repr(__c.real(), __c.imag())
template <class _Up, __enable_if_t<__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
{}
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<double>& __c)
: __repr(__c.real(), __c.imag())

template <class _Up,
__enable_if_t<!__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0,
__enable_if_t<_LIBCUDACXX_TRAIT(is_constructible, value_type, _Up), int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
{}

_CCCL_DIAG_POP
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const value_type& __re)
{
__repr_.x = __re;
__repr_.y = value_type();
return *this;
}

template <class _Up>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const complex<_Up>& __c)
{
__repr_.x = __c.real();
__repr_.y = __c.imag();
return *this;
}

# if !defined(_CCCL_COMPILER_NVRTC)
template <class _Up>
_LIBCUDACXX_INLINE_VISIBILITY complex(const ::std::complex<_Up>& __other)
: __repr(_LIBCUDACXX_ACCESS_STD_COMPLEX_REAL(__other), _LIBCUDACXX_ACCESS_STD_COMPLEX_IMAG(__other))
: __repr_(_LIBCUDACXX_ACCESS_STD_COMPLEX_REAL(__other), _LIBCUDACXX_ACCESS_STD_COMPLEX_IMAG(__other))
{}

template <class _Up>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const ::std::complex<_Up>& __other)
{
__repr.x = _LIBCUDACXX_ACCESS_STD_COMPLEX_REAL(__other);
__repr.y = _LIBCUDACXX_ACCESS_STD_COMPLEX_IMAG(__other);
__repr_.x = _LIBCUDACXX_ACCESS_STD_COMPLEX_REAL(__other);
__repr_.y = _LIBCUDACXX_ACCESS_STD_COMPLEX_IMAG(__other);
return *this;
}
# endif // !defined(_CCCL_COMPILER_NVRTC)

_LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 real() const
_LIBCUDACXX_HOST operator ::std::complex<value_type>() const
{
return __repr.x;
}
_LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 imag() const
{
return __repr.y;
return {__repr_.x, __repr_.y};
}
# endif // !_CCCL_COMPILER_NVRTC

_LIBCUDACXX_INLINE_VISIBILITY void real(value_type __re)
_LIBCUDACXX_INLINE_VISIBILITY value_type real() const
{
__repr.x = __re;
return __repr_.x;
}
_LIBCUDACXX_INLINE_VISIBILITY void imag(value_type __im)
_LIBCUDACXX_INLINE_VISIBILITY value_type imag() const
{
__repr.y = __im;
return __repr_.y;
}

_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(__nv_bfloat16 __re)
{
__repr.x = __re;
__repr.y = value_type();
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(__nv_bfloat16 __re)
_LIBCUDACXX_INLINE_VISIBILITY void real(value_type __re)
{
__repr.x += __re;
return *this;
__repr_.x = __re;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(__nv_bfloat16 __re)
_LIBCUDACXX_INLINE_VISIBILITY void imag(value_type __im)
{
__repr.x -= __re;
return *this;
__repr_.y = __im;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(__nv_bfloat16 __re)

// Those additional volatile overloads are meant to help with reductions in thrust
_LIBCUDACXX_INLINE_VISIBILITY value_type real() const volatile
{
__repr.x *= __re;
__repr.y *= __re;
return *this;
return __repr_.x;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(__nv_bfloat16 __re)
_LIBCUDACXX_INLINE_VISIBILITY value_type imag() const volatile
{
__repr.x /= __re;
__repr.y /= __re;
return *this;
return __repr_.y;
}

template <class _Xp>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const complex<_Xp>& __c)
_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const value_type& __re)
{
__repr.x = __c.real();
__repr.y = __c.imag();
__repr_.x += __re;
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const complex& __c)
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const value_type& __re)
{
__repr += __c.__repr;
__repr_.x -= __re;
return *this;
}
template <class _Xp>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const complex<_Xp>& __c)
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const value_type& __re)
{
__repr.x += __c.real();
__repr.y += __c.imag();
__repr_.x *= __re;
__repr_.y *= __re;
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const complex& __c)
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const value_type& __re)
{
__repr -= __c.__repr;
__repr_.x /= __re;
__repr_.y /= __re;
return *this;
}
template <class _Xp>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const complex<_Xp>& __c)

// We can utilize vectorized operations for those operators
_LIBCUDACXX_INLINE_VISIBILITY friend complex& operator+=(complex& __lhs, const complex& __rhs) noexcept
{
__repr.x -= __c.real();
__repr.y -= __c.imag();
return *this;
__lhs.__repr_ += __rhs.__repr_;
return __lhs;
}
template <class _Xp>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const complex<_Xp>& __c)

_LIBCUDACXX_INLINE_VISIBILITY friend complex& operator-=(complex& __lhs, const complex& __rhs) noexcept
{
*this = *this * complex(__c.real(), __c.imag());
return *this;
__lhs.__repr_ -= __rhs.__repr_;
return __lhs;
}
template <class _Xp>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const complex<_Xp>& __c)

_LIBCUDACXX_INLINE_VISIBILITY friend bool operator==(const complex& __x, const complex& __y)
{
*this = *this / complex(__c.real(), __c.imag());
return *this;
return __x.__repr_ == __y.__repr_;
}
};

inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>::complex(const complex<__nv_bfloat16>& __c)
: __re_(__c.real())
, __im_(__c.imag())
{}

inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>::complex(const complex<__nv_bfloat16>& __c)
: __re_(__c.real())
, __im_(__c.imag())
{}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 arg(__nv_bfloat16 __re)
{
return _CUDA_VSTD::atan2f(__nv_bfloat16(0), __re);
Expand All @@ -214,22 +206,22 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 arg(__nv_bfloat16 __re)
template <>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<__nv_bfloat16> asinh(const complex<__nv_bfloat16>& __x)
{
return complex<__nv_bfloat16>{_CUDA_VSTD::asinh(complex<float>{__x.real(), __x.imag()})};
return complex<__nv_bfloat16>{_CUDA_VSTD::asinh(complex<float>{__x})};
}
template <>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<__nv_bfloat16> acosh(const complex<__nv_bfloat16>& __x)
{
return complex<__nv_bfloat16>{_CUDA_VSTD::acosh(complex<float>{__x.real(), __x.imag()})};
return complex<__nv_bfloat16>{_CUDA_VSTD::acosh(complex<float>{__x})};
}
template <>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<__nv_bfloat16> atanh(const complex<__nv_bfloat16>& __x)
{
return complex<__nv_bfloat16>{_CUDA_VSTD::atanh(complex<float>{__x.real(), __x.imag()})};
return complex<__nv_bfloat16>{_CUDA_VSTD::atanh(complex<float>{__x})};
}
template <>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<__nv_bfloat16> acos(const complex<__nv_bfloat16>& __x)
{
return complex<__nv_bfloat16>{_CUDA_VSTD::acos(complex<float>{__x.real(), __x.imag()})};
return complex<__nv_bfloat16>{_CUDA_VSTD::acos(complex<float>{__x})};
}

# if !defined(_CCCL_COMPILER_NVRTC)
Expand All @@ -247,12 +239,12 @@ template <class _CharT, class _Traits>
::std::basic_ostream<_CharT, _Traits>&
operator<<(::std::basic_ostream<_CharT, _Traits>& __os, const complex<__nv_bfloat16>& __x)
{
return __os << complex<float>{__x.real(), __x.imag()};
return __os << complex<float>{__x};
}
# endif // !_CCCL_COMPILER_NVRTC

_LIBCUDACXX_END_NAMESPACE_STD

#endif /// _LIBCUDACXX_HAS_NVBF16

#endif // _LIBCUDACXX___CUDA_COMPLEX_NVBF16_H
#endif // _LIBCUDACXX___COMPLEX_NVBF16_H
Loading

0 comments on commit 2bd685f

Please sign in to comment.