Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of verattr tag and xml dependence. #137

Merged
merged 6 commits into from
Nov 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 80 additions & 13 deletions codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ def __init__(self, format_name):
self.encoding='utf-8'

# elements for versions
self.version_string = None
self.versions = None

# ordered (!) list of tuples ({tokens}, (target_attribs)) for each <token>
self.tokens = []
self.versions = [([], ("versions", "until", "since")), ]

# maps version attribute name to [access, type]
self.verattrs = {}
# maps each type to its generated py file's relative path
self.path_dict = {}
# enum name -> storage name
Expand All @@ -48,7 +49,7 @@ def generate_module_paths(self, root):
"""preprocessing - generate module paths for imports relative to the output dir"""
for child in root:
# only check stuff that has a name - ignore version tags
if child.tag not in ("version", "token"):
if child.tag.split('}')[-1] not in ("version", "token", "include"):
base_segments = os.path.join("formats", self.format_name)
if child.tag == "module":
# for modules, set the path to base/module_name
Expand All @@ -74,17 +75,35 @@ def generate_module_paths(self, root):
self.path_dict["UintEnum"] = "base_enum"
self.path_dict["Uint64Enum"] = "base_enum"

def load_xml(self, xml_file):
def register_tokens(self, root):
"""Register tokens before anything else"""
for child in root:
if child.tag == "token":
self.read_token(child)

def load_xml(self, xml_file, parsed_xmls=None):
"""Loads an XML (can be filepath or open file) and does all parsing
Goes over all children of the root node and calls the appropriate function depending on type of the child"""
try:
# try for case where xml_file is a passed file object
xml_path = xml_file.name
except AttributeError:
# if attribute error, assume it was a file path
xml_path = xml_file
xml_path = os.path.realpath(xml_path)
tree = ET.parse(xml_file)
root = tree.getroot()
self.generate_module_paths(root)
versions = Versions(self)
self.register_tokens(root)
self.versions = Versions(self)

# dictionary of xml file: XmlParser
if parsed_xmls is None:
parsed_xmls = {}

for child in root:
self.replace_tokens(child)
if child.tag not in ('version', 'module'):
if child.tag not in ('version', 'verattr', 'module'):
self.apply_conventions(child)
try:
if child.tag in self.struct_types:
Expand All @@ -98,14 +117,17 @@ def load_xml(self, xml_file):
elif child.tag == "module":
Module(self, child)
elif child.tag == "version":
versions.read(child)
elif child.tag == "token":
self.read_token(child)
self.versions.read(child)
elif child.tag == "verattr":
self.read_verattr(child)
elif child.tag.split('}')[-1] == "include":
self.read_xinclude(child, xml_path, parsed_xmls)
except Exception as err:
logging.error(err)
traceback.print_exc()
out_file = os.path.join(os.getcwd(), "generated", "formats", self.format_name, "versions.py")
versions.write(out_file)
self.versions.write(out_file)
parsed_xmls[xml_path] = self

# the following constructs do not create classes
def read_token(self, token):
Expand All @@ -114,6 +136,31 @@ def read_token(self, token):
for sub_token in token],
token.attrib["attrs"].split(" ")))

def read_verattr(self, verattr):
"""Reads an xml <verattr> and stores it in the verattrs dict"""
name = verattr.attrib['name']
assert name not in self.verattrs, f"verattr {name} already defined!"
access = '.'.join(convention.name_attribute(comp) for comp in verattr.attrib["access"].split('.'))
attr_type = verattr.attrib.get("type")
if attr_type:
attr_type = convention.name_class(attr_type)
self.verattrs[name] = [access, attr_type]

def read_xinclude(self, xinclude, xml_path, parsed_xmls):
"""Reads an xi:include element, and parses the linked xml if it doesn't exist yet in parsed xmls"""
# convert the linked relative path to an absolute one
new_path = os.path.realpath(os.path.join(os.path.dirname(xml_path), xinclude.attrib['href']))
# check if the xml file was already parsed
if new_path not in parsed_xmls:
# if not, parse it now
format_name = os.path.splitext(os.path.basename(new_path))[0]
new_parser = XmlParser(format_name)
new_parser.load_xml(new_path, parsed_xmls)
else:
new_parser = parsed_xmls[new_path]
# append all pertinent information (file paths etc) to self for access
self.copy_xml_dicts(new_parser)

