Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a #[spirv(...)] attribute checking pass and remove #[allow(unused_attributes)]. #461

Merged
merged 6 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ fn trans_image<'tcx>(
attr: SpirvAttribute,
) -> Option<Word> {
match attr {
SpirvAttribute::Image {
SpirvAttribute::ImageType {
dim,
depth,
arrayed,
Expand Down
316 changes: 316 additions & 0 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
//! `#[spirv(...)]` attribute support.
//!
//! The attribute-checking parts of this try to follow `rustc_passes::check_attr`.

use crate::symbols::{SpirvAttribute, Symbols};
use rustc_ast::Attribute;
use rustc_hir as hir;
use rustc_hir::def_id::LocalDefId;
use rustc_hir::intravisit::{self, NestedVisitorMap, Visitor};
use rustc_hir::{HirId, MethodKind, Target, CRATE_HIR_ID};
use rustc_middle::hir::map::Map;
use rustc_middle::ty::query::Providers;
use rustc_middle::ty::TyCtxt;
use std::fmt;
use std::rc::Rc;

// FIXME(eddyb) make this reusable from somewhere in `rustc`.
pub(crate) fn target_from_impl_item<'tcx>(
tcx: TyCtxt<'tcx>,
impl_item: &hir::ImplItem<'_>,
) -> Target {
match impl_item.kind {
hir::ImplItemKind::Const(..) => Target::AssocConst,
hir::ImplItemKind::Fn(..) => {
let parent_hir_id = tcx.hir().get_parent_item(impl_item.hir_id);
let containing_item = tcx.hir().expect_item(parent_hir_id);
let containing_impl_is_for_trait = match &containing_item.kind {
hir::ItemKind::Impl { of_trait, .. } => of_trait.is_some(),
_ => unreachable!("parent of an ImplItem must be an Impl"),
};
if containing_impl_is_for_trait {
Target::Method(MethodKind::Trait { body: true })
} else {
Target::Method(MethodKind::Inherent)
}
}
hir::ImplItemKind::TyAlias(..) => Target::AssocTy,
}
}

// HACK(eddyb) current `Target` (after rust-lang/rust#80641 + rust-lang/rust#80920),
// emulated before we can rustup to that point and use the new variants directly.
enum TargetNew {
Old(Target),

// Added by rust-lang/rust#80641.
Field,
Arm,
MacroDef,

// Added by rust-lang/rust#80920.
Param,
}

impl From<Target> for TargetNew {
fn from(target: Target) -> Self {
Self::Old(target)
}
}

impl fmt::Display for TargetNew {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let description = match self {
Self::Old(target) => return write!(f, "{}", target),

Self::Field => "struct field",
Self::Arm => "match arm",
Self::MacroDef => "macro def",

Self::Param => "function param",
};
f.write_str(description)
}
}

struct CheckSpirvAttrVisitor<'tcx> {
tcx: TyCtxt<'tcx>,
sym: Rc<Symbols>,
}

impl CheckSpirvAttrVisitor<'_> {
fn check_spirv_attributes(
&self,
hir_id: HirId,
attrs: &[Attribute],
target: impl Into<TargetNew>,
) {
let parse_attrs = |attrs| crate::symbols::parse_attrs_for_checking(&self.sym, attrs);

let target = target.into();
for (attr, parse_attr_result) in parse_attrs(attrs) {
// Make sure to mark the whole `#[spirv(...)]` attribute as used,
// to avoid warnings about unused attributes.
self.tcx.sess.mark_attr_used(attr);

let (span, parsed_attr) = match parse_attr_result {
Ok(span_and_parsed_attr) => span_and_parsed_attr,
Err((span, msg)) => {
self.tcx.sess.span_err(span, &msg);
continue;
}
};

/// Error newtype marker used below for readability.
struct Expected<T>(T);

let valid_target = match parsed_attr {
SpirvAttribute::Builtin(_)
| SpirvAttribute::DescriptorSet(_)
| SpirvAttribute::Binding(_)
| SpirvAttribute::Flat => match target {
TargetNew::Param => {
let parent_hir_id = self.tcx.hir().get_parent_node(hir_id);
let parent_is_entry_point =
parse_attrs(self.tcx.hir().attrs(parent_hir_id))
.filter_map(|(_, r)| r.ok())
.any(|(_, attr)| matches!(attr, SpirvAttribute::Entry(_)));
if !parent_is_entry_point {
self.tcx.sess.span_err(
span,
"attribute is only valid on a parameter of an entry-point function",
);
}
Ok(())
}

_ => Err(Expected("function parameter")),
},

SpirvAttribute::Entry(_) => match target {
TargetNew::Old(Target::Fn)
| TargetNew::Old(Target::Method(MethodKind::Trait { body: true }))
| TargetNew::Old(Target::Method(MethodKind::Inherent)) => {
// FIXME(eddyb) further check entry-point attribute validity,
// e.g. signature, shouldn't have `#[inline]` or generics, etc.
Ok(())
}

_ => Err(Expected("function")),
},

SpirvAttribute::UnrollLoops => match target {
TargetNew::Old(Target::Fn)
| TargetNew::Old(Target::Closure)
| TargetNew::Old(Target::Method(MethodKind::Trait { body: true }))
| TargetNew::Old(Target::Method(MethodKind::Inherent)) => Ok(()),

_ => Err(Expected("function or closure")),
},

SpirvAttribute::StorageClass(_)
| SpirvAttribute::ImageType { .. }
| SpirvAttribute::Sampler
| SpirvAttribute::SampledImage
| SpirvAttribute::Block => match target {
TargetNew::Old(Target::Struct) => {
// FIXME(eddyb) further check type attribute validity,
// e.g. layout, generics, other attributes, etc.
Ok(())
}

_ => Err(Expected("struct")),
},
};
match valid_target {
Ok(()) => {}
Err(Expected(expected_target)) => self.tcx.sess.span_err(
span,
&format!(
"attribute is only valid on a {}, not on a {}",
expected_target, target
),
),
}
}
}
}

// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
impl<'tcx> Visitor<'tcx> for CheckSpirvAttrVisitor<'tcx> {
type Map = Map<'tcx>;

fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
NestedVisitorMap::OnlyBodies(self.tcx.hir())
}

fn visit_item(&mut self, item: &'tcx hir::Item<'tcx>) {
let target = Target::from_item(item);
self.check_spirv_attributes(item.hir_id, item.attrs, target);
intravisit::walk_item(self, item)
}

fn visit_generic_param(&mut self, generic_param: &'tcx hir::GenericParam<'tcx>) {
let target = Target::from_generic_param(generic_param);
self.check_spirv_attributes(generic_param.hir_id, generic_param.attrs, target);
intravisit::walk_generic_param(self, generic_param)
}

fn visit_trait_item(&mut self, trait_item: &'tcx hir::TraitItem<'tcx>) {
let target = Target::from_trait_item(trait_item);
self.check_spirv_attributes(trait_item.hir_id, trait_item.attrs, target);
intravisit::walk_trait_item(self, trait_item)
}

fn visit_struct_field(&mut self, struct_field: &'tcx hir::StructField<'tcx>) {
self.check_spirv_attributes(struct_field.hir_id, struct_field.attrs, TargetNew::Field);
intravisit::walk_struct_field(self, struct_field);
}

fn visit_arm(&mut self, arm: &'tcx hir::Arm<'tcx>) {
self.check_spirv_attributes(arm.hir_id, arm.attrs, TargetNew::Arm);
intravisit::walk_arm(self, arm);
}

fn visit_foreign_item(&mut self, f_item: &'tcx hir::ForeignItem<'tcx>) {
let target = Target::from_foreign_item(f_item);
self.check_spirv_attributes(f_item.hir_id, f_item.attrs, target);
intravisit::walk_foreign_item(self, f_item)
}

fn visit_impl_item(&mut self, impl_item: &'tcx hir::ImplItem<'tcx>) {
let target = target_from_impl_item(self.tcx, impl_item);
self.check_spirv_attributes(impl_item.hir_id, impl_item.attrs, target);
intravisit::walk_impl_item(self, impl_item)
}

fn visit_stmt(&mut self, stmt: &'tcx hir::Stmt<'tcx>) {
// When checking statements ignore expressions, they will be checked later.
if let hir::StmtKind::Local(l) = stmt.kind {
self.check_spirv_attributes(l.hir_id, &l.attrs, Target::Statement);
}
intravisit::walk_stmt(self, stmt)
}

fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) {
let target = match expr.kind {
hir::ExprKind::Closure(..) => Target::Closure,
_ => Target::Expression,
};

self.check_spirv_attributes(expr.hir_id, &expr.attrs, target);
intravisit::walk_expr(self, expr)
}

