Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

Commit

Permalink
wgsl-in <-> msl-out are working and tests are passing!
Browse files Browse the repository at this point in the history
try via
```
cargo nextest run -p naga --no-default-features --features wgsl-in,msl-out,validate,span,serialize,deserialize
```
  • Loading branch information
teoxoy committed Mar 23, 2023
1 parent 03c3307 commit 7c64222
Show file tree
Hide file tree
Showing 35 changed files with 1,189 additions and 1,565 deletions.
173 changes: 87 additions & 86 deletions src/back/msl/writer.rs

Large diffs are not rendered by default.

26 changes: 9 additions & 17 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -707,14 +707,12 @@ impl<W: Write> Writer<W> {
Statement::Store { pointer, value } => {
write!(self.out, "{level}")?;

let is_atomic = match *func_ctx.info[pointer].ty.inner_with(&module.types) {
crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
crate::TypeInner::Atomic { .. } => true,
_ => false,
},
_ => false,
};
if is_atomic {
let is_atomic_pointer = func_ctx.info[pointer]
.ty
.inner_with(&module.types)
.is_atomic_pointer(&module.types);

if is_atomic_pointer {
write!(self.out, "atomicStore(")?;
self.write_expr(module, pointer, Some(func_ctx))?;
write!(self.out, ", ")?;
Expand Down Expand Up @@ -1447,18 +1445,12 @@ impl<W: Write> Writer<W> {
write!(self.out, ")")?;
}
Expression::Load { pointer } => {
let is_atomic = match *func_ctx.expect("non-global context").info[pointer]
let is_atomic_pointer = func_ctx.expect("non-global context").info[pointer]
.ty
.inner_with(&module.types)
{
crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
crate::TypeInner::Atomic { .. } => true,
_ => false,
},
_ => false,
};
.is_atomic_pointer(&module.types);

if is_atomic {
if is_atomic_pointer {
write!(self.out, "atomicLoad(")?;
self.write_expr(module, pointer, func_ctx)?;
write!(self.out, ")")?;
Expand Down
13 changes: 13 additions & 0 deletions src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ pub enum Error<'a> {
},
FunctionReturnsVoid(Span),
Other,
ExpectedArraySize(Span),
NonPositiveArrayLength(Span),
}

impl<'a> Error<'a> {
Expand Down Expand Up @@ -685,6 +687,17 @@ impl<'a> Error<'a> {
labels: vec![],
notes: vec![],
},
Error::ExpectedArraySize(span) => ParseError {
message: "array element count must resolve to an integer scalar (u32, i32)"
.to_string(),
labels: vec![(span, "must resolve to u32/i32".into())],
notes: vec![],
},
Error::NonPositiveArrayLength(span) => ParseError {
message: "array element count must be positive".to_string(),
labels: vec![(span, "must be positive".into())],
notes: vec![],
},
}
}
}
14 changes: 7 additions & 7 deletions src/front/wgsl/lower/construction.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::num::NonZeroU32;

use crate::front::wgsl::parse::ast;
use crate::{Handle, Span};

Expand Down Expand Up @@ -452,14 +454,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

let base = ctx.register_type(components[0])?;

let size_expr = ctx.module.const_expressions.append(
crate::Expression::Literal(crate::Literal::U32(components.len() as _)),
span,
);

let inner = crate::TypeInner::Array {
base,
size: crate::ArraySize::Constant(size_expr),
size: crate::ArraySize::Constant(
NonZeroU32::new(u32::try_from(components.len()).unwrap()).unwrap(),
),
stride: {
self.layouter.update(ctx.module.to_ctx()).unwrap();
self.layouter[base].to_stride()
Expand Down Expand Up @@ -583,7 +582,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let base = self.resolve_ast_type(base, ctx.as_global())?;
let size = match size {
ast::ArraySize::Constant(expr) => {
crate::ArraySize::Constant(self.expression(expr, ctx.as_const())?)
let const_expr = self.expression(expr, ctx.as_const())?;
crate::ArraySize::Constant(ctx.array_length(const_expr)?)
}
ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
};
Expand Down
183 changes: 113 additions & 70 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::num::NonZeroU32;

use crate::front::wgsl::error::{Error, ExpectedToken, InvalidAssignmentType};
use crate::front::wgsl::index::Index;
use crate::front::wgsl::parse::number::Number;
Expand Down Expand Up @@ -261,6 +263,13 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}
}

fn get_expression(&self, handle: Handle<crate::Expression>) -> &crate::Expression {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => &ctx.naga_expressions[handle],
ExpressionContextType::Constant => &self.module.const_expressions[handle],
}
}

fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.naga_expressions.get_span(handle),
Expand All @@ -282,6 +291,43 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}
}

fn array_length(
&self,
const_expr: Handle<crate::Expression>,
) -> Result<NonZeroU32, Error<'source>> {
let span = self.module.const_expressions.get_span(const_expr);
let len = self
.module
.to_ctx()
.to_array_length(const_expr, None)
.map_err(|err| match err {
crate::proc::ArrayLengthError::Invalid => Error::ExpectedArraySize(span),
crate::proc::ArrayLengthError::NotPositive => Error::NonPositiveArrayLength(span),
})?;
NonZeroU32::new(len).ok_or(Error::NonPositiveArrayLength(span))
}

fn gather(
&self,
expr: Handle<crate::Expression>,
span: Span,
) -> Result<crate::SwizzleComponent, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => {
let index = self
.module
.to_ctx()
.to_array_length(expr, Some(ctx.naga_expressions))
.map_err(|_| Error::InvalidGatherComponent(span))?;
crate::SwizzleComponent::XYZW
.get(index as usize)
.copied()
.ok_or(Error::InvalidGatherComponent(span))
}
ExpressionContextType::Constant => panic!(),
}
}

/// Determine the type of `handle`, and add it to the module's arena.
///
/// If you just need a `TypeInner` for `handle`'s type, use
Expand Down Expand Up @@ -330,15 +376,15 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
&mut self,
handle: Handle<crate::Expression>,
) -> Result<&mut Self, Error<'source>> {
let local_vars = match self.expr_type {
ExpressionContextType::Runtime(ctx) => ctx.local_vars,
ExpressionContextType::Constant => &Arena::new(),
};
let arguments = match self.expr_type {
ExpressionContextType::Runtime(ctx) => ctx.arguments,
ExpressionContextType::Constant => &[],
let empty_arena = Arena::new();
let resolve_ctx = match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => {
ResolveContext::with_locals(self.module, ctx.local_vars, ctx.arguments)
}
ExpressionContextType::Constant => {
ResolveContext::with_locals(self.module, &empty_arena, &[])
}
};
let resolve_ctx = ResolveContext::with_locals(self.module, local_vars, arguments);
let (typifier, expressions) = match self.expr_type {
ExpressionContextType::Runtime(ref mut ctx) => {
(&mut *ctx.typifier, &*ctx.naga_expressions)
Expand Down Expand Up @@ -606,11 +652,6 @@ enum LoweredGlobalDecl {
EntryPoint,
}

// enum ConstantOrInner {
// Constant(Handle<crate::Constant>),
// Inner(crate::ConstantInner),
// }

enum Texture {
Gather,
GatherCompare,
Expand Down Expand Up @@ -1352,36 +1393,32 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
));
}

// if let crate::Expression::Constant(constant) = ctx.naga_expressions[index] {
// let span = ctx.naga_expressions.get_span(index);
// let index = match ctx.module.constants[constant].inner {
// crate::ConstantInner::Scalar {
// value: crate::ScalarValue::Uint(int),
// ..
// } => u32::try_from(int).map_err(|_| Error::BadU32Constant(span)),
// crate::ConstantInner::Scalar {
// value: crate::ScalarValue::Sint(int),
// ..
// } => u32::try_from(int).map_err(|_| Error::BadU32Constant(span)),
// _ => Err(Error::BadU32Constant(span)),
// }?;