@staticmethod
def apply_convention(struct, func, params):
for k in params:
Expand Down Expand Up @@ -177,7 +224,7 @@ def map_type(self, in_type):
def replace_tokens(self, xml_struct):
"""Update xml_struct's (and all of its children's) attrib dict with content of tokens+versions list."""
# replace versions after tokens because tokens include versions
for tokens, target_attribs in self.tokens + self.versions:
for tokens, target_attribs in self.tokens:
for target_attrib in target_attribs:
if target_attrib in xml_struct.attrib:
expr_str = xml_struct.attrib[target_attrib]
Expand All @@ -195,6 +242,22 @@ def replace_tokens(self, xml_struct):
for xml_child in xml_struct:
self.replace_tokens(xml_child)

@staticmethod
def copy_dict_info(own_dict, other_dict):
"""Add information from other dict if we didn't have it yet"""
for key in other_dict.keys():
if key not in own_dict:
own_dict[key] = other_dict[key]

def copy_xml_dicts(self, other_parser):
"""Copy information necessary for linking and generation from another parser as if we'd read the file"""
[self.versions.read(version) for version in other_parser.versions.versions]
self.tokens.extend(other_parser.tokens)
self.copy_dict_info(self.verattrs, other_parser.verattrs)
self.copy_dict_info(self.path_dict, other_parser.path_dict)
self.copy_dict_info(self.storage_dict, other_parser.storage_dict)
self.copy_dict_info(self.tag_dict, other_parser.tag_dict)


def copy_src_to_generated():
"""copies the files from the source folder to the generated folder"""
Expand All @@ -220,14 +283,18 @@ def generate_classes():
cwd = os.getcwd()
root_dir = os.path.join(cwd, "source\\formats")
copy_src_to_generated()
parsed_xmls = {}
for format_name in os.listdir(root_dir):
dir_path = os.path.join(root_dir, format_name)
if os.path.isdir(dir_path):
xml_path = os.path.join(dir_path, format_name+".xml")
if os.path.isfile(xml_path):
logging.info(f"Reading {format_name} format")
if os.path.realpath(xml_path) in parsed_xmls:
logging.info(f"Already read {format_name}, skipping")
else:
logging.info(f"Reading {format_name} format")
xmlp = XmlParser(format_name)
xmlp.load_xml(xml_path)
xmlp.load_xml(xml_path, parsed_xmls)
create_inits()


Expand Down
6 changes: 5 additions & 1 deletion codegen/Module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ def read(self, element):
self.custom = bool(eval(element.attrib.get("custom","true").replace("true","True").replace("false","False"),{}))

def write(self, rel_path):
with open(os.path.join(os.getcwd(), "generated", rel_path, "__init__.py"), "w", encoding=self.parser.encoding) as file:
abs_path = os.path.join(os.getcwd(), "generated", rel_path, "__init__.py")
out_dir = os.path.dirname(abs_path)
if not os.path.isdir(out_dir):
os.makedirs(out_dir)
with open(abs_path, "w", encoding=self.parser.encoding) as file:
file.write(self.comment_str)
file.write(f'\n\n__priority__ = {repr(self.priority)}')
file.write(f'\n__depends__ = {repr(self.depends)}')
Expand Down
14 changes: 10 additions & 4 deletions codegen/Union.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ def write_init(self, f):
def write_defaults(self, f, condition=""):
base_indent = "\n\t\t"
for field in self.members:
field_debug_str = clean_comment_str(field.text, indent="\t\t")
arg, template, arr1, arr2, conditionals, field_name, field_type, pad_mode = get_params(field)

indent, condition = condition_indent(base_indent, conditionals, condition)
Expand Down Expand Up @@ -274,7 +273,14 @@ def write_io(self, f, method_type, condition=""):
else:
f.write(
f"{indent}{self.compound.parser.method_for_type(field_type, mode=method_type, attr=f'self.{field_name}', arg=arg, template=template)}")
# store version related stuff on self.context
if "version" in field_name:
f.write(f"{indent}{CONTEXT}.{field_name} = self.{field_name}")
if method_type == 'read':
# store version related fields on self.context on read
for k, (access, dtype) in self.compound.parser.verattrs.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably comment this , maybe pull out into a function too

# check all version-related global variables registered with the verattr tag
attr_path = access.split('.')
if field_name == attr_path[0]:
if dtype is None or len(attr_path) > 1 or field_type == dtype:
# the verattr type isn't known, we can't check it or it matches
f.write(f"{indent}{CONTEXT}.{field_name} = self.{field_name}")
break
return condition
34 changes: 18 additions & 16 deletions codegen/Versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from codegen.expression import Version


base_ver_attrs = ("id", "supported", "custom", "ext")

class Versions:
"""Creates and writes a version block"""