fn visit_variant(
&mut self,
variant: &'tcx hir::Variant<'tcx>,
generics: &'tcx hir::Generics<'tcx>,
item_id: HirId,
) {
self.check_spirv_attributes(variant.id, variant.attrs, Target::Variant);
intravisit::walk_variant(self, variant, generics, item_id)
}

fn visit_macro_def(&mut self, macro_def: &'tcx hir::MacroDef<'tcx>) {
self.check_spirv_attributes(macro_def.hir_id, macro_def.attrs, TargetNew::MacroDef);
intravisit::walk_macro_def(self, macro_def);
}

fn visit_param(&mut self, param: &'tcx hir::Param<'tcx>) {
self.check_spirv_attributes(param.hir_id, param.attrs, TargetNew::Param);

intravisit::walk_param(self, param);
}
}

fn check_invalid_macro_level_spirv_attr(tcx: TyCtxt<'_>, sym: &Symbols, attrs: &[Attribute]) {
for attr in attrs {
if tcx.sess.check_name(attr, sym.spirv) {
tcx.sess
.span_err(attr.span, "#[spirv(..)] cannot be applied to a macro");
}
}
}

// FIXME(eddyb) DRY this somehow and make it reusable from somewhere in `rustc`.
fn check_mod_attrs(tcx: TyCtxt<'_>, module_def_id: LocalDefId) {
let check_spirv_attr_visitor = &mut CheckSpirvAttrVisitor {
tcx,
sym: Symbols::get(),
};
tcx.hir().visit_item_likes_in_module(
module_def_id,
&mut check_spirv_attr_visitor.as_deep_visitor(),
);
// FIXME(eddyb) use `tcx.hir().visit_exported_macros_in_krate(...)` after rustup.
for id in tcx.hir().krate().exported_macros {
check_spirv_attr_visitor.visit_macro_def(match tcx.hir().find(id.hir_id) {
Some(hir::Node::MacroDef(macro_def)) => macro_def,
_ => unreachable!(),
});
}
check_invalid_macro_level_spirv_attr(
tcx,
&check_spirv_attr_visitor.sym,
tcx.hir().krate().non_exported_macro_attrs,
);
if module_def_id.is_top_level_module() {
check_spirv_attr_visitor.check_spirv_attributes(
CRATE_HIR_ID,
tcx.hir().krate_attrs(),
Target::Mod,
);
}
}

pub(crate) fn provide(providers: &mut Providers) {
*providers = Providers {
check_mod_attrs: |tcx, def_id| {
// Run both the default checks, and our `#[spirv(...)]` ones.
(rustc_interface::DEFAULT_QUERY_PROVIDERS.check_mod_attrs)(tcx, def_id);
check_mod_attrs(tcx, def_id)
},
..*providers
};
}
5 changes: 3 additions & 2 deletions crates/rustc_codegen_spirv/src/codegen_cx/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use rustc_target::spec::{HasTargetSpec, Target};
use std::cell::{Cell, RefCell};
use std::collections::HashMap;
use std::iter::once;
use std::rc::Rc;

pub struct CodegenCx<'tcx> {
pub tcx: TyCtxt<'tcx>,
Expand All @@ -56,7 +57,7 @@ pub struct CodegenCx<'tcx> {
unroll_loops_decorations: RefCell<HashMap<Word, UnrollLoopsDecoration>>,
pub kernel_mode: bool,
/// Cache of all the builtin symbols we need
pub sym: Box<Symbols>,
pub sym: Rc<Symbols>,
pub instruction_table: InstructionTable,
pub libm_intrinsics: RefCell<HashMap<Word, super::builder::libm_intrinsics::LibmIntrinsic>>,

Expand All @@ -72,7 +73,7 @@ pub struct CodegenCx<'tcx> {

impl<'tcx> CodegenCx<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, codegen_unit: &'tcx CodegenUnit<'tcx>) -> Self {
let sym = Box::new(Symbols::new());
let sym = Symbols::get();
let mut spirv_version = None;
let mut memory_model = None;
let mut kernel_mode = false;
Expand Down
4 changes: 4 additions & 0 deletions crates/rustc_codegen_spirv/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ macro_rules! assert_ty_eq {
}

mod abi;
mod attr;
mod builder;
mod builder_spirv;
mod codegen_cx;
Expand Down Expand Up @@ -299,6 +300,9 @@ impl CodegenBackend for SpirvCodegenBackend {
inner
})
};

// Extra hooks provided by other parts of `rustc_codegen_spirv`.
crate::attr::provide(providers);
}

fn provide_extern(&self, providers: &mut query::Providers) {
Expand Down
Loading