// (
// crate::Expression::AccessIndex {
// base: expr.handle,
// index,
// },
// expr.is_reference,
// )
// } else {
(
crate::Expression::Access {
base: expr.handle,
index,
},
expr.is_reference,
)
// }
if let crate::Expression::Literal(lit) = *ctx.get_expression(index) {
let span = ctx.get_expression_span(index);
let index = match lit {
crate::Literal::U32(index) => Ok(index),
crate::Literal::I32(index) => {
u32::try_from(index).map_err(|_| Error::BadU32Constant(span))
}
_ => Err(Error::BadU32Constant(span)),
}?;

(
crate::Expression::AccessIndex {
base: expr.handle,
index,
},
expr.is_reference,
)
} else {
(
crate::Expression::Access {
base: expr.handle,
index,
},
expr.is_reference,
)
}
}
ast::Expression::Member { base, ref field } => {
let TypedExpression {
Expand Down Expand Up @@ -1551,11 +1588,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.map(|&arg| self.expression(arg, ctx.reborrow()))
.collect::<Result<Vec<_>, _>>()?;

let result = ctx.module.functions[function]
.result
.is_some()
.then(|| ctx.interrupt_emitter(crate::Expression::CallResult(function), span));
let has_result = ctx.module.functions[function].result.is_some();
let rctx = ctx.runtime_expression_ctx();
// we need to always do this before a fn call since all arguments need to be emitted before the fn call
rctx.block
.extend(rctx.emitter.finish(rctx.naga_expressions));
let result = has_result.then(|| {
rctx.naga_expressions
.append(crate::Expression::CallResult(function), span)
});
rctx.emitter.start(rctx.naga_expressions);
rctx.block.push(
crate::Statement::Call {
function,
Expand Down Expand Up @@ -1904,9 +1946,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
descriptor,
};

ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
ctx.emitter.start(ctx.naga_expressions);
ctx.block
let rctx = ctx.runtime_expression_ctx();
rctx.block
.extend(rctx.emitter.finish(rctx.naga_expressions));
rctx.emitter.start(rctx.naga_expressions);
rctx.block
.push(crate::Statement::RayQuery { query, fun }, span);
return Ok(None);
}
Expand All @@ -1915,14 +1959,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?;
args.finish()?;

ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
let result = ctx
.naga_expressions
.append(crate::Expression::RayQueryProceedResult, span);
.interrupt_emitter(crate::Expression::RayQueryProceedResult, span);
let fun = crate::RayQueryFunction::Proceed { result };

ctx.emitter.start(ctx.naga_expressions);
ctx.block
let rctx = ctx.runtime_expression_ctx();
rctx.block
.push(crate::Statement::RayQuery { query, fun }, span);
return Ok(Some(result));
}
Expand Down Expand Up @@ -2050,23 +2091,25 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
match *ctx.resolved_inner(lowered_image_or_component) {
crate::TypeInner::Image { .. } => {
let image_span = ctx.ast_expressions.get_span(image_or_component);
let gather = ctx.append_expression(
crate::Expression::Literal(crate::Literal::U32(0)),
span,
);
(lowered_image_or_component, image_span, Some(gather))
(
lowered_image_or_component,
image_span,
Some(crate::SwizzleComponent::X),
)
}
_ => {
let (image, image_span) = get_image_and_span(self, &mut args, &mut ctx)?;
(image, image_span, Some(lowered_image_or_component))
(
image,
image_span,
Some(ctx.gather(lowered_image_or_component, span)?),
)
}
}
}
Texture::GatherCompare => {
let (image, image_span) = get_image_and_span(self, &mut args, &mut ctx)?;
let gather =
ctx.append_expression(crate::Expression::Literal(crate::Literal::U32(0)), span);
(image, image_span, Some(gather))
(image, image_span, Some(crate::SwizzleComponent::X))
}

_ => {
Expand Down Expand Up @@ -2117,7 +2160,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

let offset = args
.next()
.map(|arg| self.expression(arg, ctx.reborrow()))
.map(|arg| self.expression(arg, ctx.as_const()))
.ok()
.transpose()?;

Expand Down Expand Up @@ -2241,8 +2284,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
base,
size: match size {
ast::ArraySize::Constant(constant) => {
let constant = self.expression(constant, ctx.as_const())?;
crate::ArraySize::Constant(constant)
let const_expr = self.expression(constant, ctx.as_const())?;
crate::ArraySize::Constant(ctx.as_const().array_length(const_expr)?)
}
ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
},
Expand All @@ -2268,8 +2311,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
base,
size: match size {
ast::ArraySize::Constant(constant) => {
let constant = self.expression(constant, ctx.as_const())?;
crate::ArraySize::Constant(constant)
let const_expr = self.expression(constant, ctx.as_const())?;
crate::ArraySize::Constant(ctx.as_const().array_length(const_expr)?)
}
ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
},
Expand Down
Loading

0 comments on commit 7c64222

Please sign in to comment.