Expand All @@ -25,34 +27,34 @@ def write(self, out_file):
stream.write(f"def is_{self.format_id(version.attrib['id'])}(context):")
conds_list = []
for k, v in version.attrib.items():
if k != "id":
name = k.lower()
if k not in base_ver_attrs:
if k in self.parent.verattrs:
name = self.parent.verattrs[k][0]
else:
name = k.lower()
val = v.strip()
if name == 'num':
val = str(Version(val))
if " " in val:
conds_list.append(f"context.{name} in ({val.replace(' ', ', ')})")
conds_list.append(f"context.{name} in ({', '.join([str(Version(nr)) for nr in val.split(' ')])})")
else:
conds_list.append(f"context.{name} == {val}")
conds_list.append(f"context.{name} == {str(Version(val))}")
stream.write("\n\tif " + " and ".join(conds_list) + ":")
stream.write("\n\t\treturn True")
stream.write("\n\n\n")

stream.write(f"def set_{self.format_id(version.attrib['id'])}(context):")
for k, v in version.attrib.items():
if k != "id":
name = k.lower()
if k not in base_ver_attrs:
suffix = ""
if k in self.parent.verattrs:
name, attr_type = self.parent.verattrs[k]
if attr_type and self.parent.tag_dict[attr_type.lower()] == 'bitfield':
suffix = "._value"
else:
name = k.lower()
val = v.strip()
if " " in val:
val = val.split(" ")[0]
# todo - this should instead be detected by field type
if name == "user_version":
suffix = "._value"
else:
suffix = ""
if name == "num":
val = str(Version(val))
stream.write(f"\n\tcontext.{name}{suffix} = {val}")
stream.write(f"\n\tcontext.{name}{suffix} = {str(Version(val))}")
stream.write("\n\n\n")

# go through all the games, record them and map defaults to versions
Expand Down
3 changes: 2 additions & 1 deletion codegen/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(self, expr_str: str):
byte_number_strs = expr_str.split(".")
self.value = sum(int(n) << shift for n, shift in zip(byte_number_strs, self.shifts))
else:
self.value = int(expr_str)
# use int(x, 0) to evaluate x as an int literal, allowing for non-decimal (e.g. hex) values to be read
self.value = int(expr_str, 0)
# print(self)

def version_number(version_str):
Expand Down
62 changes: 2 additions & 60 deletions generated/formats/bani/bani.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,7 @@
<!DOCTYPE niftoolsxml>
<niftoolsxml version="0.7.1.0">

<token name="operator" attrs="cond vercond arr1 arr2 arg">
All Operators except for unary not (!), parentheses, and member of (\)
NOTE: These can be ignored entirely by string substitution and dealt with directly.
NOTE: These must be listed after the above tokens so that they replace last. For example, `verexpr` uses these tokens.
<operator token="#ADD#" string="+" />
<operator token="#SUB#" string="-" />
<operator token="#MUL#" string="*" />
<operator token="#DIV#" string="/" />
<operator token="#AND#" string="&amp;&amp;" />
<operator token="#OR#" string="||" />
<operator token="#LT#" string="&lt;" />
<operator token="#GT#" string="&gt;" />
<operator token="#LTE#" string="&lt;=" />
<operator token="#GTE#" string="&gt;=" />
<operator token="#EQ#" string="==" />
<operator token="#NEQ#" string="!=" />
<operator token="#RSH#" string="&gt;&gt;" />
<operator token="#LSH#" string="&lt;&lt;" />
<operator token="#BITAND#" string="&amp;" />
<operator token="#BITOR#" string="|" />
<operator token="#MOD#" string="%" />
</token>
<!--Basic Types-->

<basic name="ubyte" count="1">
An unsigned 8-bit integer.
</basic>

<basic name="byte" count="1">
A signed 8-bit integer.
</basic>

<basic name="uint" count="1">
An unsigned 32-bit integer.
</basic>

<basic name="uint64" count="1">
An unsigned 64-bit integer.
</basic>

<basic name="ushort" count="1" >
An unsigned 16-bit integer.
</basic>

<basic name="int" count="1" >
A signed 32-bit integer.
</basic>

<basic name="short" count="1" >
A signed 16-bit integer.
</basic>

<basic name="char" count="0" >
An 8-bit character.
</basic>

<basic name="float" count="0" >
A standard 32-bit floating point number.
</basic>
<xi:include href="../base/base.xml" xmlns:xi="http://www.w3.org/2001/XInclude" xpointer="xpointer(*/*)" />

<basic name="string">
A string of given length.
Expand Down Expand Up @@ -101,7 +43,7 @@
<field name="m33" type="float" default="1.0">Member 3,3 (bottom left)</field>
</compound>

<compound name="Matrix24" size="32">
<compound name="Matrix24" size="32">
A 4x4 transformation matrix.
<field name="m11" type="float" default="1.0">The (1,1) element.</field>
<field name="m21" type="float" default="0.0">The (2,1) element.</field>
Expand Down
Loading