diff --git a/.github/workflows/codegen.yml b/.github/workflows/python-tooling.yml similarity index 76% rename from .github/workflows/codegen.yml rename to .github/workflows/python-tooling.yml index 24422eba10f8..75fdd75299d4 100644 --- a/.github/workflows/codegen.yml +++ b/.github/workflows/python-tooling.yml @@ -1,10 +1,11 @@ -name: Codegen +name: Python tooling on: pull_request: paths: - "misc/bazel/**" - "misc/codegen/**" + - "misc/scripts/models-as-data/bulk_generate_mad.py" - "*.bazel*" - .github/workflows/codegen.yml - .pre-commit-config.yaml @@ -17,17 +18,14 @@ permissions: contents: read jobs: - codegen: + check-python-tooling: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 - with: - python-version-file: 'misc/codegen/.python-version' - uses: pre-commit/action@646c83fcd040023954eafda54b4db0192ce70507 name: Check that python code is properly formatted with: - extra_args: autopep8 --all-files + extra_args: black --all-files - name: Run codegen tests shell: bash run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 42333e91289e..bc07fb789873 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,11 +14,11 @@ repos: hooks: - id: clang-format - - repo: https://github.com/pre-commit/mirrors-autopep8 - rev: v2.0.4 + - repo: https://github.com/psf/black + rev: 25.1.0 hooks: - - id: autopep8 - files: ^misc/codegen/.*\.py + - id: black + files: ^(misc/codegen/.*|misc/scripts/models-as-data/bulk_generate_mad)\.py$ - repo: local hooks: diff --git a/misc/codegen/codegen.py b/misc/codegen/codegen.py index ae3a67d3fba6..7510405cd7fb 100755 --- a/misc/codegen/codegen.py +++ b/misc/codegen/codegen.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -""" Driver script to run all code generation """ +"""Driver script to run all code generation""" import argparse import logging @@ -9,7 +9,7 @@ import typing import shlex -if 'BUILD_WORKSPACE_DIRECTORY' not in os.environ: +if "BUILD_WORKSPACE_DIRECTORY" not in os.environ: # we are not running with `bazel run`, set up module search path _repo_root = pathlib.Path(__file__).resolve().parents[2] sys.path.append(str(_repo_root)) @@ -29,57 +29,105 @@ def _parse_args() -> argparse.Namespace: conf = None p = argparse.ArgumentParser(description="Code generation suite") - p.add_argument("--generate", type=lambda x: x.split(","), - help="specify what targets to generate as a comma separated list, choosing among dbscheme, ql, " - "trap, cpp and rust") - p.add_argument("--verbose", "-v", action="store_true", help="print more information") + p.add_argument( + "--generate", + type=lambda x: x.split(","), + help="specify what targets to generate as a comma separated list, choosing among dbscheme, ql, " + "trap, cpp and rust", + ) + p.add_argument( + "--verbose", "-v", action="store_true", help="print more information" + ) p.add_argument("--quiet", "-q", action="store_true", help="only print errors") - p.add_argument("--configuration-file", "-c", type=_abspath, default=conf, - help="A configuration file to load options from. By default, the first codegen.conf file found by " - "going up directories from the current location. If present all paths provided in options are " - "considered relative to its directory") - p.add_argument("--root-dir", type=_abspath, - help="the directory that should be regarded as the root of the language pack codebase. Used to " - "compute QL imports and in some comments and as root for relative paths provided as options. " - "If not provided it defaults to the directory of the configuration file, if any") + p.add_argument( + "--configuration-file", + "-c", + type=_abspath, + default=conf, + help="A configuration file to load options from. By default, the first codegen.conf file found by " + "going up directories from the current location. If present all paths provided in options are " + "considered relative to its directory", + ) + p.add_argument( + "--root-dir", + type=_abspath, + help="the directory that should be regarded as the root of the language pack codebase. Used to " + "compute QL imports and in some comments and as root for relative paths provided as options. " + "If not provided it defaults to the directory of the configuration file, if any", + ) path_arguments = [ - p.add_argument("--schema", - help="input schema file (default schema.py)"), - p.add_argument("--dbscheme", - help="output file for dbscheme generation, input file for trap generation"), - p.add_argument("--ql-output", - help="output directory for generated QL files"), - p.add_argument("--ql-stub-output", - help="output directory for QL stub/customization files. Defines also the " - "generated qll file importing every class file"), - p.add_argument("--ql-test-output", - help="output directory for QL generated extractor test files"), - p.add_argument("--ql-cfg-output", - help="output directory for QL CFG layer (optional)."), - p.add_argument("--cpp-output", - help="output directory for generated C++ files, required if trap or cpp is provided to " - "--generate"), - p.add_argument("--rust-output", - help="output directory for generated Rust files, required if rust is provided to " - "--generate"), - p.add_argument("--generated-registry", - help="registry file containing information about checked-in generated code. A .gitattributes" - "file is generated besides it to mark those files with linguist-generated=true. Must" - "be in a directory containing all generated code."), + p.add_argument("--schema", help="input schema file (default schema.py)"), + p.add_argument( + "--dbscheme", + help="output file for dbscheme generation, input file for trap generation", + ), + p.add_argument("--ql-output", help="output directory for generated QL files"), + p.add_argument( + "--ql-stub-output", + help="output directory for QL stub/customization files. Defines also the " + "generated qll file importing every class file", + ), + p.add_argument( + "--ql-test-output", + help="output directory for QL generated extractor test files", + ), + p.add_argument( + "--ql-cfg-output", help="output directory for QL CFG layer (optional)." + ), + p.add_argument( + "--cpp-output", + help="output directory for generated C++ files, required if trap or cpp is provided to " + "--generate", + ), + p.add_argument( + "--rust-output", + help="output directory for generated Rust files, required if rust is provided to " + "--generate", + ), + p.add_argument( + "--generated-registry", + help="registry file containing information about checked-in generated code. A .gitattributes" + "file is generated besides it to mark those files with linguist-generated=true. Must" + "be in a directory containing all generated code.", + ), ] - p.add_argument("--script-name", - help="script name to put in header comments of generated files. By default, the path of this " - "script relative to the root directory") - p.add_argument("--trap-library", - help="path to the trap library from an include directory, required if generating C++ trap bindings"), - p.add_argument("--ql-format", action="store_true", default=True, - help="use codeql to autoformat QL files (which is the default)") - p.add_argument("--no-ql-format", action="store_false", dest="ql_format", help="do not format QL files") - p.add_argument("--codeql-binary", default="codeql", help="command to use for QL formatting (default %(default)s)") - p.add_argument("--force", "-f", action="store_true", - help="generate all files without skipping unchanged files and overwriting modified ones") - p.add_argument("--use-current-directory", action="store_true", - help="do not consider paths as relative to --root-dir or the configuration directory") + p.add_argument( + "--script-name", + help="script name to put in header comments of generated files. By default, the path of this " + "script relative to the root directory", + ) + p.add_argument( + "--trap-library", + help="path to the trap library from an include directory, required if generating C++ trap bindings", + ), + p.add_argument( + "--ql-format", + action="store_true", + default=True, + help="use codeql to autoformat QL files (which is the default)", + ) + p.add_argument( + "--no-ql-format", + action="store_false", + dest="ql_format", + help="do not format QL files", + ) + p.add_argument( + "--codeql-binary", + default="codeql", + help="command to use for QL formatting (default %(default)s)", + ) + p.add_argument( + "--force", + "-f", + action="store_true", + help="generate all files without skipping unchanged files and overwriting modified ones", + ) + p.add_argument( + "--use-current-directory", + action="store_true", + help="do not consider paths as relative to --root-dir or the configuration directory", + ) opts = p.parse_args() if opts.configuration_file is not None: with open(opts.configuration_file) as config: @@ -97,7 +145,15 @@ def _parse_args() -> argparse.Namespace: for arg in path_arguments: path = getattr(opts, arg.dest) if path is not None: - setattr(opts, arg.dest, _abspath(path) if opts.use_current_directory else (opts.root_dir / path)) + setattr( + opts, + arg.dest, + ( + _abspath(path) + if opts.use_current_directory + else (opts.root_dir / path) + ), + ) if not opts.script_name: opts.script_name = paths.exe_file.relative_to(opts.root_dir) return opts @@ -115,7 +171,7 @@ def run(): log_level = logging.ERROR else: log_level = logging.INFO - logging.basicConfig(format="{levelname} {message}", style='{', level=log_level) + logging.basicConfig(format="{levelname} {message}", style="{", level=log_level) for target in opts.generate: generate(target, opts, render.Renderer(opts.script_name)) diff --git a/misc/codegen/generators/cppgen.py b/misc/codegen/generators/cppgen.py index 1a9a64663c19..cf99167fa46d 100644 --- a/misc/codegen/generators/cppgen.py +++ b/misc/codegen/generators/cppgen.py @@ -49,7 +49,11 @@ def _get_trap_name(cls: schema.Class, p: schema.Property) -> str | None: return inflection.pluralize(trap_name) -def _get_field(cls: schema.Class, p: schema.Property, add_or_none_except: typing.Optional[str] = None) -> cpp.Field: +def _get_field( + cls: schema.Class, + p: schema.Property, + add_or_none_except: typing.Optional[str] = None, +) -> cpp.Field: args = dict( field_name=p.name + ("_" if p.name in cpp.cpp_keywords else ""), base_type=_get_type(p.type, add_or_none_except), @@ -83,14 +87,15 @@ def _get_class(self, name: str) -> cpp.Class: bases=[self._get_class(b) for b in cls.bases], fields=[ _get_field(cls, p, self._add_or_none_except) - for p in cls.properties if "cpp_skip" not in p.pragmas and not p.synth + for p in cls.properties + if "cpp_skip" not in p.pragmas and not p.synth ], final=not cls.derived, trap_name=trap_name, ) def get_classes(self): - ret = {'': []} + ret = {"": []} for k, cls in self._classmap.items(): if not cls.synth: ret.setdefault(cls.group, []).append(self._get_class(cls.name)) @@ -102,6 +107,12 @@ def generate(opts, renderer): processor = Processor(schemaloader.load_file(opts.schema)) out = opts.cpp_output for dir, classes in processor.get_classes().items(): - renderer.render(cpp.ClassList(classes, opts.schema, - include_parent=bool(dir), - trap_library=opts.trap_library), out / dir / "TrapClasses") + renderer.render( + cpp.ClassList( + classes, + opts.schema, + include_parent=bool(dir), + trap_library=opts.trap_library, + ), + out / dir / "TrapClasses", + ) diff --git a/misc/codegen/generators/dbschemegen.py b/misc/codegen/generators/dbschemegen.py index f861972cdd68..c28fce746b97 100755 --- a/misc/codegen/generators/dbschemegen.py +++ b/misc/codegen/generators/dbschemegen.py @@ -13,6 +13,7 @@ as columns The type hierarchy will be translated to corresponding `union` declarations. """ + import typing import inflection @@ -29,7 +30,7 @@ class Error(Exception): def dbtype(typename: str, add_or_none_except: typing.Optional[str] = None) -> str: - """ translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes. + """translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes. For class types, appends an underscore followed by `null` if provided """ if typename[0].isupper(): @@ -42,12 +43,18 @@ def dbtype(typename: str, add_or_none_except: typing.Optional[str] = None) -> st return typename -def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], add_or_none_except: typing.Optional[str] = None): - """ Yield all dbscheme entities needed to model class `cls` """ +def cls_to_dbscheme( + cls: schema.Class, + lookup: typing.Dict[str, schema.Class], + add_or_none_except: typing.Optional[str] = None, +): + """Yield all dbscheme entities needed to model class `cls`""" if cls.synth: return if cls.derived: - yield Union(dbtype(cls.name), (dbtype(c) for c in cls.derived if not lookup[c].synth)) + yield Union( + dbtype(cls.name), (dbtype(c) for c in cls.derived if not lookup[c].synth) + ) dir = pathlib.Path(cls.group) if cls.group else None # output a table specific to a class only if it is a leaf class or it has 1-to-1 properties # Leaf classes need a table to bind the `@` ids @@ -61,9 +68,11 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a name=inflection.tableize(cls.name), columns=[ Column("id", type=dbtype(cls.name), binding=binding), - ] + [ + ] + + [ Column(f.name, dbtype(f.type, add_or_none_except)) - for f in cls.properties if f.is_single and not f.synth + for f in cls.properties + if f.is_single and not f.synth ], dir=dir, ) @@ -74,28 +83,37 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a continue if f.is_unordered: yield Table( - name=overridden_table_name or inflection.tableize(f"{cls.name}_{f.name}"), + name=overridden_table_name + or inflection.tableize(f"{cls.name}_{f.name}"), columns=[ Column("id", type=dbtype(cls.name)), - Column(inflection.singularize(f.name), dbtype(f.type, add_or_none_except)), + Column( + inflection.singularize(f.name), + dbtype(f.type, add_or_none_except), + ), ], dir=dir, ) elif f.is_repeated: yield Table( keyset=KeySet(["id", "index"]), - name=overridden_table_name or inflection.tableize(f"{cls.name}_{f.name}"), + name=overridden_table_name + or inflection.tableize(f"{cls.name}_{f.name}"), columns=[ Column("id", type=dbtype(cls.name)), Column("index", type="int"), - Column(inflection.singularize(f.name), dbtype(f.type, add_or_none_except)), + Column( + inflection.singularize(f.name), + dbtype(f.type, add_or_none_except), + ), ], dir=dir, ) elif f.is_optional: yield Table( keyset=KeySet(["id"]), - name=overridden_table_name or inflection.tableize(f"{cls.name}_{f.name}"), + name=overridden_table_name + or inflection.tableize(f"{cls.name}_{f.name}"), columns=[ Column("id", type=dbtype(cls.name)), Column(f.name, dbtype(f.type, add_or_none_except)), @@ -105,7 +123,8 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a elif f.is_predicate: yield Table( keyset=KeySet(["id"]), - name=overridden_table_name or inflection.underscore(f"{cls.name}_{f.name}"), + name=overridden_table_name + or inflection.underscore(f"{cls.name}_{f.name}"), columns=[ Column("id", type=dbtype(cls.name)), ], @@ -119,33 +138,46 @@ def check_name_conflicts(decls: list[Table | Union]): match decl: case Table(name=name): if name in names: - raise Error(f"Duplicate table name: { - name}, you can use `@ql.db_table_name` on a property to resolve this") + raise Error( + f"Duplicate table name: { + name}, you can use `@ql.db_table_name` on a property to resolve this" + ) names.add(name) def get_declarations(data: schema.Schema): add_or_none_except = data.root_class.name if data.null else None - declarations = [d for cls in data.classes.values() if not cls.imported for d in cls_to_dbscheme(cls, - data.classes, add_or_none_except)] + declarations = [ + d + for cls in data.classes.values() + if not cls.imported + for d in cls_to_dbscheme(cls, data.classes, add_or_none_except) + ] if data.null: property_classes = { - prop.type for cls in data.classes.values() for prop in cls.properties + prop.type + for cls in data.classes.values() + for prop in cls.properties if cls.name != data.null and prop.type and prop.type[0].isupper() } declarations += [ - Union(dbtype(t, data.null), [dbtype(t), dbtype(data.null)]) for t in sorted(property_classes) + Union(dbtype(t, data.null), [dbtype(t), dbtype(data.null)]) + for t in sorted(property_classes) ] check_name_conflicts(declarations) return declarations -def get_includes(data: schema.Schema, include_dir: pathlib.Path, root_dir: pathlib.Path): +def get_includes( + data: schema.Schema, include_dir: pathlib.Path, root_dir: pathlib.Path +): includes = [] for inc in data.includes: inc = include_dir / inc with open(inc) as inclusion: - includes.append(SchemeInclude(src=inc.relative_to(root_dir), data=inclusion.read())) + includes.append( + SchemeInclude(src=inc.relative_to(root_dir), data=inclusion.read()) + ) return includes @@ -155,8 +187,10 @@ def generate(opts, renderer): data = schemaloader.load_file(input) - dbscheme = Scheme(src=input.name, - includes=get_includes(data, include_dir=input.parent, root_dir=input.parent), - declarations=get_declarations(data)) + dbscheme = Scheme( + src=input.name, + includes=get_includes(data, include_dir=input.parent, root_dir=input.parent), + declarations=get_declarations(data), + ) renderer.render(dbscheme, out) diff --git a/misc/codegen/generators/qlgen.py b/misc/codegen/generators/qlgen.py index 7e898135d01f..991c21990d46 100755 --- a/misc/codegen/generators/qlgen.py +++ b/misc/codegen/generators/qlgen.py @@ -19,6 +19,7 @@ * one `.ql` test query for all single properties and on `_.ql` test query for each optional or repeated property """ + # TODO this should probably be split in different generators now: ql, qltest, maybe qlsynth import logging @@ -70,7 +71,7 @@ class NoClasses(Error): abbreviations.update({f"{k}s": f"{v}s" for k, v in abbreviations.items()}) -_abbreviations_re = re.compile("|".join(fr"\b{abbr}\b" for abbr in abbreviations)) +_abbreviations_re = re.compile("|".join(rf"\b{abbr}\b" for abbr in abbreviations)) def _humanize(s: str) -> str: @@ -98,11 +99,17 @@ def _get_doc(cls: schema.Class, prop: schema.Property, plural=None): return format.format(**{noun: transform(noun) for noun in nouns}) prop_name = _humanize(prop.name) - class_name = cls.pragmas.get("ql_default_doc_name", _humanize(inflection.underscore(cls.name))) + class_name = cls.pragmas.get( + "ql_default_doc_name", _humanize(inflection.underscore(cls.name)) + ) if prop.is_predicate: return f"this {class_name} {prop_name}" if plural is not None: - prop_name = inflection.pluralize(prop_name) if plural else inflection.singularize(prop_name) + prop_name = ( + inflection.pluralize(prop_name) + if plural + else inflection.singularize(prop_name) + ) return f"{prop_name} of this {class_name}" @@ -114,8 +121,12 @@ def _type_is_hideable(t: str, lookup: typing.Dict[str, schema.ClassBase]) -> boo return False -def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dict[str, schema.ClassBase], - prev_child: str = "") -> ql.Property: +def get_ql_property( + cls: schema.Class, + prop: schema.Property, + lookup: typing.Dict[str, schema.ClassBase], + prev_child: str = "", +) -> ql.Property: args = dict( type=prop.type if not prop.is_predicate else "predicate", @@ -133,12 +144,15 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic ql_name = prop.pragmas.get("ql_name", prop.name) db_table_name = prop.pragmas.get("ql_db_table_name") if db_table_name and prop.is_single: - raise Error(f"`db_table_name` pragma is not supported for single properties, but {cls.name}.{prop.name} has it") + raise Error( + f"`db_table_name` pragma is not supported for single properties, but {cls.name}.{prop.name} has it" + ) if prop.is_single: args.update( singular=inflection.camelize(ql_name), tablename=inflection.tableize(cls.name), - tableparams=["this"] + ["result" if p is prop else "_" for p in cls.properties if p.is_single], + tableparams=["this"] + + ["result" if p is prop else "_" for p in cls.properties if p.is_single], doc=_get_doc(cls, prop), ) elif prop.is_repeated: @@ -146,7 +160,11 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic singular=inflection.singularize(inflection.camelize(ql_name)), plural=inflection.pluralize(inflection.camelize(ql_name)), tablename=db_table_name or inflection.tableize(f"{cls.name}_{prop.name}"), - tableparams=["this", "index", "result"] if not prop.is_unordered else ["this", "result"], + tableparams=( + ["this", "index", "result"] + if not prop.is_unordered + else ["this", "result"] + ), doc=_get_doc(cls, prop, plural=False), doc_plural=_get_doc(cls, prop, plural=True), ) @@ -169,7 +187,9 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic return ql.Property(**args) -def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase]) -> ql.Class: +def get_ql_class( + cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase] +) -> ql.Class: if "ql_name" in cls.pragmas: raise Error("ql_name is not supported yet for classes, only for properties") prev_child = "" @@ -195,12 +215,14 @@ def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase]) ) -def get_ql_cfg_class(cls: schema.Class, lookup: typing.Dict[str, ql.Class]) -> ql.CfgClass: +def get_ql_cfg_class( + cls: schema.Class, lookup: typing.Dict[str, ql.Class] +) -> ql.CfgClass: return ql.CfgClass( name=cls.name, bases=[base for base in cls.bases if lookup[base.base].cfg], properties=cls.properties, - doc=cls.doc + doc=cls.doc, ) @@ -214,24 +236,33 @@ def _to_db_type(x: str) -> str: def get_ql_synth_class_db(name: str) -> ql.Synth.FinalClassDb: - return _final_db_class_lookup.setdefault(name, ql.Synth.FinalClassDb(name=name, - params=[ - ql.Synth.Param("id", _to_db_type(name))])) + return _final_db_class_lookup.setdefault( + name, + ql.Synth.FinalClassDb( + name=name, params=[ql.Synth.Param("id", _to_db_type(name))] + ), + ) def get_ql_synth_class(cls: schema.Class): if cls.derived: - return ql.Synth.NonFinalClass(name=cls.name, derived=sorted(cls.derived), - root=not cls.bases) + return ql.Synth.NonFinalClass( + name=cls.name, derived=sorted(cls.derived), root=not cls.bases + ) if cls.synth and cls.synth.from_class is not None: source = cls.synth.from_class get_ql_synth_class_db(source).subtract_type(cls.name) - return ql.Synth.FinalClassDerivedSynth(name=cls.name, - params=[ql.Synth.Param("id", _to_db_type(source))]) + return ql.Synth.FinalClassDerivedSynth( + name=cls.name, params=[ql.Synth.Param("id", _to_db_type(source))] + ) if cls.synth and cls.synth.on_arguments is not None: - return ql.Synth.FinalClassFreshSynth(name=cls.name, - params=[ql.Synth.Param(k, _to_db_type(v)) - for k, v in cls.synth.on_arguments.items()]) + return ql.Synth.FinalClassFreshSynth( + name=cls.name, + params=[ + ql.Synth.Param(k, _to_db_type(v)) + for k, v in cls.synth.on_arguments.items() + ], + ) return get_ql_synth_class_db(cls.name) @@ -250,7 +281,13 @@ def get_types_used_by(cls: ql.Class, is_impl: bool) -> typing.Iterable[str]: def get_classes_used_by(cls: ql.Class, is_impl: bool) -> typing.List[str]: - return sorted(set(t for t in get_types_used_by(cls, is_impl) if t[0].isupper() and (is_impl or t != cls.name))) + return sorted( + set( + t + for t in get_types_used_by(cls, is_impl) + if t[0].isupper() and (is_impl or t != cls.name) + ) + ) def format(codeql, files): @@ -265,7 +302,8 @@ def format(codeql, files): codeql_path = shutil.which(codeql) if not codeql_path: raise FormatError( - f"`{codeql}` not found in PATH. Either install it, or pass `-- --codeql-binary` with a full path") + f"`{codeql}` not found in PATH. Either install it, or pass `-- --codeql-binary` with a full path" + ) codeql = codeql_path res = subprocess.run(format_cmd, stderr=subprocess.PIPE, text=True) if res.returncode: @@ -281,16 +319,22 @@ def _get_path(cls: schema.Class) -> pathlib.Path: def _get_path_impl(cls: schema.Class) -> pathlib.Path: - return pathlib.Path(cls.group or "", "internal", cls.name+"Impl").with_suffix(".qll") + return pathlib.Path(cls.group or "", "internal", cls.name + "Impl").with_suffix( + ".qll" + ) def _get_path_public(cls: schema.Class) -> pathlib.Path: - return pathlib.Path(cls.group or "", "internal" if "ql_internal" in cls.pragmas else "", cls.name).with_suffix(".qll") + return pathlib.Path( + cls.group or "", "internal" if "ql_internal" in cls.pragmas else "", cls.name + ).with_suffix(".qll") -def _get_all_properties(cls: schema.Class, lookup: typing.Dict[str, schema.Class], - already_seen: typing.Optional[typing.Set[int]] = None) -> \ - typing.Iterable[typing.Tuple[schema.Class, schema.Property]]: +def _get_all_properties( + cls: schema.Class, + lookup: typing.Dict[str, schema.Class], + already_seen: typing.Optional[typing.Set[int]] = None, +) -> typing.Iterable[typing.Tuple[schema.Class, schema.Property]]: # deduplicate using ids if already_seen is None: already_seen = set() @@ -304,14 +348,19 @@ def _get_all_properties(cls: schema.Class, lookup: typing.Dict[str, schema.Class yield cls, p -def _get_all_properties_to_be_tested(cls: schema.Class, lookup: typing.Dict[str, schema.Class]) -> \ - typing.Iterable[ql.PropertyForTest]: +def _get_all_properties_to_be_tested( + cls: schema.Class, lookup: typing.Dict[str, schema.Class] +) -> typing.Iterable[ql.PropertyForTest]: for c, p in _get_all_properties(cls, lookup): if not ("qltest_skip" in c.pragmas or "qltest_skip" in p.pragmas): # TODO here operations are duplicated, but should be better if we split ql and qltest generation p = get_ql_property(c, p, lookup) - yield ql.PropertyForTest(p.getter, is_total=p.is_single or p.is_predicate, - type=p.type if not p.is_predicate else None, is_indexed=p.is_indexed) + yield ql.PropertyForTest( + p.getter, + is_total=p.is_single or p.is_predicate, + type=p.type if not p.is_predicate else None, + is_indexed=p.is_indexed, + ) if p.is_repeated and not p.is_optional: yield ql.PropertyForTest(f"getNumberOf{p.plural}", type="int") elif p.is_optional and not p.is_repeated: @@ -324,33 +373,45 @@ def _partition_iter(x, pred): def _partition(l, pred): - """ partitions a list according to boolean predicate """ + """partitions a list according to boolean predicate""" return map(list, _partition_iter(l, pred)) -def _is_in_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[str, schema.Class]): - return "qltest_collapse_hierarchy" in cls.pragmas or _is_under_qltest_collapsed_hierarchy(cls, lookup) +def _is_in_qltest_collapsed_hierarchy( + cls: schema.Class, lookup: typing.Dict[str, schema.Class] +): + return ( + "qltest_collapse_hierarchy" in cls.pragmas + or _is_under_qltest_collapsed_hierarchy(cls, lookup) + ) -def _is_under_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[str, schema.Class]): +def _is_under_qltest_collapsed_hierarchy( + cls: schema.Class, lookup: typing.Dict[str, schema.Class] +): return "qltest_uncollapse_hierarchy" not in cls.pragmas and any( - _is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases) + _is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases + ) def should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]): - return "qltest_skip" in cls.pragmas or not ( - cls.final or "qltest_collapse_hierarchy" in cls.pragmas) or _is_under_qltest_collapsed_hierarchy( - cls, lookup) + return ( + "qltest_skip" in cls.pragmas + or not (cls.final or "qltest_collapse_hierarchy" in cls.pragmas) + or _is_under_qltest_collapsed_hierarchy(cls, lookup) + ) -def _get_stub(cls: schema.Class, base_import: str, generated_import_prefix: str) -> ql.Stub: +def _get_stub( + cls: schema.Class, base_import: str, generated_import_prefix: str +) -> ql.Stub: if isinstance(cls.synth, schema.SynthInfo): if cls.synth.from_class is not None: accessors = [ ql.SynthUnderlyingAccessor( argument="Entity", type=_to_db_type(cls.synth.from_class), - constructorparams=["result"] + constructorparams=["result"], ) ] elif cls.synth.on_arguments is not None: @@ -358,28 +419,39 @@ def _get_stub(cls: schema.Class, base_import: str, generated_import_prefix: str) ql.SynthUnderlyingAccessor( argument=inflection.camelize(arg), type=_to_db_type(type), - constructorparams=["result" if a == arg else "_" for a in cls.synth.on_arguments] - ) for arg, type in cls.synth.on_arguments.items() + constructorparams=[ + "result" if a == arg else "_" for a in cls.synth.on_arguments + ], + ) + for arg, type in cls.synth.on_arguments.items() ] else: accessors = [] - return ql.Stub(name=cls.name, base_import=base_import, import_prefix=generated_import_prefix, - doc=cls.doc, synth_accessors=accessors) + return ql.Stub( + name=cls.name, + base_import=base_import, + import_prefix=generated_import_prefix, + doc=cls.doc, + synth_accessors=accessors, + ) def _get_class_public(cls: schema.Class) -> ql.ClassPublic: - return ql.ClassPublic(name=cls.name, doc=cls.doc, internal="ql_internal" in cls.pragmas) + return ql.ClassPublic( + name=cls.name, doc=cls.doc, internal="ql_internal" in cls.pragmas + ) _stub_qldoc_header = "// the following QLdoc is generated: if you need to edit it, do it in the schema file\n " _class_qldoc_re = re.compile( rf"(?P(?:{re.escape(_stub_qldoc_header)})?/\*\*.*?\*/\s*|^\s*)(?:class\s+(?P\w+))?", - re.MULTILINE | re.DOTALL) + re.MULTILINE | re.DOTALL, +) def _patch_class_qldoc(cls: str, qldoc: str, stub_file: pathlib.Path): - """ Replace or insert `qldoc` as the QLdoc of class `cls` in `stub_file` """ + """Replace or insert `qldoc` as the QLdoc of class `cls` in `stub_file`""" if not qldoc or not stub_file.exists(): return qldoc = "\n ".join(l.rstrip() for l in qldoc.splitlines()) @@ -415,7 +487,11 @@ def generate(opts, renderer): data = schemaloader.load_file(input) - classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items() if not cls.imported} + classes = { + name: get_ql_class(cls, data.classes) + for name, cls in data.classes.items() + if not cls.imported + } if not classes: raise NoClasses root = next(iter(classes.values())) @@ -429,28 +505,47 @@ def generate(opts, renderer): cfg_classes = [] generated_import_prefix = get_import(out, opts.root_dir) registry = opts.generated_registry or pathlib.Path( - os.path.commonpath((out, stub_out, test_out)), ".generated.list") + os.path.commonpath((out, stub_out, test_out)), ".generated.list" + ) - with renderer.manage(generated=generated, stubs=stubs, registry=registry, - force=opts.force) as renderer: + with renderer.manage( + generated=generated, stubs=stubs, registry=registry, force=opts.force + ) as renderer: - db_classes = [cls for name, cls in classes.items() if not data.classes[name].synth] - renderer.render(ql.DbClasses(classes=db_classes, imports=sorted(set(pre_imports.values()))), out / "Raw.qll") + db_classes = [ + cls for name, cls in classes.items() if not data.classes[name].synth + ] + renderer.render( + ql.DbClasses(classes=db_classes, imports=sorted(set(pre_imports.values()))), + out / "Raw.qll", + ) - classes_by_dir_and_name = sorted(classes.values(), key=lambda cls: (cls.dir, cls.name)) + classes_by_dir_and_name = sorted( + classes.values(), key=lambda cls: (cls.dir, cls.name) + ) for c in classes_by_dir_and_name: - path = get_import(stub_out / c.dir / "internal" / - c.name if c.internal else stub_out / c.path, opts.root_dir) + path = get_import( + ( + stub_out / c.dir / "internal" / c.name + if c.internal + else stub_out / c.path + ), + opts.root_dir, + ) imports[c.name] = path - path_impl = get_import(stub_out / c.dir / "internal" / c.name, opts.root_dir) + path_impl = get_import( + stub_out / c.dir / "internal" / c.name, opts.root_dir + ) imports_impl[c.name + "Impl"] = path_impl + "Impl" if c.cfg: cfg_classes.append(get_ql_cfg_class(c, classes)) for c in classes.values(): qll = out / c.path.with_suffix(".qll") - c.imports = [imports[t] if t in imports else imports_impl[t] + - "::Impl as " + t for t in get_classes_used_by(c, is_impl=True)] + c.imports = [ + imports[t] if t in imports else imports_impl[t] + "::Impl as " + t + for t in get_classes_used_by(c, is_impl=True) + ] classes_used_by[c.name] = get_classes_used_by(c, is_impl=False) c.import_prefix = generated_import_prefix renderer.render(c, qll) @@ -458,7 +553,7 @@ def generate(opts, renderer): if cfg_out: cfg_classes_val = ql.CfgClasses( include_file_import=get_import(include_file, opts.root_dir), - classes=cfg_classes + classes=cfg_classes, ) cfg_qll = cfg_out / "CfgNodes.qll" renderer.render(cfg_classes_val, cfg_qll) @@ -475,7 +570,7 @@ def generate(opts, renderer): if not renderer.is_customized_stub(stub_file): renderer.render(stub, stub_file) else: - qldoc = renderer.render_str(stub, template='ql_stub_class_qldoc') + qldoc = renderer.render_str(stub, template="ql_stub_class_qldoc") _patch_class_qldoc(c.name, qldoc, stub_file) class_public = _get_class_public(c) path_public = _get_path_public(c) @@ -484,18 +579,31 @@ def generate(opts, renderer): renderer.render(class_public, class_public_file) # for example path/to/elements -> path/to/elements.qll - renderer.render(ql.ImportList([i for name, i in imports.items() if name not in classes or not classes[name].internal]), - include_file) + renderer.render( + ql.ImportList( + [ + i + for name, i in imports.items() + if name not in classes or not classes[name].internal + ] + ), + include_file, + ) elements_module = get_import(include_file, opts.root_dir) renderer.render( ql.GetParentImplementation( classes=list(classes.values()), - imports=[elements_module] + [i for name, - i in imports.items() if name in classes and classes[name].internal], + imports=[elements_module] + + [ + i + for name, i in imports.items() + if name in classes and classes[name].internal + ], ), - out / 'ParentChild.qll') + out / "ParentChild.qll", + ) if test_out: for c in data.classes.values(): @@ -507,39 +615,61 @@ def generate(opts, renderer): test_with = data.classes[test_with_name] if test_with_name else c test_dir = test_out / test_with.group / test_with.name test_dir.mkdir(parents=True, exist_ok=True) - if all(f.suffix in (".txt", ".ql", ".actual", ".expected") for f in test_dir.glob("*.*")): + if all( + f.suffix in (".txt", ".ql", ".actual", ".expected") + for f in test_dir.glob("*.*") + ): log.warning(f"no test source in {test_dir.relative_to(test_out)}") - renderer.render(ql.MissingTestInstructions(), - test_dir / missing_test_source_filename) + renderer.render( + ql.MissingTestInstructions(), + test_dir / missing_test_source_filename, + ) continue - total_props, partial_props = _partition(_get_all_properties_to_be_tested(c, data.classes), - lambda p: p.is_total) - renderer.render(ql.ClassTester(class_name=c.name, - properties=total_props, - elements_module=elements_module, - # in case of collapsed hierarchies we want to see the actual QL class in results - show_ql_class="qltest_collapse_hierarchy" in c.pragmas), - test_dir / f"{c.name}.ql") + total_props, partial_props = _partition( + _get_all_properties_to_be_tested(c, data.classes), + lambda p: p.is_total, + ) + renderer.render( + ql.ClassTester( + class_name=c.name, + properties=total_props, + elements_module=elements_module, + # in case of collapsed hierarchies we want to see the actual QL class in results + show_ql_class="qltest_collapse_hierarchy" in c.pragmas, + ), + test_dir / f"{c.name}.ql", + ) for p in partial_props: - renderer.render(ql.PropertyTester(class_name=c.name, - elements_module=elements_module, - property=p), test_dir / f"{c.name}_{p.getter}.ql") + renderer.render( + ql.PropertyTester( + class_name=c.name, + elements_module=elements_module, + property=p, + ), + test_dir / f"{c.name}_{p.getter}.ql", + ) final_synth_types = [] non_final_synth_types = [] constructor_imports = [] synth_constructor_imports = [] stubs = {} - for cls in sorted((cls for cls in data.classes.values() if not cls.imported), - key=lambda cls: (cls.group, cls.name)): + for cls in sorted( + (cls for cls in data.classes.values() if not cls.imported), + key=lambda cls: (cls.group, cls.name), + ): synth_type = get_ql_synth_class(cls) if synth_type.is_final: final_synth_types.append(synth_type) if synth_type.has_params: - stub_file = stub_out / cls.group / "internal" / f"{cls.name}Constructor.qll" + stub_file = ( + stub_out / cls.group / "internal" / f"{cls.name}Constructor.qll" + ) if not renderer.is_customized_stub(stub_file): # stub rendering must be postponed as we might not have yet all subtracted synth types in `synth_type` - stubs[stub_file] = ql.Synth.ConstructorStub(synth_type, import_prefix=generated_import_prefix) + stubs[stub_file] = ql.Synth.ConstructorStub( + synth_type, import_prefix=generated_import_prefix + ) constructor_import = get_import(stub_file, opts.root_dir) constructor_imports.append(constructor_import) if synth_type.is_synth: @@ -549,9 +679,20 @@ def generate(opts, renderer): for stub_file, data in stubs.items(): renderer.render(data, stub_file) - renderer.render(ql.Synth.Types(root.name, generated_import_prefix, - final_synth_types, non_final_synth_types), out / "Synth.qll") - renderer.render(ql.ImportList(constructor_imports), out / "SynthConstructors.qll") - renderer.render(ql.ImportList(synth_constructor_imports), out / "PureSynthConstructors.qll") + renderer.render( + ql.Synth.Types( + root.name, + generated_import_prefix, + final_synth_types, + non_final_synth_types, + ), + out / "Synth.qll", + ) + renderer.render( + ql.ImportList(constructor_imports), out / "SynthConstructors.qll" + ) + renderer.render( + ql.ImportList(synth_constructor_imports), out / "PureSynthConstructors.qll" + ) if opts.ql_format: format(opts.codeql_binary, renderer.written) diff --git a/misc/codegen/generators/rustgen.py b/misc/codegen/generators/rustgen.py index d7025830bcbd..1f373151d6ad 100644 --- a/misc/codegen/generators/rustgen.py +++ b/misc/codegen/generators/rustgen.py @@ -55,7 +55,8 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field: def _get_properties( - cls: schema.Class, lookup: dict[str, schema.ClassBase], + cls: schema.Class, + lookup: dict[str, schema.ClassBase], ) -> typing.Iterable[tuple[schema.Class, schema.Property]]: for b in cls.bases: yield from _get_properties(lookup[b], lookup) @@ -92,8 +93,9 @@ def _get_class(self, name: str) -> rust.Class: # only generate detached fields in the actual class defining them, not the derived ones if c is cls: # TODO lift this restriction if required (requires change in dbschemegen as well) - assert c.derived or not p.is_single, \ - f"property {p.name} in concrete class marked as detached but not optional" + assert ( + c.derived or not p.is_single + ), f"property {p.name} in concrete class marked as detached but not optional" detached_fields.append(_get_field(c, p)) elif not cls.derived: # for non-detached ones, only generate fields in the concrete classes @@ -123,10 +125,12 @@ def generate(opts, renderer): processor = Processor(schemaloader.load_file(opts.schema)) out = opts.rust_output groups = set() - with renderer.manage(generated=out.rglob("*.rs"), - stubs=(), - registry=out / ".generated.list", - force=opts.force) as renderer: + with renderer.manage( + generated=out.rglob("*.rs"), + stubs=(), + registry=out / ".generated.list", + force=opts.force, + ) as renderer: for group, classes in processor.get_classes().items(): group = group or "top" groups.add(group) diff --git a/misc/codegen/generators/rusttestgen.py b/misc/codegen/generators/rusttestgen.py index e7a23fedacdc..a46d2584127b 100644 --- a/misc/codegen/generators/rusttestgen.py +++ b/misc/codegen/generators/rusttestgen.py @@ -42,7 +42,9 @@ def _get_code(doc: list[str]) -> list[str]: code.append(f"// {line}") case _, True: code.append(line) - assert not adding_code, "Unterminated code block in docstring:\n " + "\n ".join(doc) + assert not adding_code, "Unterminated code block in docstring:\n " + "\n ".join( + doc + ) if has_code: return code return [] @@ -51,15 +53,19 @@ def _get_code(doc: list[str]) -> list[str]: def generate(opts, renderer): assert opts.ql_test_output schema = schemaloader.load_file(opts.schema) - with renderer.manage(generated=opts.ql_test_output.rglob("gen_*.rs"), - stubs=(), - registry=opts.ql_test_output / ".generated_tests.list", - force=opts.force) as renderer: + with renderer.manage( + generated=opts.ql_test_output.rglob("gen_*.rs"), + stubs=(), + registry=opts.ql_test_output / ".generated_tests.list", + force=opts.force, + ) as renderer: for cls in schema.classes.values(): if cls.imported: continue - if (qlgen.should_skip_qltest(cls, schema.classes) or - "rust_skip_doc_test" in cls.pragmas): + if ( + qlgen.should_skip_qltest(cls, schema.classes) + or "rust_skip_doc_test" in cls.pragmas + ): continue code = _get_code(cls.doc) for p in schema.iter_properties(cls.name): @@ -79,5 +85,10 @@ def generate(opts, renderer): code = [indent + l for l in code] test_with_name = typing.cast(str, cls.pragmas.get("qltest_test_with")) test_with = schema.classes[test_with_name] if test_with_name else cls - test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs" + test = ( + opts.ql_test_output + / test_with.group + / test_with.name + / f"gen_{test_name}.rs" + ) renderer.render(TestCode(code="\n".join(code), function=fn), test) diff --git a/misc/codegen/generators/trapgen.py b/misc/codegen/generators/trapgen.py index e22b3e4e0e73..1f33fd4a0ff8 100755 --- a/misc/codegen/generators/trapgen.py +++ b/misc/codegen/generators/trapgen.py @@ -86,13 +86,18 @@ def generate(opts, renderer): for dir, entries in traps.items(): dir = dir or pathlib.Path() relative_gen_dir = pathlib.Path(*[".." for _ in dir.parents]) - renderer.render(cpp.TrapList(entries, opts.dbscheme, trap_library, relative_gen_dir), out / dir / "TrapEntries") + renderer.render( + cpp.TrapList(entries, opts.dbscheme, trap_library, relative_gen_dir), + out / dir / "TrapEntries", + ) tags = [] for tag in toposort_flatten(tag_graph): - tags.append(cpp.Tag( - name=get_tag_name(tag), - bases=[get_tag_name(b) for b in sorted(tag_graph[tag])], - id=tag, - )) + tags.append( + cpp.Tag( + name=get_tag_name(tag), + bases=[get_tag_name(b) for b in sorted(tag_graph[tag])], + id=tag, + ) + ) renderer.render(cpp.TagList(tags, opts.dbscheme), out / "TrapTags") diff --git a/misc/codegen/lib/cpp.py b/misc/codegen/lib/cpp.py index eed7aba045cb..2b8c504caacd 100644 --- a/misc/codegen/lib/cpp.py +++ b/misc/codegen/lib/cpp.py @@ -4,20 +4,111 @@ from typing import List, ClassVar # taken from https://en.cppreference.com/w/cpp/keyword -cpp_keywords = {"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", "atomic_commit", "atomic_noexcept", - "auto", "bitand", "bitor", "bool", "break", "case", "catch", "char", "char8_t", "char16_t", "char32_t", - "class", "compl", "concept", "const", "consteval", "constexpr", "constinit", "const_cast", "continue", - "co_await", "co_return", "co_yield", "decltype", "default", "delete", "do", "double", "dynamic_cast", - "else", "enum", "explicit", "export", "extern", "false", "float", "for", "friend", "goto", "if", - "inline", "int", "long", "mutable", "namespace", "new", "noexcept", "not", "not_eq", "nullptr", - "operator", "or", "or_eq", "private", "protected", "public", "reflexpr", "register", "reinterpret_cast", - "requires", "return", "short", "signed", "sizeof", "static", "static_assert", "static_cast", "struct", - "switch", "synchronized", "template", "this", "thread_local", "throw", "true", "try", "typedef", - "typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while", - "xor", "xor_eq"} +cpp_keywords = { + "alignas", + "alignof", + "and", + "and_eq", + "asm", + "atomic_cancel", + "atomic_commit", + "atomic_noexcept", + "auto", + "bitand", + "bitor", + "bool", + "break", + "case", + "catch", + "char", + "char8_t", + "char16_t", + "char32_t", + "class", + "compl", + "concept", + "const", + "consteval", + "constexpr", + "constinit", + "const_cast", + "continue", + "co_await", + "co_return", + "co_yield", + "decltype", + "default", + "delete", + "do", + "double", + "dynamic_cast", + "else", + "enum", + "explicit", + "export", + "extern", + "false", + "float", + "for", + "friend", + "goto", + "if", + "inline", + "int", + "long", + "mutable", + "namespace", + "new", + "noexcept", + "not", + "not_eq", + "nullptr", + "operator", + "or", + "or_eq", + "private", + "protected", + "public", + "reflexpr", + "register", + "reinterpret_cast", + "requires", + "return", + "short", + "signed", + "sizeof", + "static", + "static_assert", + "static_cast", + "struct", + "switch", + "synchronized", + "template", + "this", + "thread_local", + "throw", + "true", + "try", + "typedef", + "typeid", + "typename", + "union", + "unsigned", + "using", + "virtual", + "void", + "volatile", + "wchar_t", + "while", + "xor", + "xor_eq", +} _field_overrides = [ - (re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"base_type": "unsigned"}), + ( + re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), + {"base_type": "unsigned"}, + ), (re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}), ] @@ -108,7 +199,7 @@ def has_bases(self): @dataclass class TrapList: - template: ClassVar = 'trap_traps' + template: ClassVar = "trap_traps" extensions = ["h", "cpp"] traps: List[Trap] source: str @@ -118,7 +209,7 @@ class TrapList: @dataclass class TagList: - template: ClassVar = 'trap_tags' + template: ClassVar = "trap_tags" extensions = ["h"] tags: List[Tag] @@ -127,7 +218,7 @@ class TagList: @dataclass class ClassBase: - ref: 'Class' + ref: "Class" first: bool = False @@ -140,7 +231,9 @@ class Class: trap_name: str = None def __post_init__(self): - self.bases = [ClassBase(c) for c in sorted(self.bases, key=lambda cls: cls.name)] + self.bases = [ + ClassBase(c) for c in sorted(self.bases, key=lambda cls: cls.name) + ] if self.bases: self.bases[0].first = True diff --git a/misc/codegen/lib/dbscheme.py b/misc/codegen/lib/dbscheme.py index eee0191b6788..03c9878d7f11 100644 --- a/misc/codegen/lib/dbscheme.py +++ b/misc/codegen/lib/dbscheme.py @@ -1,4 +1,4 @@ -""" dbscheme format representation """ +"""dbscheme format representation""" import logging import pathlib @@ -100,7 +100,7 @@ class SchemeInclude: @dataclass class Scheme: - template: ClassVar = 'dbscheme' + template: ClassVar = "dbscheme" src: str includes: List[SchemeInclude] diff --git a/misc/codegen/lib/paths.py b/misc/codegen/lib/paths.py index b102987a2267..f56bbb9d8171 100644 --- a/misc/codegen/lib/paths.py +++ b/misc/codegen/lib/paths.py @@ -1,4 +1,4 @@ -""" module providing useful filesystem paths """ +"""module providing useful filesystem paths""" import pathlib import sys @@ -7,13 +7,15 @@ _this_file = pathlib.Path(__file__).resolve() try: - workspace_dir = pathlib.Path(os.environ['BUILD_WORKSPACE_DIRECTORY']).resolve() # <- means we are using bazel run - root_dir = workspace_dir / 'swift' + workspace_dir = pathlib.Path( + os.environ["BUILD_WORKSPACE_DIRECTORY"] + ).resolve() # <- means we are using bazel run + root_dir = workspace_dir / "swift" except KeyError: root_dir = _this_file.parents[2] workspace_dir = root_dir.parent -lib_dir = _this_file.parents[2] / 'codegen' / 'lib' -templates_dir = _this_file.parents[2] / 'codegen' / 'templates' +lib_dir = _this_file.parents[2] / "codegen" / "lib" +templates_dir = _this_file.parents[2] / "codegen" / "templates" exe_file = pathlib.Path(sys.argv[0]).resolve() diff --git a/misc/codegen/lib/ql.py b/misc/codegen/lib/ql.py index 0200477eb32c..7537aac995c5 100644 --- a/misc/codegen/lib/ql.py +++ b/misc/codegen/lib/ql.py @@ -100,7 +100,7 @@ def __str__(self): @dataclass class Class: - template: ClassVar = 'ql_class' + template: ClassVar = "ql_class" name: str bases: List[Base] = field(default_factory=list) @@ -116,7 +116,12 @@ class Class: cfg: bool = False def __post_init__(self): - def get_bases(bases): return [Base(str(b), str(prev)) for b, prev in zip(bases, itertools.chain([""], bases))] + def get_bases(bases): + return [ + Base(str(b), str(prev)) + for b, prev in zip(bases, itertools.chain([""], bases)) + ] + self.bases = get_bases(self.bases) self.bases_impl = get_bases(self.bases_impl) if self.properties: @@ -164,7 +169,7 @@ def __post_init__(self): @dataclass class Stub: - template: ClassVar = 'ql_stub' + template: ClassVar = "ql_stub" name: str base_import: str @@ -183,7 +188,7 @@ def has_qldoc(self) -> bool: @dataclass class ClassPublic: - template: ClassVar = 'ql_class_public' + template: ClassVar = "ql_class_public" name: str imports: List[str] = field(default_factory=list) @@ -197,7 +202,7 @@ def has_qldoc(self) -> bool: @dataclass class DbClasses: - template: ClassVar = 'ql_db' + template: ClassVar = "ql_db" classes: List[Class] = field(default_factory=list) imports: List[str] = field(default_factory=list) @@ -205,14 +210,14 @@ class DbClasses: @dataclass class ImportList: - template: ClassVar = 'ql_imports' + template: ClassVar = "ql_imports" imports: List[str] = field(default_factory=list) @dataclass class GetParentImplementation: - template: ClassVar = 'ql_parent' + template: ClassVar = "ql_parent" classes: List[Class] = field(default_factory=list) imports: List[str] = field(default_factory=list) @@ -234,7 +239,7 @@ class TesterBase: @dataclass class ClassTester(TesterBase): - template: ClassVar = 'ql_test_class' + template: ClassVar = "ql_test_class" properties: List[PropertyForTest] = field(default_factory=list) show_ql_class: bool = False @@ -242,14 +247,14 @@ class ClassTester(TesterBase): @dataclass class PropertyTester(TesterBase): - template: ClassVar = 'ql_test_property' + template: ClassVar = "ql_test_property" property: PropertyForTest @dataclass class MissingTestInstructions: - template: ClassVar = 'ql_test_missing' + template: ClassVar = "ql_test_missing" class Synth: @@ -306,7 +311,9 @@ class FinalClassDb(FinalClass): subtracted_synth_types: List["Synth.Class"] = field(default_factory=list) def subtract_type(self, type: str): - self.subtracted_synth_types.append(Synth.Class(type, first=not self.subtracted_synth_types)) + self.subtracted_synth_types.append( + Synth.Class(type, first=not self.subtracted_synth_types) + ) @property def has_subtracted_synth_types(self) -> bool: @@ -357,6 +364,6 @@ class CfgClass: @dataclass class CfgClasses: - template: ClassVar = 'ql_cfg_nodes' + template: ClassVar = "ql_cfg_nodes" include_file_import: Optional[str] = None classes: List[CfgClass] = field(default_factory=list) diff --git a/misc/codegen/lib/render.py b/misc/codegen/lib/render.py index ac43a515de10..5ab746107ee7 100644 --- a/misc/codegen/lib/render.py +++ b/misc/codegen/lib/render.py @@ -1,4 +1,4 @@ -""" template renderer module, wrapping around `pystache.Renderer` +"""template renderer module, wrapping around `pystache.Renderer` `pystache` is a python mustache engine, and mustache is a template language. More information on @@ -23,14 +23,21 @@ class Error(Exception): class Renderer: - """ Template renderer using mustache templates in the `templates` directory """ + """Template renderer using mustache templates in the `templates` directory""" def __init__(self, generator: pathlib.Path): - self._r = pystache.Renderer(search_dirs=str(paths.templates_dir), escape=lambda u: u) + self._r = pystache.Renderer( + search_dirs=str(paths.templates_dir), escape=lambda u: u + ) self._generator = generator - def render(self, data: object, output: typing.Optional[pathlib.Path], template: typing.Optional[str] = None): - """ Render `data` to `output`. + def render( + self, + data: object, + output: typing.Optional[pathlib.Path], + template: typing.Optional[str] = None, + ): + """Render `data` to `output`. `data` must have a `template` attribute denoting which template to use from the template directory. @@ -58,13 +65,18 @@ def _do_write(self, mnemonic: str, contents: str, output: pathlib.Path): out.write(contents) log.debug(f"{mnemonic}: generated {output.name}") - def manage(self, generated: typing.Iterable[pathlib.Path], stubs: typing.Iterable[pathlib.Path], - registry: pathlib.Path, force: bool = False) -> "RenderManager": + def manage( + self, + generated: typing.Iterable[pathlib.Path], + stubs: typing.Iterable[pathlib.Path], + registry: pathlib.Path, + force: bool = False, + ) -> "RenderManager": return RenderManager(self._generator, generated, stubs, registry, force) class RenderManager(Renderer): - """ A context manager allowing to manage checked in generated files and their cleanup, able + """A context manager allowing to manage checked in generated files and their cleanup, able to skip unneeded writes. This is done by using and updating a checked in list of generated files that assigns two @@ -74,6 +86,7 @@ class RenderManager(Renderer): * the other is the hash of the actual file after code generation has finished. This will be different from the above because of post-processing like QL formatting. This hash is used to detect invalid modification of generated files""" + written: typing.Set[pathlib.Path] @dataclass @@ -82,12 +95,18 @@ class Hashes: pre contains the hash of a file as rendered, post is the hash after postprocessing (for example QL formatting) """ + pre: str post: typing.Optional[str] = None - def __init__(self, generator: pathlib.Path, generated: typing.Iterable[pathlib.Path], - stubs: typing.Iterable[pathlib.Path], - registry: pathlib.Path, force: bool = False): + def __init__( + self, + generator: pathlib.Path, + generated: typing.Iterable[pathlib.Path], + stubs: typing.Iterable[pathlib.Path], + registry: pathlib.Path, + force: bool = False, + ): super().__init__(generator) self._registry_path = registry self._force = force @@ -142,10 +161,14 @@ def _process_generated(self, generated: typing.Iterable[pathlib.Path]): if self._force: pass elif rel_path not in self._hashes: - log.warning(f"{rel_path} marked as generated but absent from the registry") + log.warning( + f"{rel_path} marked as generated but absent from the registry" + ) elif self._hashes[rel_path].post != self._hash_file(f): - raise Error(f"{rel_path} is generated but was modified, please revert the file " - "or pass --force to overwrite") + raise Error( + f"{rel_path} is generated but was modified, please revert the file " + "or pass --force to overwrite" + ) def _process_stubs(self, stubs: typing.Iterable[pathlib.Path]): for f in stubs: @@ -159,8 +182,10 @@ def _process_stubs(self, stubs: typing.Iterable[pathlib.Path]): elif rel_path not in self._hashes: log.warning(f"{rel_path} marked as stub but absent from the registry") elif self._hashes[rel_path].post != self._hash_file(f): - raise Error(f"{rel_path} is a stub marked as generated, but it was modified, " - "please remove the `// generated` header, revert the file or pass --force to overwrite it") + raise Error( + f"{rel_path} is a stub marked as generated, but it was modified, " + "please remove the `// generated` header, revert the file or pass --force to overwrite it" + ) @staticmethod def is_customized_stub(file: pathlib.Path) -> bool: @@ -191,13 +216,17 @@ def _load_registry(self): for line in reg: if line.strip(): filename, prehash, posthash = line.split() - self._hashes[pathlib.Path(filename)] = self.Hashes(prehash, posthash) + self._hashes[pathlib.Path(filename)] = self.Hashes( + prehash, posthash + ) except FileNotFoundError: pass def _dump_registry(self): self._registry_path.parent.mkdir(parents=True, exist_ok=True) - with open(self._registry_path, 'w') as out, open(self._registry_path.parent / ".gitattributes", "w") as attrs: + with open(self._registry_path, "w") as out, open( + self._registry_path.parent / ".gitattributes", "w" + ) as attrs: print(f"/{self._registry_path.name}", "linguist-generated", file=attrs) print("/.gitattributes", "linguist-generated", file=attrs) for f, hashes in sorted(self._hashes.items()): diff --git a/misc/codegen/lib/schema.py b/misc/codegen/lib/schema.py index 5178e61d3844..efcfb5c5fc2e 100644 --- a/misc/codegen/lib/schema.py +++ b/misc/codegen/lib/schema.py @@ -1,4 +1,5 @@ -""" schema format representation """ +"""schema format representation""" + import abc import typing from collections.abc import Iterable @@ -52,7 +53,11 @@ def is_optional(self) -> bool: @property def is_repeated(self) -> bool: - return self.kind in (self.Kind.REPEATED, self.Kind.REPEATED_OPTIONAL, self.Kind.REPEATED_UNORDERED) + return self.kind in ( + self.Kind.REPEATED, + self.Kind.REPEATED_OPTIONAL, + self.Kind.REPEATED_UNORDERED, + ) @property def is_unordered(self) -> bool: @@ -74,10 +79,11 @@ def has_builtin_type(self) -> bool: SingleProperty = functools.partial(Property, Property.Kind.SINGLE) OptionalProperty = functools.partial(Property, Property.Kind.OPTIONAL) RepeatedProperty = functools.partial(Property, Property.Kind.REPEATED) -RepeatedOptionalProperty = functools.partial( - Property, Property.Kind.REPEATED_OPTIONAL) +RepeatedOptionalProperty = functools.partial(Property, Property.Kind.REPEATED_OPTIONAL) PredicateProperty = functools.partial(Property, Property.Kind.PREDICATE) -RepeatedUnorderedProperty = functools.partial(Property, Property.Kind.REPEATED_UNORDERED) +RepeatedUnorderedProperty = functools.partial( + Property, Property.Kind.REPEATED_UNORDERED +) @dataclass @@ -197,9 +203,9 @@ def _make_property(arg: object) -> Property: class PropertyModifier(abc.ABC): - """ Modifier of `Property` objects. - Being on the right of `|` it will trigger construction of a `Property` from - the left operand. + """Modifier of `Property` objects. + Being on the right of `|` it will trigger construction of a `Property` from + the left operand. """ def __ror__(self, other: object) -> Property: @@ -210,11 +216,9 @@ def __ror__(self, other: object) -> Property: def __invert__(self) -> "PropertyModifier": return self.negate() - def modify(self, prop: Property): - ... + def modify(self, prop: Property): ... - def negate(self) -> "PropertyModifier": - ... + def negate(self) -> "PropertyModifier": ... def split_doc(doc): @@ -224,7 +228,11 @@ def split_doc(doc): lines = doc.splitlines() # Determine minimum indentation (first line doesn't count): strippedlines = (line.lstrip() for line in lines[1:]) - indents = [len(line) - len(stripped) for line, stripped in zip(lines[1:], strippedlines) if stripped] + indents = [ + len(line) - len(stripped) + for line, stripped in zip(lines[1:], strippedlines) + if stripped + ] # Remove indentation (first line is special): trimmed = [lines[0].strip()] if indents: diff --git a/misc/codegen/lib/schemadefs.py b/misc/codegen/lib/schemadefs.py index b0cf2b038a8d..5841b9ac874d 100644 --- a/misc/codegen/lib/schemadefs.py +++ b/misc/codegen/lib/schemadefs.py @@ -39,7 +39,9 @@ class _DocModifier(_schema.PropertyModifier, metaclass=_DocModifierMetaclass): def modify(self, prop: _schema.Property): if self.doc and ("\n" in self.doc or self.doc[-1] == "."): - raise _schema.Error("No newlines or trailing dots are allowed in doc, did you intend to use desc?") + raise _schema.Error( + "No newlines or trailing dots are allowed in doc, did you intend to use desc?" + ) prop.doc = self.doc def negate(self) -> _schema.PropertyModifier: @@ -73,10 +75,13 @@ def include(source: str): @_dataclass class _Namespace: - """ simple namespacing mechanism """ + """simple namespacing mechanism""" + _name: str - def add(self, pragma: _Union["_PragmaBase", "_Parametrized"], key: str | None = None): + def add( + self, pragma: _Union["_PragmaBase", "_Parametrized"], key: str | None = None + ): self.__dict__[pragma.pragma] = pragma pragma.pragma = key or f"{self._name}_{pragma.pragma}" @@ -110,15 +115,18 @@ def _apply(self, pragmas: _Dict[str, object]) -> None: @_dataclass class _ClassPragma(_PragmaBase): - """ A class pragma. + """A class pragma. For schema classes it acts as a python decorator with `@`. """ + inherited: bool = False def __call__(self, cls: type) -> type: - """ use this pragma as a decorator on classes """ + """use this pragma as a decorator on classes""" if self.inherited: - setattr(cls, f"{_schema.inheritable_pragma_prefix}{self.pragma}", self.value) + setattr( + cls, f"{_schema.inheritable_pragma_prefix}{self.pragma}", self.value + ) else: # not using hasattr as we don't want to land on inherited pragmas if "_pragmas" not in cls.__dict__: @@ -129,9 +137,10 @@ def __call__(self, cls: type) -> type: @_dataclass class _PropertyPragma(_PragmaBase, _schema.PropertyModifier): - """ A property pragma. + """A property pragma. It functions similarly to a `_PropertyModifier` with `|`, adding the pragma. """ + remove: bool = False def modify(self, prop: _schema.Property): @@ -149,21 +158,23 @@ def _apply(self, pragmas: _Dict[str, object]) -> None: @_dataclass class _Pragma(_ClassPragma, _PropertyPragma): - """ A class or property pragma. + """A class or property pragma. For properties, it functions similarly to a `_PropertyModifier` with `|`, adding the pragma. For schema classes it acts as a python decorator with `@`. """ class _Parametrized[P, **Q, T]: - """ A parametrized pragma. + """A parametrized pragma. Needs to be applied to a parameter to give a pragma. """ def __init__(self, pragma_instance: P, factory: _Callable[Q, T]): self.pragma_instance = pragma_instance self.factory = factory - self.__signature__ = _inspect.signature(self.factory).replace(return_annotation=type(self.pragma_instance)) + self.__signature__ = _inspect.signature(self.factory).replace( + return_annotation=type(self.pragma_instance) + ) @property def pragma(self): @@ -187,7 +198,8 @@ def modify(self, prop: _schema.Property): K = _schema.Property.Kind if prop.kind != K.SINGLE: raise _schema.Error( - "optional should only be applied to simple property types") + "optional should only be applied to simple property types" + ) prop.kind = K.OPTIONAL @@ -200,7 +212,8 @@ def modify(self, prop: _schema.Property): prop.kind = K.REPEATED_OPTIONAL else: raise _schema.Error( - "list should only be applied to simple or optional property types") + "list should only be applied to simple or optional property types" + ) class _Setifier(_schema.PropertyModifier): @@ -212,7 +225,7 @@ def modify(self, prop: _schema.Property): class _TypeModifier: - """ Modifies types using get item notation """ + """Modifies types using get item notation""" def __init__(self, modifier: _schema.PropertyModifier): self.modifier = modifier @@ -242,7 +255,11 @@ def __getitem__(self, item): qltest.add(_ClassPragma("skip")) qltest.add(_ClassPragma("collapse_hierarchy")) qltest.add(_ClassPragma("uncollapse_hierarchy")) -qltest.add(_Parametrized(_ClassPragma("test_with", inherited=True), factory=_schema.get_type_name)) +qltest.add( + _Parametrized( + _ClassPragma("test_with", inherited=True), factory=_schema.get_type_name + ) +) ql.add(_Parametrized(_ClassPragma("default_doc_name"), factory=lambda doc: doc)) ql.add(_ClassPragma("hideable", inherited=True)) @@ -255,15 +272,33 @@ def __getitem__(self, item): rust.add(_PropertyPragma("detach")) rust.add(_Pragma("skip_doc_test")) -rust.add(_Parametrized(_ClassPragma("doc_test_signature"), factory=lambda signature: signature)) +rust.add( + _Parametrized( + _ClassPragma("doc_test_signature"), factory=lambda signature: signature + ) +) -group = _Parametrized(_ClassPragma("group", inherited=True), factory=lambda group: group) +group = _Parametrized( + _ClassPragma("group", inherited=True), factory=lambda group: group +) -synth.add(_Parametrized(_ClassPragma("from_class"), factory=lambda ref: _schema.SynthInfo( - from_class=_schema.get_type_name(ref))), key="synth") -synth.add(_Parametrized(_ClassPragma("on_arguments"), factory=lambda **kwargs: - _schema.SynthInfo(on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()})), key="synth") +synth.add( + _Parametrized( + _ClassPragma("from_class"), + factory=lambda ref: _schema.SynthInfo(from_class=_schema.get_type_name(ref)), + ), + key="synth", +) +synth.add( + _Parametrized( + _ClassPragma("on_arguments"), + factory=lambda **kwargs: _schema.SynthInfo( + on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()} + ), + ), + key="synth", +) @_dataclass(frozen=True) @@ -283,7 +318,12 @@ def modify(self, prop: _schema.Property): drop = object() -def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, replace_bases: _Dict[type, type] | None = None, cfg: bool = False) -> _Callable[[type], _PropertyModifierList]: +def annotate( + annotated_cls: type, + add_bases: _Iterable[type] | None = None, + replace_bases: _Dict[type, type] | None = None, + cfg: bool = False, +) -> _Callable[[type], _PropertyModifierList]: """ Add or modify schema annotations after a class has been defined previously. @@ -291,6 +331,7 @@ def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, repl `replace_bases` can be used to replace bases on the annotated class. """ + def decorator(cls: type) -> _PropertyModifierList: if cls.__name__ != "_": raise _schema.Error("Annotation classes must be named _") @@ -299,7 +340,9 @@ def decorator(cls: type) -> _PropertyModifierList: for p, v in cls.__dict__.get("_pragmas", {}).items(): _ClassPragma(p, value=v)(annotated_cls) if replace_bases: - annotated_cls.__bases__ = tuple(replace_bases.get(b, b) for b in annotated_cls.__bases__) + annotated_cls.__bases__ = tuple( + replace_bases.get(b, b) for b in annotated_cls.__bases__ + ) if add_bases: annotated_cls.__bases__ += tuple(add_bases) annotated_cls.__cfg__ = cfg @@ -312,9 +355,12 @@ def decorator(cls: type) -> _PropertyModifierList: elif p in annotated_cls.__annotations__: annotated_cls.__annotations__[p] |= a elif isinstance(a, (_PropertyModifierList, _PropertyModifierList)): - raise _schema.Error(f"annotated property {p} not present in annotated class " - f"{annotated_cls.__name__}") + raise _schema.Error( + f"annotated property {p} not present in annotated class " + f"{annotated_cls.__name__}" + ) else: annotated_cls.__annotations__[p] = a return _ + return decorator diff --git a/misc/codegen/loaders/dbschemeloader.py b/misc/codegen/loaders/dbschemeloader.py index f6fbab50499c..a9b599ef0c3c 100644 --- a/misc/codegen/loaders/dbschemeloader.py +++ b/misc/codegen/loaders/dbschemeloader.py @@ -12,9 +12,13 @@ class _Re: "|" r"^(?P@\w+)\s*=\s*(?P@\w+(?:\s*\|\s*@\w+)*)\s*;?" ) - field = re.compile(r"(?m)[\w\s]*\s(?P\w+)\s*:\s*(?P@?\w+)(?P\s+ref)?") + field = re.compile( + r"(?m)[\w\s]*\s(?P\w+)\s*:\s*(?P@?\w+)(?P\s+ref)?" + ) key = re.compile(r"@\w+") - comment = re.compile(r"(?m)(?s)/\*.*?\*/|//(?!dir=)[^\n]*$") # lookahead avoid ignoring metadata like //dir=foo + comment = re.compile( + r"(?m)(?s)/\*.*?\*/|//(?!dir=)[^\n]*$" + ) # lookahead avoid ignoring metadata like //dir=foo def _get_column(match): diff --git a/misc/codegen/loaders/schemaloader.py b/misc/codegen/loaders/schemaloader.py index 3b5f20cbbede..eaf08a04f571 100644 --- a/misc/codegen/loaders/schemaloader.py +++ b/misc/codegen/loaders/schemaloader.py @@ -1,4 +1,5 @@ -""" schema loader """ +"""schema loader""" + import sys import inflection @@ -33,37 +34,56 @@ def _get_class(cls: type) -> schema.Class: raise schema.Error(f"Only class definitions allowed in schema, found {cls}") # we must check that going to dbscheme names and back is preserved # In particular this will not happen if uppercase acronyms are included in the name - to_underscore_and_back = inflection.camelize(inflection.underscore(cls.__name__), uppercase_first_letter=True) + to_underscore_and_back = inflection.camelize( + inflection.underscore(cls.__name__), uppercase_first_letter=True + ) if cls.__name__ != to_underscore_and_back: - raise schema.Error(f"Class name must be upper camel-case, without capitalized acronyms, found {cls.__name__} " - f"instead of {to_underscore_and_back}") - if len({g for g in (getattr(b, f"{schema.inheritable_pragma_prefix}group", None) - for b in cls.__bases__) if g}) > 1: + raise schema.Error( + f"Class name must be upper camel-case, without capitalized acronyms, found {cls.__name__} " + f"instead of {to_underscore_and_back}" + ) + if ( + len( + { + g + for g in ( + getattr(b, f"{schema.inheritable_pragma_prefix}group", None) + for b in cls.__bases__ + ) + if g + } + ) + > 1 + ): raise schema.Error(f"Bases with mixed groups for {cls.__name__}") pragmas = { # dir and getattr inherit from bases - a[len(schema.inheritable_pragma_prefix):]: getattr(cls, a) - for a in dir(cls) if a.startswith(schema.inheritable_pragma_prefix) + a[len(schema.inheritable_pragma_prefix) :]: getattr(cls, a) + for a in dir(cls) + if a.startswith(schema.inheritable_pragma_prefix) } pragmas |= cls.__dict__.get("_pragmas", {}) derived = {d.__name__ for d in cls.__subclasses__()} if "null" in pragmas and derived: raise schema.Error(f"Null class cannot be derived") - return schema.Class(name=cls.__name__, - bases=[b.__name__ for b in cls.__bases__ if b is not object], - derived=derived, - pragmas=pragmas, - cfg=cls.__cfg__ if hasattr(cls, "__cfg__") else False, - # in the following we don't use `getattr` to avoid inheriting - properties=[ - a | _PropertyNamer(n) - for n, a in cls.__dict__.get("__annotations__", {}).items() - ], - doc=schema.split_doc(cls.__doc__), - ) - - -def _toposort_classes_by_group(classes: typing.Dict[str, schema.Class]) -> typing.Dict[str, schema.Class]: + return schema.Class( + name=cls.__name__, + bases=[b.__name__ for b in cls.__bases__ if b is not object], + derived=derived, + pragmas=pragmas, + cfg=cls.__cfg__ if hasattr(cls, "__cfg__") else False, + # in the following we don't use `getattr` to avoid inheriting + properties=[ + a | _PropertyNamer(n) + for n, a in cls.__dict__.get("__annotations__", {}).items() + ], + doc=schema.split_doc(cls.__doc__), + ) + + +def _toposort_classes_by_group( + classes: typing.Dict[str, schema.Class], +) -> typing.Dict[str, schema.Class]: groups = {} ret = {} @@ -79,7 +99,7 @@ def _toposort_classes_by_group(classes: typing.Dict[str, schema.Class]) -> typin def _fill_synth_information(classes: typing.Dict[str, schema.Class]): - """ Take a dictionary where the `synth` field is filled for all explicitly synthesized classes + """Take a dictionary where the `synth` field is filled for all explicitly synthesized classes and update it so that all non-final classes that have only synthesized final descendants get `True` as` value for the `synth` field """ @@ -109,7 +129,7 @@ def fill_is_synth(name: str): def _fill_hideable_information(classes: typing.Dict[str, schema.Class]): - """ Update the class map propagating the `hideable` attribute upwards in the hierarchy """ + """Update the class map propagating the `hideable` attribute upwards in the hierarchy""" todo = [cls for cls in classes.values() if "ql_hideable" in cls.pragmas] while todo: cls = todo.pop() @@ -123,10 +143,14 @@ def _fill_hideable_information(classes: typing.Dict[str, schema.Class]): def _check_test_with(classes: typing.Dict[str, schema.Class]): for cls in classes.values(): test_with = typing.cast(str, cls.pragmas.get("qltest_test_with")) - transitive_test_with = test_with and classes[test_with].pragmas.get("qltest_test_with") + transitive_test_with = test_with and classes[test_with].pragmas.get( + "qltest_test_with" + ) if test_with and transitive_test_with: - raise schema.Error(f"{cls.name} has test_with {test_with} which in turn " - f"has test_with {transitive_test_with}, use that directly") + raise schema.Error( + f"{cls.name} has test_with {test_with} which in turn " + f"has test_with {transitive_test_with}, use that directly" + ) def load(m: types.ModuleType) -> schema.Schema: @@ -136,6 +160,7 @@ def load(m: types.ModuleType) -> schema.Schema: known = {"int", "string", "boolean"} known.update(n for n in m.__dict__ if not n.startswith("__")) import misc.codegen.lib.schemadefs as defs + null = None for name, data in m.__dict__.items(): if hasattr(defs, name): @@ -152,21 +177,26 @@ def load(m: types.ModuleType) -> schema.Schema: continue cls = _get_class(data) if classes and not cls.bases: - raise schema.Error( - f"Only one root class allowed, found second root {name}") + raise schema.Error(f"Only one root class allowed, found second root {name}") cls.check_types(known) classes[name] = cls if "null" in cls.pragmas: del cls.pragmas["null"] if null is not None: - raise schema.Error(f"Null class {null} already defined, second null class {name} not allowed") + raise schema.Error( + f"Null class {null} already defined, second null class {name} not allowed" + ) null = name _fill_synth_information(classes) _fill_hideable_information(classes) _check_test_with(classes) - return schema.Schema(includes=includes, classes=imported_classes | _toposort_classes_by_group(classes), null=null) + return schema.Schema( + includes=includes, + classes=imported_classes | _toposort_classes_by_group(classes), + null=null, + ) def load_file(path: pathlib.Path) -> schema.Schema: diff --git a/misc/codegen/test/test_cpp.py b/misc/codegen/test/test_cpp.py index c4bee337a4f7..77295bb0d828 100644 --- a/misc/codegen/test/test_cpp.py +++ b/misc/codegen/test/test_cpp.py @@ -17,34 +17,49 @@ def test_field_name(): assert f.field_name == "foo" -@pytest.mark.parametrize("type,expected", [ - ("std::string", "trapQuoted(value)"), - ("bool", '(value ? "true" : "false")'), - ("something_else", "value"), -]) +@pytest.mark.parametrize( + "type,expected", + [ + ("std::string", "trapQuoted(value)"), + ("bool", '(value ? "true" : "false")'), + ("something_else", "value"), + ], +) def test_field_get_streamer(type, expected): f = cpp.Field("name", type) assert f.get_streamer()("value") == expected -@pytest.mark.parametrize("is_optional,is_repeated,is_predicate,expected", [ - (False, False, False, True), - (True, False, False, False), - (False, True, False, False), - (True, True, False, False), - (False, False, True, False), -]) +@pytest.mark.parametrize( + "is_optional,is_repeated,is_predicate,expected", + [ + (False, False, False, True), + (True, False, False, False), + (False, True, False, False), + (True, True, False, False), + (False, False, True, False), + ], +) def test_field_is_single(is_optional, is_repeated, is_predicate, expected): - f = cpp.Field("name", "type", is_optional=is_optional, is_repeated=is_repeated, is_predicate=is_predicate) + f = cpp.Field( + "name", + "type", + is_optional=is_optional, + is_repeated=is_repeated, + is_predicate=is_predicate, + ) assert f.is_single is expected -@pytest.mark.parametrize("is_optional,is_repeated,expected", [ - (False, False, "bar"), - (True, False, "std::optional"), - (False, True, "std::vector"), - (True, True, "std::vector>"), -]) +@pytest.mark.parametrize( + "is_optional,is_repeated,expected", + [ + (False, False, "bar"), + (True, False, "std::optional"), + (False, True, "std::vector"), + (True, True, "std::vector>"), + ], +) def test_field_modal_types(is_optional, is_repeated, expected): f = cpp.Field("name", "bar", is_optional=is_optional, is_repeated=is_repeated) assert f.type == expected @@ -69,11 +84,9 @@ def test_tag_has_first_base_marked(): assert t.bases == expected -@pytest.mark.parametrize("bases,expected", [ - ([], False), - (["a"], True), - (["a", "b"], True) -]) +@pytest.mark.parametrize( + "bases,expected", [([], False), (["a"], True), (["a", "b"], True)] +) def test_tag_has_bases(bases, expected): t = cpp.Tag("name", bases, "id") assert t.has_bases is expected @@ -91,11 +104,9 @@ def test_class_has_first_base_marked(): assert c.bases == expected -@pytest.mark.parametrize("bases,expected", [ - ([], False), - (["a"], True), - (["a", "b"], True) -]) +@pytest.mark.parametrize( + "bases,expected", [([], False), (["a"], True), (["a", "b"], True)] +) def test_class_has_bases(bases, expected): t = cpp.Class("name", [cpp.Class(b) for b in bases]) assert t.has_bases is expected @@ -113,5 +124,5 @@ def test_class_single_fields(): assert c.single_fields == fields[::2] -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/misc/codegen/test/test_cppgen.py b/misc/codegen/test/test_cppgen.py index 063940322412..8d0d4605b052 100644 --- a/misc/codegen/test/test_cppgen.py +++ b/misc/codegen/test/test_cppgen.py @@ -18,7 +18,10 @@ def ret(classes): assert isinstance(g, cpp.ClassList), f assert g.include_parent is (f.parent != output_dir) assert f.name == "TrapClasses", f - return {str(f.parent.relative_to(output_dir)): g.classes for f, g in generated.items()} + return { + str(f.parent.relative_to(output_dir)): g.classes + for f, g in generated.items() + } return ret @@ -38,129 +41,193 @@ def test_empty(generate): def test_empty_class(generate): - assert generate([ - schema.Class(name="MyClass"), - ]) == [ - cpp.Class(name="MyClass", final=True, trap_name="MyClasses") - ] + assert generate( + [ + schema.Class(name="MyClass"), + ] + ) == [cpp.Class(name="MyClass", final=True, trap_name="MyClasses")] def test_two_class_hierarchy(generate): base = cpp.Class(name="A") - assert generate([ - schema.Class(name="A", derived={"B"}), - schema.Class(name="B", bases=["A"]), - ]) == [ + assert generate( + [ + schema.Class(name="A", derived={"B"}), + schema.Class(name="B", bases=["A"]), + ] + ) == [ base, cpp.Class(name="B", bases=[base], final=True, trap_name="Bs"), ] -@pytest.mark.parametrize("type,expected", [ - ("a", "a"), - ("string", "std::string"), - ("boolean", "bool"), - ("MyClass", "TrapLabel"), -]) -@pytest.mark.parametrize("property_cls,optional,repeated,unordered,trap_name", [ - (schema.SingleProperty, False, False, False, None), - (schema.OptionalProperty, True, False, False, "MyClassProps"), - (schema.RepeatedProperty, False, True, False, "MyClassProps"), - (schema.RepeatedOptionalProperty, True, True, False, "MyClassProps"), - (schema.RepeatedUnorderedProperty, False, True, True, "MyClassProps"), -]) -def test_class_with_field(generate, type, expected, property_cls, optional, repeated, unordered, trap_name): - assert generate([ - schema.Class(name="MyClass", properties=[property_cls("prop", type)]), - ]) == [ - cpp.Class(name="MyClass", - fields=[cpp.Field("prop", expected, is_optional=optional, - is_repeated=repeated, is_unordered=unordered, trap_name=trap_name)], - trap_name="MyClasses", - final=True) +@pytest.mark.parametrize( + "type,expected", + [ + ("a", "a"), + ("string", "std::string"), + ("boolean", "bool"), + ("MyClass", "TrapLabel"), + ], +) +@pytest.mark.parametrize( + "property_cls,optional,repeated,unordered,trap_name", + [ + (schema.SingleProperty, False, False, False, None), + (schema.OptionalProperty, True, False, False, "MyClassProps"), + (schema.RepeatedProperty, False, True, False, "MyClassProps"), + (schema.RepeatedOptionalProperty, True, True, False, "MyClassProps"), + (schema.RepeatedUnorderedProperty, False, True, True, "MyClassProps"), + ], +) +def test_class_with_field( + generate, type, expected, property_cls, optional, repeated, unordered, trap_name +): + assert generate( + [ + schema.Class(name="MyClass", properties=[property_cls("prop", type)]), + ] + ) == [ + cpp.Class( + name="MyClass", + fields=[ + cpp.Field( + "prop", + expected, + is_optional=optional, + is_repeated=repeated, + is_unordered=unordered, + trap_name=trap_name, + ) + ], + trap_name="MyClasses", + final=True, + ) ] def test_class_field_with_null(generate, input): input.null = "Null" a = cpp.Class(name="A") - assert generate([ - schema.Class(name="A", derived={"B"}), - schema.Class(name="B", bases=["A"], properties=[ - schema.SingleProperty("x", "A"), - schema.SingleProperty("y", "B"), - ]) - ]) == [ + assert generate( + [ + schema.Class(name="A", derived={"B"}), + schema.Class( + name="B", + bases=["A"], + properties=[ + schema.SingleProperty("x", "A"), + schema.SingleProperty("y", "B"), + ], + ), + ] + ) == [ a, - cpp.Class(name="B", bases=[a], final=True, trap_name="Bs", - fields=[ - cpp.Field("x", "TrapLabel"), - cpp.Field("y", "TrapLabel"), - ]), + cpp.Class( + name="B", + bases=[a], + final=True, + trap_name="Bs", + fields=[ + cpp.Field("x", "TrapLabel"), + cpp.Field("y", "TrapLabel"), + ], + ), ] def test_class_with_predicate(generate): - assert generate([ - schema.Class(name="MyClass", properties=[ - schema.PredicateProperty("prop")]), - ]) == [ - cpp.Class(name="MyClass", - fields=[ - cpp.Field("prop", "bool", trap_name="MyClassProp", is_predicate=True)], - trap_name="MyClasses", - final=True) + assert generate( + [ + schema.Class(name="MyClass", properties=[schema.PredicateProperty("prop")]), + ] + ) == [ + cpp.Class( + name="MyClass", + fields=[ + cpp.Field("prop", "bool", trap_name="MyClassProp", is_predicate=True) + ], + trap_name="MyClasses", + final=True, + ) ] -@pytest.mark.parametrize("name", - ["start_line", "start_column", "end_line", "end_column", "index", "num_whatever", "width"]) +@pytest.mark.parametrize( + "name", + [ + "start_line", + "start_column", + "end_line", + "end_column", + "index", + "num_whatever", + "width", + ], +) def test_class_with_overridden_unsigned_field(generate, name): - assert generate([ - schema.Class(name="MyClass", properties=[ - schema.SingleProperty(name, "bar")]), - ]) == [ - cpp.Class(name="MyClass", - fields=[cpp.Field(name, "unsigned")], - trap_name="MyClasses", - final=True) + assert generate( + [ + schema.Class( + name="MyClass", properties=[schema.SingleProperty(name, "bar")] + ), + ] + ) == [ + cpp.Class( + name="MyClass", + fields=[cpp.Field(name, "unsigned")], + trap_name="MyClasses", + final=True, + ) ] def test_class_with_overridden_underscore_field(generate): - assert generate([ - schema.Class(name="MyClass", properties=[ - schema.SingleProperty("something_", "bar")]), - ]) == [ - cpp.Class(name="MyClass", - fields=[cpp.Field("something", "bar")], - trap_name="MyClasses", - final=True) + assert generate( + [ + schema.Class( + name="MyClass", properties=[schema.SingleProperty("something_", "bar")] + ), + ] + ) == [ + cpp.Class( + name="MyClass", + fields=[cpp.Field("something", "bar")], + trap_name="MyClasses", + final=True, + ) ] @pytest.mark.parametrize("name", cpp.cpp_keywords) def test_class_with_keyword_field(generate, name): - assert generate([ - schema.Class(name="MyClass", properties=[ - schema.SingleProperty(name, "bar")]), - ]) == [ - cpp.Class(name="MyClass", - fields=[cpp.Field(name + "_", "bar")], - trap_name="MyClasses", - final=True) + assert generate( + [ + schema.Class( + name="MyClass", properties=[schema.SingleProperty(name, "bar")] + ), + ] + ) == [ + cpp.Class( + name="MyClass", + fields=[cpp.Field(name + "_", "bar")], + trap_name="MyClasses", + final=True, + ) ] def test_classes_with_dirs(generate_grouped): cbase = cpp.Class(name="CBase") - assert generate_grouped([ - schema.Class(name="A"), - schema.Class(name="B", pragmas={"group": "foo"}), - schema.Class(name="CBase", derived={"C"}, pragmas={"group": "bar"}), - schema.Class(name="C", bases=["CBase"], pragmas={"group": "bar"}), - schema.Class(name="D", pragmas={"group": "foo/bar/baz"}), - ]) == { + assert generate_grouped( + [ + schema.Class(name="A"), + schema.Class(name="B", pragmas={"group": "foo"}), + schema.Class(name="CBase", derived={"C"}, pragmas={"group": "bar"}), + schema.Class(name="C", bases=["CBase"], pragmas={"group": "bar"}), + schema.Class(name="D", pragmas={"group": "foo/bar/baz"}), + ] + ) == { ".": [cpp.Class(name="A", trap_name="As", final=True)], "foo": [cpp.Class(name="B", trap_name="Bs", final=True)], "bar": [cbase, cpp.Class(name="C", bases=[cbase], trap_name="Cs", final=True)], @@ -169,81 +236,126 @@ def test_classes_with_dirs(generate_grouped): def test_cpp_skip_pragma(generate): - assert generate([ - schema.Class(name="A", properties=[ - schema.SingleProperty("x", "foo"), - schema.SingleProperty("y", "bar", pragmas=["x", "cpp_skip", "y"]), - ]) - ]) == [ - cpp.Class(name="A", final=True, trap_name="As", fields=[ - cpp.Field("x", "foo"), - ]), + assert generate( + [ + schema.Class( + name="A", + properties=[ + schema.SingleProperty("x", "foo"), + schema.SingleProperty("y", "bar", pragmas=["x", "cpp_skip", "y"]), + ], + ) + ] + ) == [ + cpp.Class( + name="A", + final=True, + trap_name="As", + fields=[ + cpp.Field("x", "foo"), + ], + ), ] def test_synth_classes_ignored(generate): - assert generate([ - schema.Class( - name="W", - pragmas={"synth": schema.SynthInfo()}, - ), - schema.Class( - name="X", - pragmas={"synth": schema.SynthInfo(from_class="A")}, - ), - schema.Class( - name="Y", - pragmas={"synth": schema.SynthInfo(on_arguments={"a": "A", "b": "int"})}, - ), - schema.Class( - name="Z", - ), - ]) == [ + assert generate( + [ + schema.Class( + name="W", + pragmas={"synth": schema.SynthInfo()}, + ), + schema.Class( + name="X", + pragmas={"synth": schema.SynthInfo(from_class="A")}, + ), + schema.Class( + name="Y", + pragmas={ + "synth": schema.SynthInfo(on_arguments={"a": "A", "b": "int"}) + }, + ), + schema.Class( + name="Z", + ), + ] + ) == [ cpp.Class(name="Z", final=True, trap_name="Zs"), ] def test_synth_properties_ignored(generate): - assert generate([ - schema.Class( + assert generate( + [ + schema.Class( + name="X", + properties=[ + schema.SingleProperty("x", "a"), + schema.SingleProperty("y", "b", synth=True), + schema.SingleProperty("z", "c"), + schema.OptionalProperty("foo", "bar", synth=True), + schema.RepeatedProperty("baz", "bazz", synth=True), + schema.RepeatedOptionalProperty("bazzz", "bazzzz", synth=True), + schema.RepeatedUnorderedProperty("bazzzzz", "bazzzzzz", synth=True), + ], + ), + ] + ) == [ + cpp.Class( name="X", - properties=[ - schema.SingleProperty("x", "a"), - schema.SingleProperty("y", "b", synth=True), - schema.SingleProperty("z", "c"), - schema.OptionalProperty("foo", "bar", synth=True), - schema.RepeatedProperty("baz", "bazz", synth=True), - schema.RepeatedOptionalProperty("bazzz", "bazzzz", synth=True), - schema.RepeatedUnorderedProperty("bazzzzz", "bazzzzzz", synth=True), + final=True, + trap_name="Xes", + fields=[ + cpp.Field("x", "a"), + cpp.Field("z", "c"), ], ), - ]) == [ - cpp.Class(name="X", final=True, trap_name="Xes", fields=[ - cpp.Field("x", "a"), - cpp.Field("z", "c"), - ]), ] def test_properties_with_custom_db_table_names(generate): - assert generate([ - schema.Class("Obj", properties=[ - schema.OptionalProperty("x", "a", pragmas={"ql_db_table_name": "foo"}), - schema.RepeatedProperty("y", "b", pragmas={"ql_db_table_name": "bar"}), - schema.RepeatedOptionalProperty("z", "c", pragmas={"ql_db_table_name": "baz"}), - schema.PredicateProperty("p", pragmas={"ql_db_table_name": "hello"}), - schema.RepeatedUnorderedProperty("q", "d", pragmas={"ql_db_table_name": "world"}), - ]), - ]) == [ - cpp.Class(name="Obj", final=True, trap_name="Objs", fields=[ - cpp.Field("x", "a", is_optional=True, trap_name="Foo"), - cpp.Field("y", "b", is_repeated=True, trap_name="Bar"), - cpp.Field("z", "c", is_repeated=True, is_optional=True, trap_name="Baz"), - cpp.Field("p", "bool", is_predicate=True, trap_name="Hello"), - cpp.Field("q", "d", is_repeated=True, is_unordered=True, trap_name="World"), - ]), + assert generate( + [ + schema.Class( + "Obj", + properties=[ + schema.OptionalProperty( + "x", "a", pragmas={"ql_db_table_name": "foo"} + ), + schema.RepeatedProperty( + "y", "b", pragmas={"ql_db_table_name": "bar"} + ), + schema.RepeatedOptionalProperty( + "z", "c", pragmas={"ql_db_table_name": "baz"} + ), + schema.PredicateProperty( + "p", pragmas={"ql_db_table_name": "hello"} + ), + schema.RepeatedUnorderedProperty( + "q", "d", pragmas={"ql_db_table_name": "world"} + ), + ], + ), + ] + ) == [ + cpp.Class( + name="Obj", + final=True, + trap_name="Objs", + fields=[ + cpp.Field("x", "a", is_optional=True, trap_name="Foo"), + cpp.Field("y", "b", is_repeated=True, trap_name="Bar"), + cpp.Field( + "z", "c", is_repeated=True, is_optional=True, trap_name="Baz" + ), + cpp.Field("p", "bool", is_predicate=True, trap_name="Hello"), + cpp.Field( + "q", "d", is_repeated=True, is_unordered=True, trap_name="World" + ), + ], + ), ] -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/misc/codegen/test/test_dbscheme.py b/misc/codegen/test/test_dbscheme.py index e2635ecee5ac..2ba7a269c6e8 100644 --- a/misc/codegen/test/test_dbscheme.py +++ b/misc/codegen/test/test_dbscheme.py @@ -14,12 +14,15 @@ def test_dbcolumn_keyword_name(keyword): assert dbscheme.Column(keyword, "some_type").name == keyword + "_" -@pytest.mark.parametrize("type,binding,lhstype,rhstype", [ - ("builtin_type", False, "builtin_type", "builtin_type ref"), - ("builtin_type", True, "builtin_type", "builtin_type ref"), - ("@at_type", False, "int", "@at_type ref"), - ("@at_type", True, "unique int", "@at_type"), -]) +@pytest.mark.parametrize( + "type,binding,lhstype,rhstype", + [ + ("builtin_type", False, "builtin_type", "builtin_type ref"), + ("builtin_type", True, "builtin_type", "builtin_type ref"), + ("@at_type", False, "int", "@at_type ref"), + ("@at_type", True, "unique int", "@at_type"), + ], +) def test_dbcolumn_types(type, binding, lhstype, rhstype): col = dbscheme.Column("foo", type, binding) assert col.lhstype == lhstype @@ -34,7 +37,11 @@ def test_keyset_has_first_id_marked(): def test_table_has_first_column_marked(): - columns = [dbscheme.Column("a", "x"), dbscheme.Column("b", "y", binding=True), dbscheme.Column("c", "z")] + columns = [ + dbscheme.Column("a", "x"), + dbscheme.Column("b", "y", binding=True), + dbscheme.Column("c", "z"), + ] expected = deepcopy(columns) table = dbscheme.Table("foo", columns) expected[0].first = True @@ -48,5 +55,5 @@ def test_union_has_first_case_marked(): assert [c.type for c in u.rhs] == rhs -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/misc/codegen/test/test_dbschemegen.py b/misc/codegen/test/test_dbschemegen.py index 653ad7fc8a39..7ae1941fe8d1 100644 --- a/misc/codegen/test/test_dbschemegen.py +++ b/misc/codegen/test/test_dbschemegen.py @@ -8,10 +8,12 @@ InputExpectedPair = collections.namedtuple("InputExpectedPair", ("input", "expected")) -@pytest.fixture(params=[ - InputExpectedPair(None, None), - InputExpectedPair("foodir", pathlib.Path("foodir")), -]) +@pytest.fixture( + params=[ + InputExpectedPair(None, None), + InputExpectedPair("foodir", pathlib.Path("foodir")), + ] +) def dir_param(request): return request.param @@ -21,7 +23,7 @@ def generate(opts, input, renderer): def func(classes, null=None): input.classes = {cls.name: cls for cls in classes} input.null = null - (out, data), = run_generation(dbschemegen.generate, opts, renderer).items() + ((out, data),) = run_generation(dbschemegen.generate, opts, renderer).items() assert out is opts.dbscheme return data @@ -48,23 +50,26 @@ def test_includes(input, opts, generate): dbscheme.SchemeInclude( src=pathlib.Path(i), data=i + " data", - ) for i in includes + ) + for i in includes ], declarations=[], ) def test_empty_final_class(generate, dir_param): - assert generate([ - schema.Class("Object", pragmas={"group": dir_param.input}), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class("Object", pragmas={"group": dir_param.input}), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ dbscheme.Table( name="objects", columns=[ - dbscheme.Column('id', '@object', binding=True), + dbscheme.Column("id", "@object", binding=True), ], dir=dir_param.expected, ) @@ -73,218 +78,279 @@ def test_empty_final_class(generate, dir_param): def test_final_class_with_single_scalar_field(generate, dir_param): - assert generate([ - schema.Class("Object", pragmas={"group": dir_param.input}, properties=[ - schema.SingleProperty("foo", "bar"), - ]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class( + "Object", + pragmas={"group": dir_param.input}, + properties=[ + schema.SingleProperty("foo", "bar"), + ], + ), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ dbscheme.Table( name="objects", columns=[ - dbscheme.Column('id', '@object', binding=True), - dbscheme.Column('foo', 'bar'), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object", binding=True), + dbscheme.Column("foo", "bar"), + ], + dir=dir_param.expected, ) ], ) def test_final_class_with_single_class_field(generate, dir_param): - assert generate([ - schema.Class("Object", pragmas={"group": dir_param.input}, properties=[ - schema.SingleProperty("foo", "Bar"), - ]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class( + "Object", + pragmas={"group": dir_param.input}, + properties=[ + schema.SingleProperty("foo", "Bar"), + ], + ), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ dbscheme.Table( name="objects", columns=[ - dbscheme.Column('id', '@object', binding=True), - dbscheme.Column('foo', '@bar'), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object", binding=True), + dbscheme.Column("foo", "@bar"), + ], + dir=dir_param.expected, ) ], ) def test_final_class_with_optional_field(generate, dir_param): - assert generate([ - schema.Class("Object", pragmas={"group": dir_param.input}, properties=[ - schema.OptionalProperty("foo", "bar"), - ]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class( + "Object", + pragmas={"group": dir_param.input}, + properties=[ + schema.OptionalProperty("foo", "bar"), + ], + ), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ dbscheme.Table( name="objects", columns=[ - dbscheme.Column('id', '@object', binding=True), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object", binding=True), + ], + dir=dir_param.expected, ), dbscheme.Table( name="object_foos", keyset=dbscheme.KeySet(["id"]), columns=[ - dbscheme.Column('id', '@object'), - dbscheme.Column('foo', 'bar'), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object"), + dbscheme.Column("foo", "bar"), + ], + dir=dir_param.expected, ), ], ) -@pytest.mark.parametrize("property_cls", [schema.RepeatedProperty, schema.RepeatedOptionalProperty]) +@pytest.mark.parametrize( + "property_cls", [schema.RepeatedProperty, schema.RepeatedOptionalProperty] +) def test_final_class_with_repeated_field(generate, property_cls, dir_param): - assert generate([ - schema.Class("Object", pragmas={"group": dir_param.input}, properties=[ - property_cls("foo", "bar"), - ]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class( + "Object", + pragmas={"group": dir_param.input}, + properties=[ + property_cls("foo", "bar"), + ], + ), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ dbscheme.Table( name="objects", columns=[ - dbscheme.Column('id', '@object', binding=True), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object", binding=True), + ], + dir=dir_param.expected, ), dbscheme.Table( name="object_foos", keyset=dbscheme.KeySet(["id", "index"]), columns=[ - dbscheme.Column('id', '@object'), - dbscheme.Column('index', 'int'), - dbscheme.Column('foo', 'bar'), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object"), + dbscheme.Column("index", "int"), + dbscheme.Column("foo", "bar"), + ], + dir=dir_param.expected, ), ], ) def test_final_class_with_repeated_unordered_field(generate, dir_param): - assert generate([ - schema.Class("Object", pragmas={"group": dir_param.input}, properties=[ - schema.RepeatedUnorderedProperty("foo", "bar"), - ]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class( + "Object", + pragmas={"group": dir_param.input}, + properties=[ + schema.RepeatedUnorderedProperty("foo", "bar"), + ], + ), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ dbscheme.Table( name="objects", columns=[ - dbscheme.Column('id', '@object', binding=True), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object", binding=True), + ], + dir=dir_param.expected, ), dbscheme.Table( name="object_foos", columns=[ - dbscheme.Column('id', '@object'), - dbscheme.Column('foo', 'bar'), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object"), + dbscheme.Column("foo", "bar"), + ], + dir=dir_param.expected, ), ], ) def test_final_class_with_predicate_field(generate, dir_param): - assert generate([ - schema.Class("Object", pragmas={"group": dir_param.input}, properties=[ - schema.PredicateProperty("foo"), - ]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class( + "Object", + pragmas={"group": dir_param.input}, + properties=[ + schema.PredicateProperty("foo"), + ], + ), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ dbscheme.Table( name="objects", columns=[ - dbscheme.Column('id', '@object', binding=True), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object", binding=True), + ], + dir=dir_param.expected, ), dbscheme.Table( name="object_foo", keyset=dbscheme.KeySet(["id"]), columns=[ - dbscheme.Column('id', '@object'), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object"), + ], + dir=dir_param.expected, ), ], ) def test_final_class_with_more_fields(generate, dir_param): - assert generate([ - schema.Class("Object", pragmas={"group": dir_param.input}, properties=[ - schema.SingleProperty("one", "x"), - schema.SingleProperty("two", "y"), - schema.OptionalProperty("three", "z"), - schema.RepeatedProperty("four", "u"), - schema.RepeatedOptionalProperty("five", "v"), - schema.PredicateProperty("six"), - ]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class( + "Object", + pragmas={"group": dir_param.input}, + properties=[ + schema.SingleProperty("one", "x"), + schema.SingleProperty("two", "y"), + schema.OptionalProperty("three", "z"), + schema.RepeatedProperty("four", "u"), + schema.RepeatedOptionalProperty("five", "v"), + schema.PredicateProperty("six"), + ], + ), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ dbscheme.Table( name="objects", columns=[ - dbscheme.Column('id', '@object', binding=True), - dbscheme.Column('one', 'x'), - dbscheme.Column('two', 'y'), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object", binding=True), + dbscheme.Column("one", "x"), + dbscheme.Column("two", "y"), + ], + dir=dir_param.expected, ), dbscheme.Table( name="object_threes", keyset=dbscheme.KeySet(["id"]), columns=[ - dbscheme.Column('id', '@object'), - dbscheme.Column('three', 'z'), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object"), + dbscheme.Column("three", "z"), + ], + dir=dir_param.expected, ), dbscheme.Table( name="object_fours", keyset=dbscheme.KeySet(["id", "index"]), columns=[ - dbscheme.Column('id', '@object'), - dbscheme.Column('index', 'int'), - dbscheme.Column('four', 'u'), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object"), + dbscheme.Column("index", "int"), + dbscheme.Column("four", "u"), + ], + dir=dir_param.expected, ), dbscheme.Table( name="object_fives", keyset=dbscheme.KeySet(["id", "index"]), columns=[ - dbscheme.Column('id', '@object'), - dbscheme.Column('index', 'int'), - dbscheme.Column('five', 'v'), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object"), + dbscheme.Column("index", "int"), + dbscheme.Column("five", "v"), + ], + dir=dir_param.expected, ), dbscheme.Table( name="object_six", keyset=dbscheme.KeySet(["id"]), columns=[ - dbscheme.Column('id', '@object'), - ], dir=dir_param.expected, + dbscheme.Column("id", "@object"), + ], + dir=dir_param.expected, ), ], ) def test_empty_class_with_derived(generate): - assert generate([ - schema.Class(name="Base", derived={"Left", "Right"}), - schema.Class(name="Left", bases=["Base"]), - schema.Class(name="Right", bases=["Base"]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class(name="Base", derived={"Left", "Right"}), + schema.Class(name="Left", bases=["Base"]), + schema.Class(name="Right", bases=["Base"]), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ @@ -305,17 +371,20 @@ def test_empty_class_with_derived(generate): def test_class_with_derived_and_single_property(generate, dir_param): - assert generate([ - schema.Class( - name="Base", - derived={"Left", "Right"}, - pragmas={"group": dir_param.input}, - properties=[ - schema.SingleProperty("single", "Prop"), - ]), - schema.Class(name="Left", bases=["Base"]), - schema.Class(name="Right", bases=["Base"]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class( + name="Base", + derived={"Left", "Right"}, + pragmas={"group": dir_param.input}, + properties=[ + schema.SingleProperty("single", "Prop"), + ], + ), + schema.Class(name="Left", bases=["Base"]), + schema.Class(name="Right", bases=["Base"]), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ @@ -327,8 +396,8 @@ def test_class_with_derived_and_single_property(generate, dir_param): name="bases", keyset=dbscheme.KeySet(["id"]), columns=[ - dbscheme.Column('id', '@base'), - dbscheme.Column('single', '@prop'), + dbscheme.Column("id", "@base"), + dbscheme.Column("single", "@prop"), ], dir=dir_param.expected, ), @@ -345,17 +414,20 @@ def test_class_with_derived_and_single_property(generate, dir_param): def test_class_with_derived_and_optional_property(generate, dir_param): - assert generate([ - schema.Class( - name="Base", - derived={"Left", "Right"}, - pragmas={"group": dir_param.input}, - properties=[ - schema.OptionalProperty("opt", "Prop"), - ]), - schema.Class(name="Left", bases=["Base"]), - schema.Class(name="Right", bases=["Base"]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class( + name="Base", + derived={"Left", "Right"}, + pragmas={"group": dir_param.input}, + properties=[ + schema.OptionalProperty("opt", "Prop"), + ], + ), + schema.Class(name="Left", bases=["Base"]), + schema.Class(name="Right", bases=["Base"]), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ @@ -367,8 +439,8 @@ def test_class_with_derived_and_optional_property(generate, dir_param): name="base_opts", keyset=dbscheme.KeySet(["id"]), columns=[ - dbscheme.Column('id', '@base'), - dbscheme.Column('opt', '@prop'), + dbscheme.Column("id", "@base"), + dbscheme.Column("opt", "@prop"), ], dir=dir_param.expected, ), @@ -385,17 +457,20 @@ def test_class_with_derived_and_optional_property(generate, dir_param): def test_class_with_derived_and_repeated_property(generate, dir_param): - assert generate([ - schema.Class( - name="Base", - pragmas={"group": dir_param.input}, - derived={"Left", "Right"}, - properties=[ - schema.RepeatedProperty("rep", "Prop"), - ]), - schema.Class(name="Left", bases=["Base"]), - schema.Class(name="Right", bases=["Base"]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class( + name="Base", + pragmas={"group": dir_param.input}, + derived={"Left", "Right"}, + properties=[ + schema.RepeatedProperty("rep", "Prop"), + ], + ), + schema.Class(name="Left", bases=["Base"]), + schema.Class(name="Right", bases=["Base"]), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ @@ -407,9 +482,9 @@ def test_class_with_derived_and_repeated_property(generate, dir_param): name="base_reps", keyset=dbscheme.KeySet(["id", "index"]), columns=[ - dbscheme.Column('id', '@base'), - dbscheme.Column('index', 'int'), - dbscheme.Column('rep', '@prop'), + dbscheme.Column("id", "@base"), + dbscheme.Column("index", "int"), + dbscheme.Column("rep", "@prop"), ], dir=dir_param.expected, ), @@ -426,38 +501,41 @@ def test_class_with_derived_and_repeated_property(generate, dir_param): def test_null_class(generate): - assert generate([ - schema.Class( - name="Base", - derived={"W", "X", "Y", "Z", "Null"}, - ), - schema.Class( - name="W", - bases=["Base"], - properties=[ - schema.SingleProperty("w", "W"), - schema.SingleProperty("x", "X"), - schema.OptionalProperty("y", "Y"), - schema.RepeatedProperty("z", "Z"), - ] - ), - schema.Class( - name="X", - bases=["Base"], - ), - schema.Class( - name="Y", - bases=["Base"], - ), - schema.Class( - name="Z", - bases=["Base"], - ), - schema.Class( - name="Null", - bases=["Base"], - ), - ], null="Null") == dbscheme.Scheme( + assert generate( + [ + schema.Class( + name="Base", + derived={"W", "X", "Y", "Z", "Null"}, + ), + schema.Class( + name="W", + bases=["Base"], + properties=[ + schema.SingleProperty("w", "W"), + schema.SingleProperty("x", "X"), + schema.OptionalProperty("y", "Y"), + schema.RepeatedProperty("z", "Z"), + ], + ), + schema.Class( + name="X", + bases=["Base"], + ), + schema.Class( + name="Y", + bases=["Base"], + ), + schema.Class( + name="Z", + bases=["Base"], + ), + schema.Class( + name="Null", + bases=["Base"], + ), + ], + null="Null", + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ @@ -468,50 +546,50 @@ def test_null_class(generate): dbscheme.Table( name="ws", columns=[ - dbscheme.Column('id', '@w', binding=True), - dbscheme.Column('w', '@w_or_none'), - dbscheme.Column('x', '@x_or_none'), + dbscheme.Column("id", "@w", binding=True), + dbscheme.Column("w", "@w_or_none"), + dbscheme.Column("x", "@x_or_none"), ], ), dbscheme.Table( name="w_ies", keyset=dbscheme.KeySet(["id"]), columns=[ - dbscheme.Column('id', '@w'), - dbscheme.Column('y', '@y_or_none'), + dbscheme.Column("id", "@w"), + dbscheme.Column("y", "@y_or_none"), ], ), dbscheme.Table( name="w_zs", keyset=dbscheme.KeySet(["id", "index"]), columns=[ - dbscheme.Column('id', '@w'), - dbscheme.Column('index', 'int'), - dbscheme.Column('z', '@z_or_none'), + dbscheme.Column("id", "@w"), + dbscheme.Column("index", "int"), + dbscheme.Column("z", "@z_or_none"), ], ), dbscheme.Table( name="xes", columns=[ - dbscheme.Column('id', '@x', binding=True), + dbscheme.Column("id", "@x", binding=True), ], ), dbscheme.Table( name="ys", columns=[ - dbscheme.Column('id', '@y', binding=True), + dbscheme.Column("id", "@y", binding=True), ], ), dbscheme.Table( name="zs", columns=[ - dbscheme.Column('id', '@z', binding=True), + dbscheme.Column("id", "@z", binding=True), ], ), dbscheme.Table( name="nulls", columns=[ - dbscheme.Column('id', '@null', binding=True), + dbscheme.Column("id", "@null", binding=True), ], ), dbscheme.Union( @@ -535,11 +613,15 @@ def test_null_class(generate): def test_synth_classes_ignored(generate): - assert generate([ - schema.Class(name="A", pragmas={"synth": schema.SynthInfo()}), - schema.Class(name="B", pragmas={"synth": schema.SynthInfo(from_class="A")}), - schema.Class(name="C", pragmas={"synth": schema.SynthInfo(on_arguments={"x": "A"})}), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class(name="A", pragmas={"synth": schema.SynthInfo()}), + schema.Class(name="B", pragmas={"synth": schema.SynthInfo(from_class="A")}), + schema.Class( + name="C", pragmas={"synth": schema.SynthInfo(on_arguments={"x": "A"})} + ), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[], @@ -547,11 +629,13 @@ def test_synth_classes_ignored(generate): def test_synth_derived_classes_ignored(generate): - assert generate([ - schema.Class(name="A", derived={"B", "C"}), - schema.Class(name="B", bases=["A"], pragmas={"synth": schema.SynthInfo()}), - schema.Class(name="C", bases=["A"]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class(name="A", derived={"B", "C"}), + schema.Class(name="B", bases=["A"], pragmas={"synth": schema.SynthInfo()}), + schema.Class(name="C", bases=["A"]), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ @@ -561,23 +645,28 @@ def test_synth_derived_classes_ignored(generate): columns=[ dbscheme.Column("id", "@c", binding=True), ], - ) + ), ], ) def test_synth_properties_ignored(generate): - assert generate([ - schema.Class(name="A", properties=[ - schema.SingleProperty("x", "a"), - schema.SingleProperty("y", "b", synth=True), - schema.SingleProperty("z", "c"), - schema.OptionalProperty("foo", "bar", synth=True), - schema.RepeatedProperty("baz", "bazz", synth=True), - schema.RepeatedOptionalProperty("bazzz", "bazzzz", synth=True), - schema.RepeatedUnorderedProperty("bazzzzz", "bazzzzzz", synth=True), - ]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class( + name="A", + properties=[ + schema.SingleProperty("x", "a"), + schema.SingleProperty("y", "b", synth=True), + schema.SingleProperty("z", "c"), + schema.OptionalProperty("foo", "bar", synth=True), + schema.RepeatedProperty("baz", "bazz", synth=True), + schema.RepeatedOptionalProperty("bazzz", "bazzzz", synth=True), + schema.RepeatedUnorderedProperty("bazzzzz", "bazzzzzz", synth=True), + ], + ), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ @@ -595,24 +684,44 @@ def test_synth_properties_ignored(generate): def test_table_conflict(generate): with pytest.raises(dbschemegen.Error): - generate([ - schema.Class("Foo", properties=[ - schema.OptionalProperty("bar", "FooBar"), - ]), - schema.Class("FooBar"), - ]) + generate( + [ + schema.Class( + "Foo", + properties=[ + schema.OptionalProperty("bar", "FooBar"), + ], + ), + schema.Class("FooBar"), + ] + ) def test_table_name_overrides(generate): - assert generate([ - schema.Class("Obj", properties=[ - schema.OptionalProperty("x", "a", pragmas={"ql_db_table_name": "foo"}), - schema.RepeatedProperty("y", "b", pragmas={"ql_db_table_name": "bar"}), - schema.RepeatedOptionalProperty("z", "c", pragmas={"ql_db_table_name": "baz"}), - schema.PredicateProperty("p", pragmas={"ql_db_table_name": "hello"}), - schema.RepeatedUnorderedProperty("q", "d", pragmas={"ql_db_table_name": "world"}), - ]), - ]) == dbscheme.Scheme( + assert generate( + [ + schema.Class( + "Obj", + properties=[ + schema.OptionalProperty( + "x", "a", pragmas={"ql_db_table_name": "foo"} + ), + schema.RepeatedProperty( + "y", "b", pragmas={"ql_db_table_name": "bar"} + ), + schema.RepeatedOptionalProperty( + "z", "c", pragmas={"ql_db_table_name": "baz"} + ), + schema.PredicateProperty( + "p", pragmas={"ql_db_table_name": "hello"} + ), + schema.RepeatedUnorderedProperty( + "q", "d", pragmas={"ql_db_table_name": "world"} + ), + ], + ), + ] + ) == dbscheme.Scheme( src=schema_file.name, includes=[], declarations=[ @@ -666,5 +775,5 @@ def test_table_name_overrides(generate): ) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/misc/codegen/test/test_dbschemelaoder.py b/misc/codegen/test/test_dbschemelaoder.py index ab4efbff75a1..e9079b699b93 100644 --- a/misc/codegen/test/test_dbschemelaoder.py +++ b/misc/codegen/test/test_dbschemelaoder.py @@ -22,26 +22,42 @@ def test_load_empty(load): def test_load_one_empty_table(load): - assert load(""" + assert ( + load( + """ test_foos(); -""") == [ - dbscheme.Table(name="test_foos", columns=[]) - ] +""" + ) + == [dbscheme.Table(name="test_foos", columns=[])] + ) def test_load_table_with_keyset(load): - assert load(""" + assert ( + load( + """ #keyset[x, y,z] test_foos(); -""") == [ - dbscheme.Table(name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"])) - ] +""" + ) + == [ + dbscheme.Table( + name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"]) + ) + ] + ) expected_columns = [ ("int foo: int ref", dbscheme.Column(schema_name="foo", type="int", binding=False)), - (" int bar : int ref", dbscheme.Column(schema_name="bar", type="int", binding=False)), - ("str baz_: str ref", dbscheme.Column(schema_name="baz", type="str", binding=False)), + ( + " int bar : int ref", + dbscheme.Column(schema_name="bar", type="int", binding=False), + ), + ( + "str baz_: str ref", + dbscheme.Column(schema_name="baz", type="str", binding=False), + ), ("int x: @foo ref", dbscheme.Column(schema_name="x", type="@foo", binding=False)), ("int y: @foo", dbscheme.Column(schema_name="y", type="@foo", binding=True)), ("unique int z: @foo", dbscheme.Column(schema_name="z", type="@foo", binding=True)), @@ -50,42 +66,58 @@ def test_load_table_with_keyset(load): @pytest.mark.parametrize("column,expected", expected_columns) def test_load_table_with_column(load, column, expected): - assert load(f""" + assert ( + load( + f""" foos( {column} ); -""") == [ - dbscheme.Table(name="foos", columns=[deepcopy(expected)]) - ] +""" + ) + == [dbscheme.Table(name="foos", columns=[deepcopy(expected)])] + ) def test_load_table_with_multiple_columns(load): columns = ",\n".join(c for c, _ in expected_columns) expected = [deepcopy(e) for _, e in expected_columns] - assert load(f""" + assert ( + load( + f""" foos( {columns} ); -""") == [ - dbscheme.Table(name="foos", columns=expected) - ] +""" + ) + == [dbscheme.Table(name="foos", columns=expected)] + ) def test_load_table_with_multiple_columns_and_dir(load): columns = ",\n".join(c for c, _ in expected_columns) expected = [deepcopy(e) for _, e in expected_columns] - assert load(f""" + assert ( + load( + f""" foos( //dir=foo/bar/baz {columns} ); -""") == [ - dbscheme.Table(name="foos", columns=expected, dir=pathlib.Path("foo/bar/baz")) - ] +""" + ) + == [ + dbscheme.Table( + name="foos", columns=expected, dir=pathlib.Path("foo/bar/baz") + ) + ] + ) def test_load_multiple_table_with_columns(load): tables = [f"table{i}({col});" for i, (col, _) in enumerate(expected_columns)] - expected = [dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)]) for i, (_, e) in enumerate(expected_columns)] + expected = [ + dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)]) + for i, (_, e) in enumerate(expected_columns) + ] assert load("\n".join(tables)) == expected @@ -96,28 +128,41 @@ def test_union(load): def test_table_and_union(load): - assert load(""" + assert ( + load( + """ foos(); -@foo = @bar | @baz | @bla;""") == [ - dbscheme.Table(name="foos", columns=[]), - dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]), - ] +@foo = @bar | @baz | @bla;""" + ) + == [ + dbscheme.Table(name="foos", columns=[]), + dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]), + ] + ) def test_comments_ignored(load): - assert load(""" + assert ( + load( + """ // fake_table(); foos(/* x */unique /*y*/int/* z */ id/* */: /* * */ @bar/*, int ignored: int ref*/); -@foo = @bar | @baz | @bla; // | @xxx""") == [ - dbscheme.Table(name="foos", columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)]), - dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]), - ] - - -if __name__ == '__main__': +@foo = @bar | @baz | @bla; // | @xxx""" + ) + == [ + dbscheme.Table( + name="foos", + columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)], + ), + dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]), + ] + ) + + +if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/misc/codegen/test/test_ql.py b/misc/codegen/test/test_ql.py index e326e65a9e4f..406c6134a477 100644 --- a/misc/codegen/test/test_ql.py +++ b/misc/codegen/test/test_ql.py @@ -34,37 +34,55 @@ def test_property_unordered_getter(name, expected_getter): assert prop.getter == expected_getter -@pytest.mark.parametrize("plural,expected", [ - (None, False), - ("", False), - ("X", True), -]) +@pytest.mark.parametrize( + "plural,expected", + [ + (None, False), + ("", False), + ("X", True), + ], +) def test_property_is_repeated(plural, expected): prop = ql.Property("foo", "Foo", "props", ["result"], plural=plural) assert prop.is_repeated is expected -@pytest.mark.parametrize("plural,unordered,expected", [ - (None, False, False), - ("", False, False), - ("X", False, True), - ("X", True, False), -]) +@pytest.mark.parametrize( + "plural,unordered,expected", + [ + (None, False, False), + ("", False, False), + ("X", False, True), + ("X", True, False), + ], +) def test_property_is_indexed(plural, unordered, expected): - prop = ql.Property("foo", "Foo", "props", ["result"], plural=plural, is_unordered=unordered) + prop = ql.Property( + "foo", "Foo", "props", ["result"], plural=plural, is_unordered=unordered + ) assert prop.is_indexed is expected -@pytest.mark.parametrize("is_optional,is_predicate,plural,expected", [ - (False, False, None, True), - (False, False, "", True), - (False, False, "X", False), - (True, False, None, False), - (False, True, None, False), -]) +@pytest.mark.parametrize( + "is_optional,is_predicate,plural,expected", + [ + (False, False, None, True), + (False, False, "", True), + (False, False, "X", False), + (True, False, None, False), + (False, True, None, False), + ], +) def test_property_is_single(is_optional, is_predicate, plural, expected): - prop = ql.Property("foo", "Foo", "props", ["result"], plural=plural, - is_predicate=is_predicate, is_optional=is_optional) + prop = ql.Property( + "foo", + "Foo", + "props", + ["result"], + plural=plural, + is_predicate=is_predicate, + is_optional=is_optional, + ) assert prop.is_single is expected @@ -85,7 +103,12 @@ def test_property_predicate_getter(): def test_class_processes_bases(): bases = ["B", "Ab", "C", "Aa"] - expected = [ql.Base("B"), ql.Base("Ab", prev="B"), ql.Base("C", prev="Ab"), ql.Base("Aa", prev="C")] + expected = [ + ql.Base("B"), + ql.Base("Ab", prev="B"), + ql.Base("C", prev="Ab"), + ql.Base("Aa", prev="C"), + ] cls = ql.Class("Foo", bases=bases) assert cls.bases == expected @@ -110,7 +133,9 @@ def test_non_root_class(): assert not cls.root -@pytest.mark.parametrize("prev_child,is_child", [(None, False), ("", True), ("x", True)]) +@pytest.mark.parametrize( + "prev_child,is_child", [(None, False), ("", True), ("x", True)] +) def test_is_child(prev_child, is_child): p = ql.Property("Foo", "int", prev_child=prev_child) assert p.is_child is is_child @@ -122,22 +147,27 @@ def test_empty_class_no_children(): def test_class_no_children(): - cls = ql.Class("Class", properties=[ql.Property("Foo", "int"), ql.Property("Bar", "string")]) + cls = ql.Class( + "Class", properties=[ql.Property("Foo", "int"), ql.Property("Bar", "string")] + ) assert cls.has_children is False def test_class_with_children(): - cls = ql.Class("Class", properties=[ql.Property("Foo", "int"), ql.Property("Child", "x", prev_child=""), - ql.Property("Bar", "string")]) + cls = ql.Class( + "Class", + properties=[ + ql.Property("Foo", "int"), + ql.Property("Child", "x", prev_child=""), + ql.Property("Bar", "string"), + ], + ) assert cls.has_children is True -@pytest.mark.parametrize("doc,expected", - [ - (["foo", "bar"], True), - (["foo", "bar"], True), - ([], False) - ]) +@pytest.mark.parametrize( + "doc,expected", [(["foo", "bar"], True), (["foo", "bar"], True), ([], False)] +) def test_has_doc(doc, expected): stub = ql.Stub("Class", base_import="foo", import_prefix="bar", doc=doc) assert stub.has_qldoc is expected @@ -150,5 +180,5 @@ def test_synth_accessor_has_first_constructor_param_marked(): assert [p.param for p in x.constructorparams] == params -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/misc/codegen/test/test_qlgen.py b/misc/codegen/test/test_qlgen.py index 75e587fbd5eb..43617d5f9e42 100644 --- a/misc/codegen/test/test_qlgen.py +++ b/misc/codegen/test/test_qlgen.py @@ -17,22 +17,28 @@ def run_mock(): # these are lambdas so that they will use patched paths when called -def stub_path(): return paths.root_dir / "ql/lib/stub/path" +def stub_path(): + return paths.root_dir / "ql/lib/stub/path" -def ql_output_path(): return paths.root_dir / "ql/lib/other/path" +def ql_output_path(): + return paths.root_dir / "ql/lib/other/path" -def ql_test_output_path(): return paths.root_dir / "ql/test/path" +def ql_test_output_path(): + return paths.root_dir / "ql/test/path" -def generated_registry_path(): return paths.root_dir / "ql/registry.list" +def generated_registry_path(): + return paths.root_dir / "ql/registry.list" -def import_file(): return stub_path().with_suffix(".qll") +def import_file(): + return stub_path().with_suffix(".qll") -def children_file(): return ql_output_path() / "ParentChild.qll" +def children_file(): + return ql_output_path() / "ParentChild.qll" stub_import = "stub.path" @@ -63,7 +69,9 @@ def generate(input, qlgen_opts, renderer, render_manager): def func(classes): input.classes = {cls.name: cls for cls in classes} - return run_managed_generation(qlgen.generate, qlgen_opts, renderer, render_manager) + return run_managed_generation( + qlgen.generate, qlgen_opts, renderer, render_manager + ) return func @@ -109,20 +117,38 @@ def _filter_generated_classes(ret, output_test_files=False): except ValueError: assert False, f"{f} is in wrong directory" if output_test_files: - return { - str(f): ret[ql_test_output_path() / f] - for f in test_files - } - base_files -= {pathlib.Path(f"{name}.qll") for name in - ("Raw", "Synth", "SynthConstructors", "PureSynthConstructors")} - stub_files = {pathlib.Path(f.parent.parent, f.stem + ".qll") if f.parent.name == - "internal" and pathlib.Path(f.parent.parent, f.stem + ".qll") in base_files else f for f in stub_files} + return {str(f): ret[ql_test_output_path() / f] for f in test_files} + base_files -= { + pathlib.Path(f"{name}.qll") + for name in ("Raw", "Synth", "SynthConstructors", "PureSynthConstructors") + } + stub_files = { + ( + pathlib.Path(f.parent.parent, f.stem + ".qll") + if f.parent.name == "internal" + and pathlib.Path(f.parent.parent, f.stem + ".qll") in base_files + else f + ) + for f in stub_files + } assert base_files <= stub_files return { - str(f): (ret[stub_path() / "internal" / f] if stub_path() / "internal" / f in ret else ret[stub_path() / f], - ret[stub_path() / pathlib.Path(f.parent, "internal" if not f.parent.name == - "internal" else "", f.stem + "Impl.qll")], - ret[ql_output_path() / f]) + str(f): ( + ( + ret[stub_path() / "internal" / f] + if stub_path() / "internal" / f in ret + else ret[stub_path() / f] + ), + ret[ + stub_path() + / pathlib.Path( + f.parent, + "internal" if not f.parent.name == "internal" else "", + f.stem + "Impl.qll", + ) + ], + ret[ql_output_path() / f], + ) for f in base_files } @@ -148,8 +174,12 @@ def a_ql_class(**kwargs): def a_ql_stub(*, name, import_prefix="", **kwargs): - return ql.Stub(name=name, **kwargs, import_prefix=gen_import, - base_import=f"{gen_import_prefix}{import_prefix}{name}") + return ql.Stub( + name=name, + **kwargs, + import_prefix=gen_import, + base_import=f"{gen_import_prefix}{import_prefix}{name}", + ) def a_ql_class_public(*, name, **kwargs): @@ -157,347 +187,674 @@ def a_ql_class_public(*, name, **kwargs): def test_one_empty_class(generate_classes): - assert generate_classes([ - schema.Class("A") - ]) == { - "A.qll": (a_ql_class_public(name="A"), - a_ql_stub(name="A"), - a_ql_class(name="A", final=True, imports=[stub_import_prefix + "A"])) + assert generate_classes([schema.Class("A")]) == { + "A.qll": ( + a_ql_class_public(name="A"), + a_ql_stub(name="A"), + a_ql_class(name="A", final=True, imports=[stub_import_prefix + "A"]), + ) } def test_one_empty_internal_class(generate_classes): - assert generate_classes([ - schema.Class("A", pragmas=["ql_internal"]) - ]) == { - "A.qll": (a_ql_class_public(name="A", internal=True), - a_ql_stub(name="A"), - a_ql_class(name="A", final=True, internal=True, imports=[stub_import_prefix_internal + "A"])), + assert generate_classes([schema.Class("A", pragmas=["ql_internal"])]) == { + "A.qll": ( + a_ql_class_public(name="A", internal=True), + a_ql_stub(name="A"), + a_ql_class( + name="A", + final=True, + internal=True, + imports=[stub_import_prefix_internal + "A"], + ), + ), } def test_hierarchy(generate_classes): - assert generate_classes([ - schema.Class("D", bases=["B", "C"]), - schema.Class("C", bases=["A"], derived={"D"}), - schema.Class("B", bases=["A"], derived={"D"}), - schema.Class("A", derived={"B", "C"}), - ]) == { - "A.qll": (a_ql_class_public(name="A"), a_ql_stub(name="A"), a_ql_class(name="A", imports=[stub_import_prefix + "A"])), - "B.qll": (a_ql_class_public(name="B", imports=[stub_import_prefix + "A"]), a_ql_stub(name="B"), a_ql_class(name="B", bases=["A"], bases_impl=["AImpl::A"], imports=[stub_import_prefix_internal + "AImpl::Impl as AImpl"])), - "C.qll": (a_ql_class_public(name="C", imports=[stub_import_prefix + "A"]), a_ql_stub(name="C"), a_ql_class(name="C", bases=["A"], bases_impl=["AImpl::A"], imports=[stub_import_prefix_internal + "AImpl::Impl as AImpl"])), - "D.qll": (a_ql_class_public(name="D", imports=[stub_import_prefix + "B", stub_import_prefix + "C"]), a_ql_stub(name="D"), a_ql_class(name="D", final=True, bases=["B", "C"], bases_impl=["BImpl::B", "CImpl::C"], - imports=[stub_import_prefix_internal + cls + "Impl::Impl as " + cls + "Impl" for cls in "BC"])), + assert generate_classes( + [ + schema.Class("D", bases=["B", "C"]), + schema.Class("C", bases=["A"], derived={"D"}), + schema.Class("B", bases=["A"], derived={"D"}), + schema.Class("A", derived={"B", "C"}), + ] + ) == { + "A.qll": ( + a_ql_class_public(name="A"), + a_ql_stub(name="A"), + a_ql_class(name="A", imports=[stub_import_prefix + "A"]), + ), + "B.qll": ( + a_ql_class_public(name="B", imports=[stub_import_prefix + "A"]), + a_ql_stub(name="B"), + a_ql_class( + name="B", + bases=["A"], + bases_impl=["AImpl::A"], + imports=[stub_import_prefix_internal + "AImpl::Impl as AImpl"], + ), + ), + "C.qll": ( + a_ql_class_public(name="C", imports=[stub_import_prefix + "A"]), + a_ql_stub(name="C"), + a_ql_class( + name="C", + bases=["A"], + bases_impl=["AImpl::A"], + imports=[stub_import_prefix_internal + "AImpl::Impl as AImpl"], + ), + ), + "D.qll": ( + a_ql_class_public( + name="D", imports=[stub_import_prefix + "B", stub_import_prefix + "C"] + ), + a_ql_stub(name="D"), + a_ql_class( + name="D", + final=True, + bases=["B", "C"], + bases_impl=["BImpl::B", "CImpl::C"], + imports=[ + stub_import_prefix_internal + cls + "Impl::Impl as " + cls + "Impl" + for cls in "BC" + ], + ), + ), } def test_hierarchy_imports(generate_import_list): - assert generate_import_list([ - schema.Class("D", bases=["B", "C"]), - schema.Class("C", bases=["A"], derived={"D"}), - schema.Class("B", bases=["A"], derived={"D"}), - schema.Class("A", derived={"B", "C"}), - ]) == ql.ImportList([stub_import_prefix + cls for cls in "ABCD"]) + assert generate_import_list( + [ + schema.Class("D", bases=["B", "C"]), + schema.Class("C", bases=["A"], derived={"D"}), + schema.Class("B", bases=["A"], derived={"D"}), + schema.Class("A", derived={"B", "C"}), + ] + ) == ql.ImportList([stub_import_prefix + cls for cls in "ABCD"]) def test_internal_not_in_import_list(generate_import_list): - assert generate_import_list([ - schema.Class("D", bases=["B", "C"]), - schema.Class("C", bases=["A"], derived={"D"}, pragmas=["ql_internal"]), - schema.Class("B", bases=["A"], derived={"D"}), - schema.Class("A", derived={"B", "C"}, pragmas=["ql_internal"]), - ]) == ql.ImportList([stub_import_prefix + cls for cls in "BD"]) + assert generate_import_list( + [ + schema.Class("D", bases=["B", "C"]), + schema.Class("C", bases=["A"], derived={"D"}, pragmas=["ql_internal"]), + schema.Class("B", bases=["A"], derived={"D"}), + schema.Class("A", derived={"B", "C"}, pragmas=["ql_internal"]), + ] + ) == ql.ImportList([stub_import_prefix + cls for cls in "BD"]) def test_hierarchy_children(generate_children_implementations): - assert generate_children_implementations([ - schema.Class("A", derived={"B", "C"}, pragmas=["ql_internal"]), - schema.Class("B", bases=["A"], derived={"D"}), - schema.Class("C", bases=["A"], derived={"D"}, pragmas=["ql_internal"]), - schema.Class("D", bases=["B", "C"]), - ]) == ql.GetParentImplementation( - classes=[a_ql_class(name="A", internal=True, imports=[stub_import_prefix_internal + "A"]), - a_ql_class(name="B", bases=["A"], bases_impl=["AImpl::A"], imports=[ - stub_import_prefix_internal + "AImpl::Impl as AImpl"]), - a_ql_class(name="C", bases=["A"], bases_impl=["AImpl::A"], imports=[ - stub_import_prefix_internal + "AImpl::Impl as AImpl"], internal=True), - a_ql_class(name="D", final=True, bases=["B", "C"], bases_impl=["BImpl::B", "CImpl::C"], - imports=[stub_import_prefix_internal + cls + "Impl::Impl as " + cls + "Impl" for cls in "BC"]), - ], + assert generate_children_implementations( + [ + schema.Class("A", derived={"B", "C"}, pragmas=["ql_internal"]), + schema.Class("B", bases=["A"], derived={"D"}), + schema.Class("C", bases=["A"], derived={"D"}, pragmas=["ql_internal"]), + schema.Class("D", bases=["B", "C"]), + ] + ) == ql.GetParentImplementation( + classes=[ + a_ql_class( + name="A", internal=True, imports=[stub_import_prefix_internal + "A"] + ), + a_ql_class( + name="B", + bases=["A"], + bases_impl=["AImpl::A"], + imports=[stub_import_prefix_internal + "AImpl::Impl as AImpl"], + ), + a_ql_class( + name="C", + bases=["A"], + bases_impl=["AImpl::A"], + imports=[stub_import_prefix_internal + "AImpl::Impl as AImpl"], + internal=True, + ), + a_ql_class( + name="D", + final=True, + bases=["B", "C"], + bases_impl=["BImpl::B", "CImpl::C"], + imports=[ + stub_import_prefix_internal + cls + "Impl::Impl as " + cls + "Impl" + for cls in "BC" + ], + ), + ], imports=[stub_import] + [stub_import_prefix_internal + cls for cls in "AC"], ) def test_single_property(generate_classes): - assert generate_classes([ - schema.Class("MyObject", properties=[ - schema.SingleProperty("foo", "bar")]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, - properties=[ - ql.Property(singular="Foo", type="bar", tablename="my_objects", - tableparams=["this", "result"], doc="foo of this my object"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class("MyObject", properties=[schema.SingleProperty("foo", "bar")]), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + type="bar", + tablename="my_objects", + tableparams=["this", "result"], + doc="foo of this my object", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_internal_property(generate_classes): - assert generate_classes([ - schema.Class("MyObject", properties=[ - schema.SingleProperty("foo", "bar", pragmas=["ql_internal"])]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, - properties=[ - ql.Property(singular="Foo", type="bar", tablename="my_objects", - tableparams=["this", "result"], doc="foo of this my object", - internal=True), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class( + "MyObject", + properties=[ + schema.SingleProperty("foo", "bar", pragmas=["ql_internal"]) + ], + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + type="bar", + tablename="my_objects", + tableparams=["this", "result"], + doc="foo of this my object", + internal=True, + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_children(generate_classes): - assert generate_classes([ - schema.Class("FakeRoot"), - schema.Class("MyObject", properties=[ - schema.SingleProperty("a", "int"), - schema.SingleProperty("child_1", "int", is_child=True), - schema.RepeatedProperty("bs", "int"), - schema.RepeatedProperty("children", "int", is_child=True), - schema.OptionalProperty("c", "int"), - schema.OptionalProperty("child_3", "int", is_child=True), - schema.RepeatedOptionalProperty("d", "int"), - schema.RepeatedOptionalProperty("child_4", "int", is_child=True), - ]), - ]) == { - "FakeRoot.qll": (a_ql_class_public(name="FakeRoot"), a_ql_stub(name="FakeRoot"), a_ql_class(name="FakeRoot", final=True, imports=[stub_import_prefix + "FakeRoot"])), - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, - properties=[ - ql.Property(singular="A", type="int", tablename="my_objects", - tableparams=["this", "result", "_"], - doc="a of this my object"), - ql.Property(singular="Child1", type="int", tablename="my_objects", - tableparams=["this", "_", "result"], prev_child="", - doc="child 1 of this my object"), - ql.Property(singular="B", plural="Bs", type="int", - tablename="my_object_bs", - tableparams=["this", "index", "result"], - doc="b of this my object", - doc_plural="bs of this my object"), - ql.Property(singular="Child", plural="Children", type="int", - tablename="my_object_children", - tableparams=["this", "index", "result"], prev_child="Child1", - doc="child of this my object", - doc_plural="children of this my object"), - ql.Property(singular="C", type="int", tablename="my_object_cs", - tableparams=["this", "result"], is_optional=True, - doc="c of this my object"), - ql.Property(singular="Child3", type="int", - tablename="my_object_child_3s", - tableparams=["this", "result"], is_optional=True, - prev_child="Child", doc="child 3 of this my object"), - ql.Property(singular="D", plural="Ds", type="int", - tablename="my_object_ds", - tableparams=["this", "index", "result"], is_optional=True, - doc="d of this my object", - doc_plural="ds of this my object"), - ql.Property(singular="Child4", plural="Child4s", type="int", - tablename="my_object_child_4s", - tableparams=["this", "index", "result"], is_optional=True, - prev_child="Child3", doc="child 4 of this my object", - doc_plural="child 4s of this my object"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class("FakeRoot"), + schema.Class( + "MyObject", + properties=[ + schema.SingleProperty("a", "int"), + schema.SingleProperty("child_1", "int", is_child=True), + schema.RepeatedProperty("bs", "int"), + schema.RepeatedProperty("children", "int", is_child=True), + schema.OptionalProperty("c", "int"), + schema.OptionalProperty("child_3", "int", is_child=True), + schema.RepeatedOptionalProperty("d", "int"), + schema.RepeatedOptionalProperty("child_4", "int", is_child=True), + ], + ), + ] + ) == { + "FakeRoot.qll": ( + a_ql_class_public(name="FakeRoot"), + a_ql_stub(name="FakeRoot"), + a_ql_class( + name="FakeRoot", final=True, imports=[stub_import_prefix + "FakeRoot"] + ), + ), + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="A", + type="int", + tablename="my_objects", + tableparams=["this", "result", "_"], + doc="a of this my object", + ), + ql.Property( + singular="Child1", + type="int", + tablename="my_objects", + tableparams=["this", "_", "result"], + prev_child="", + doc="child 1 of this my object", + ), + ql.Property( + singular="B", + plural="Bs", + type="int", + tablename="my_object_bs", + tableparams=["this", "index", "result"], + doc="b of this my object", + doc_plural="bs of this my object", + ), + ql.Property( + singular="Child", + plural="Children", + type="int", + tablename="my_object_children", + tableparams=["this", "index", "result"], + prev_child="Child1", + doc="child of this my object", + doc_plural="children of this my object", + ), + ql.Property( + singular="C", + type="int", + tablename="my_object_cs", + tableparams=["this", "result"], + is_optional=True, + doc="c of this my object", + ), + ql.Property( + singular="Child3", + type="int", + tablename="my_object_child_3s", + tableparams=["this", "result"], + is_optional=True, + prev_child="Child", + doc="child 3 of this my object", + ), + ql.Property( + singular="D", + plural="Ds", + type="int", + tablename="my_object_ds", + tableparams=["this", "index", "result"], + is_optional=True, + doc="d of this my object", + doc_plural="ds of this my object", + ), + ql.Property( + singular="Child4", + plural="Child4s", + type="int", + tablename="my_object_child_4s", + tableparams=["this", "index", "result"], + is_optional=True, + prev_child="Child3", + doc="child 4 of this my object", + doc_plural="child 4s of this my object", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_single_properties(generate_classes): - assert generate_classes([ - schema.Class("MyObject", properties=[ - schema.SingleProperty("one", "x"), - schema.SingleProperty("two", "y"), - schema.SingleProperty("three", "z"), - ]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, - properties=[ - ql.Property(singular="One", type="x", tablename="my_objects", - tableparams=["this", "result", "_", "_"], - doc="one of this my object"), - ql.Property(singular="Two", type="y", tablename="my_objects", - tableparams=["this", "_", "result", "_"], - doc="two of this my object"), - ql.Property(singular="Three", type="z", tablename="my_objects", - tableparams=["this", "_", "_", "result"], - doc="three of this my object"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class( + "MyObject", + properties=[ + schema.SingleProperty("one", "x"), + schema.SingleProperty("two", "y"), + schema.SingleProperty("three", "z"), + ], + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="One", + type="x", + tablename="my_objects", + tableparams=["this", "result", "_", "_"], + doc="one of this my object", + ), + ql.Property( + singular="Two", + type="y", + tablename="my_objects", + tableparams=["this", "_", "result", "_"], + doc="two of this my object", + ), + ql.Property( + singular="Three", + type="z", + tablename="my_objects", + tableparams=["this", "_", "_", "result"], + doc="three of this my object", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } @pytest.mark.parametrize("is_child,prev_child", [(False, None), (True, "")]) def test_optional_property(generate_classes, is_child, prev_child): - assert generate_classes([ - schema.Class("FakeRoot"), - schema.Class("MyObject", properties=[ - schema.OptionalProperty("foo", "bar", is_child=is_child)]), - ]) == { - "FakeRoot.qll": (a_ql_class_public(name="FakeRoot"), a_ql_stub(name="FakeRoot"), a_ql_class(name="FakeRoot", final=True, imports=[stub_import_prefix + "FakeRoot"])), - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, properties=[ - ql.Property(singular="Foo", type="bar", tablename="my_object_foos", - tableparams=["this", "result"], - is_optional=True, prev_child=prev_child, doc="foo of this my object"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class("FakeRoot"), + schema.Class( + "MyObject", + properties=[schema.OptionalProperty("foo", "bar", is_child=is_child)], + ), + ] + ) == { + "FakeRoot.qll": ( + a_ql_class_public(name="FakeRoot"), + a_ql_stub(name="FakeRoot"), + a_ql_class( + name="FakeRoot", final=True, imports=[stub_import_prefix + "FakeRoot"] + ), + ), + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + type="bar", + tablename="my_object_foos", + tableparams=["this", "result"], + is_optional=True, + prev_child=prev_child, + doc="foo of this my object", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } @pytest.mark.parametrize("is_child,prev_child", [(False, None), (True, "")]) def test_repeated_property(generate_classes, is_child, prev_child): - assert generate_classes([ - schema.Class("FakeRoot"), - schema.Class("MyObject", properties=[ - schema.RepeatedProperty("foo", "bar", is_child=is_child)]), - ]) == { - "FakeRoot.qll": (a_ql_class_public(name="FakeRoot"), a_ql_stub(name="FakeRoot"), a_ql_class(name="FakeRoot", final=True, imports=[stub_import_prefix + "FakeRoot"])), - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, properties=[ - ql.Property(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos", - tableparams=["this", "index", "result"], prev_child=prev_child, - doc="foo of this my object", doc_plural="foos of this my object"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class("FakeRoot"), + schema.Class( + "MyObject", + properties=[schema.RepeatedProperty("foo", "bar", is_child=is_child)], + ), + ] + ) == { + "FakeRoot.qll": ( + a_ql_class_public(name="FakeRoot"), + a_ql_stub(name="FakeRoot"), + a_ql_class( + name="FakeRoot", final=True, imports=[stub_import_prefix + "FakeRoot"] + ), + ), + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + plural="Foos", + type="bar", + tablename="my_object_foos", + tableparams=["this", "index", "result"], + prev_child=prev_child, + doc="foo of this my object", + doc_plural="foos of this my object", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_repeated_unordered_property(generate_classes): - assert generate_classes([ - schema.Class("FakeRoot"), - schema.Class("MyObject", properties=[ - schema.RepeatedUnorderedProperty("foo", "bar")]), - ]) == { - "FakeRoot.qll": (a_ql_class_public(name="FakeRoot"), a_ql_stub(name="FakeRoot"), a_ql_class(name="FakeRoot", final=True, imports=[stub_import_prefix + "FakeRoot"])), - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, properties=[ - ql.Property(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos", - tableparams=["this", "result"], is_unordered=True, - doc="foo of this my object", doc_plural="foos of this my object"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class("FakeRoot"), + schema.Class( + "MyObject", properties=[schema.RepeatedUnorderedProperty("foo", "bar")] + ), + ] + ) == { + "FakeRoot.qll": ( + a_ql_class_public(name="FakeRoot"), + a_ql_stub(name="FakeRoot"), + a_ql_class( + name="FakeRoot", final=True, imports=[stub_import_prefix + "FakeRoot"] + ), + ), + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + plural="Foos", + type="bar", + tablename="my_object_foos", + tableparams=["this", "result"], + is_unordered=True, + doc="foo of this my object", + doc_plural="foos of this my object", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } @pytest.mark.parametrize("is_child,prev_child", [(False, None), (True, "")]) def test_repeated_optional_property(generate_classes, is_child, prev_child): - assert generate_classes([ - schema.Class("FakeRoot"), - schema.Class("MyObject", properties=[ - schema.RepeatedOptionalProperty("foo", "bar", is_child=is_child)]), - ]) == { - - "FakeRoot.qll": (a_ql_class_public(name="FakeRoot"), a_ql_stub(name="FakeRoot"), a_ql_class(name="FakeRoot", final=True, imports=[stub_import_prefix + "FakeRoot"])), - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, properties=[ - ql.Property(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos", - tableparams=["this", "index", "result"], is_optional=True, - prev_child=prev_child, doc="foo of this my object", - doc_plural="foos of this my object"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class("FakeRoot"), + schema.Class( + "MyObject", + properties=[ + schema.RepeatedOptionalProperty("foo", "bar", is_child=is_child) + ], + ), + ] + ) == { + "FakeRoot.qll": ( + a_ql_class_public(name="FakeRoot"), + a_ql_stub(name="FakeRoot"), + a_ql_class( + name="FakeRoot", final=True, imports=[stub_import_prefix + "FakeRoot"] + ), + ), + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + plural="Foos", + type="bar", + tablename="my_object_foos", + tableparams=["this", "index", "result"], + is_optional=True, + prev_child=prev_child, + doc="foo of this my object", + doc_plural="foos of this my object", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_predicate_property(generate_classes): - assert generate_classes([ - schema.Class("MyObject", properties=[ - schema.PredicateProperty("is_foo")]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, properties=[ - ql.Property(singular="isFoo", type="predicate", tablename="my_object_is_foo", - tableparams=["this"], is_predicate=True, doc="this my object is foo"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class("MyObject", properties=[schema.PredicateProperty("is_foo")]), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="isFoo", + type="predicate", + tablename="my_object_is_foo", + tableparams=["this"], + is_predicate=True, + doc="this my object is foo", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } @pytest.mark.parametrize("is_child,prev_child", [(False, None), (True, "")]) def test_single_class_property(generate_classes, is_child, prev_child): - assert generate_classes([ - schema.Class("Bar"), - schema.Class("MyObject", properties=[ - schema.SingleProperty("foo", "Bar", is_child=is_child)]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject", imports=[stub_import_prefix + "Bar"]), - a_ql_stub(name="MyObject"), - a_ql_class( - name="MyObject", final=True, imports=[stub_import_prefix + "Bar", stub_import_prefix + "MyObject"], properties=[ - ql.Property(singular="Foo", type="Bar", tablename="my_objects", - tableparams=[ - "this", "result"], - prev_child=prev_child, doc="foo of this my object", - type_is_codegen_class=True), - ], - )), - "Bar.qll": (a_ql_class_public(name="Bar"), a_ql_stub(name="Bar"), a_ql_class(name="Bar", final=True, imports=[stub_import_prefix + "Bar"])), + assert generate_classes( + [ + schema.Class("Bar"), + schema.Class( + "MyObject", + properties=[schema.SingleProperty("foo", "Bar", is_child=is_child)], + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject", imports=[stub_import_prefix + "Bar"]), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + imports=[stub_import_prefix + "Bar", stub_import_prefix + "MyObject"], + properties=[ + ql.Property( + singular="Foo", + type="Bar", + tablename="my_objects", + tableparams=["this", "result"], + prev_child=prev_child, + doc="foo of this my object", + type_is_codegen_class=True, + ), + ], + ), + ), + "Bar.qll": ( + a_ql_class_public(name="Bar"), + a_ql_stub(name="Bar"), + a_ql_class(name="Bar", final=True, imports=[stub_import_prefix + "Bar"]), + ), } def test_class_with_doc(generate_classes): doc = ["Very important class.", "Very."] - assert generate_classes([ - schema.Class("A", doc=doc), - ]) == { - "A.qll": (a_ql_class_public(name="A", doc=doc), a_ql_stub(name="A", doc=doc), a_ql_class(name="A", final=True, doc=doc, imports=[stub_import_prefix + "A"])), + assert generate_classes( + [ + schema.Class("A", doc=doc), + ] + ) == { + "A.qll": ( + a_ql_class_public(name="A", doc=doc), + a_ql_stub(name="A", doc=doc), + a_ql_class( + name="A", final=True, doc=doc, imports=[stub_import_prefix + "A"] + ), + ), } def test_class_dir(generate_classes): dir = "another/rel/path" - assert generate_classes([ - schema.Class("A", derived={"B"}, pragmas={"group": dir}), - schema.Class("B", bases=["A"]), - ]) == { + assert generate_classes( + [ + schema.Class("A", derived={"B"}, pragmas={"group": dir}), + schema.Class("B", bases=["A"]), + ] + ) == { f"{dir}/A.qll": ( - a_ql_class_public(name="A"), a_ql_stub(name="A", import_prefix="another.rel.path."), a_ql_class(name="A", dir=pathlib.Path(dir), imports=[stub_import_prefix + "another.rel.path.A"])), - "B.qll": (a_ql_class_public(name="B", imports=[stub_import_prefix + "another.rel.path.A"]), - a_ql_stub(name="B"), - a_ql_class(name="B", final=True, bases=["A"], bases_impl=["AImpl::A"], - imports=[stub_import_prefix + "another.rel.path.internal.AImpl::Impl as AImpl"])), + a_ql_class_public(name="A"), + a_ql_stub(name="A", import_prefix="another.rel.path."), + a_ql_class( + name="A", + dir=pathlib.Path(dir), + imports=[stub_import_prefix + "another.rel.path.A"], + ), + ), + "B.qll": ( + a_ql_class_public( + name="B", imports=[stub_import_prefix + "another.rel.path.A"] + ), + a_ql_stub(name="B"), + a_ql_class( + name="B", + final=True, + bases=["A"], + bases_impl=["AImpl::A"], + imports=[ + stub_import_prefix + + "another.rel.path.internal.AImpl::Impl as AImpl" + ], + ), + ), } def test_root_element_cannot_have_children(generate_classes): with pytest.raises(qlgen.RootElementHasChildren): - generate_classes([ - schema.Class('A', properties=[schema.SingleProperty("x", is_child=True)]) - ]) + generate_classes( + [schema.Class("A", properties=[schema.SingleProperty("x", is_child=True)])] + ) def test_class_dir_imports(generate_import_list): dir = "another/rel/path" - assert generate_import_list([ - schema.Class("A", derived={"B"}, pragmas={"group": dir}), - schema.Class("B", bases=["A"]), - ]) == ql.ImportList([ - stub_import_prefix + "B", - stub_import_prefix + "another.rel.path.A", - ]) + assert generate_import_list( + [ + schema.Class("A", derived={"B"}, pragmas={"group": dir}), + schema.Class("B", bases=["A"]), + ] + ) == ql.ImportList( + [ + stub_import_prefix + "B", + stub_import_prefix + "another.rel.path.A", + ] + ) def test_format(opts, generate, render_manager, run_mock): @@ -507,10 +864,21 @@ def test_format(opts, generate, render_manager, run_mock): pathlib.Path("bar.qll"), pathlib.Path("y", "baz.txt"), ] - generate([schema.Class('A')]) + generate([schema.Class("A")]) assert run_mock.mock_calls == [ - mock.call([opts.codeql_binary, "query", "format", "--in-place", "--", "x/foo.ql", "bar.qll"], - stderr=subprocess.PIPE, text=True), + mock.call( + [ + opts.codeql_binary, + "query", + "format", + "--in-place", + "--", + "x/foo.ql", + "bar.qll", + ], + stderr=subprocess.PIPE, + text=True, + ), ] @@ -523,7 +891,7 @@ def test_format_error(opts, generate, render_manager, run_mock): pathlib.Path("y", "baz.txt"), ] with pytest.raises(qlgen.FormatError): - generate([schema.Class('A')]) + generate([schema.Class("A")]) def test_format_no_codeql(opts, generate, render_manager, run_mock): @@ -532,7 +900,7 @@ def test_format_no_codeql(opts, generate, render_manager, run_mock): pathlib.Path("bar.qll"), ] with pytest.raises(qlgen.FormatError): - generate([schema.Class('A')]) + generate([schema.Class("A")]) def test_format_no_codeql_in_path(opts, generate, render_manager, run_mock): @@ -541,7 +909,7 @@ def test_format_no_codeql_in_path(opts, generate, render_manager, run_mock): pathlib.Path("bar.qll"), ] with pytest.raises(qlgen.FormatError): - generate([schema.Class('A')]) + generate([schema.Class("A")]) @pytest.mark.parametrize("force", [False, True]) @@ -561,23 +929,29 @@ def test_manage_parameters(opts, generate, renderer, force): write(test_a) write(test_b) write(test_c) - generate([schema.Class('A')]) + generate([schema.Class("A")]) assert renderer.mock_calls == [ - mock.call.manage(generated={ql_a, ql_b, test_a, test_b, import_file()}, stubs={stub_a, stub_b}, - registry=opts.generated_registry, force=force) + mock.call.manage( + generated={ql_a, ql_b, test_a, test_b, import_file()}, + stubs={stub_a, stub_b}, + registry=opts.generated_registry, + force=force, + ) ] def test_modified_stub_skipped(qlgen_opts, generate, render_manager): stub = qlgen_opts.ql_stub_output / "AImpl.qll" render_manager.is_customized_stub.side_effect = lambda f: f == stub - assert stub not in generate([schema.Class('A')]) + assert stub not in generate([schema.Class("A")]) def test_test_missing_source(generate_tests): - generate_tests([ - schema.Class("A"), - ]) == { + generate_tests( + [ + schema.Class("A"), + ] + ) == { "A/MISSING_SOURCE.txt": ql.MissingTestInstructions(), } @@ -592,144 +966,236 @@ def a_ql_property_tester(**kwargs): def test_test_source_present(opts, generate_tests): write(opts.ql_test_output / "A" / "test.swift") - assert generate_tests([ - schema.Class("A"), - ]) == { + assert generate_tests( + [ + schema.Class("A"), + ] + ) == { "A/A.ql": a_ql_class_tester(class_name="A"), } def test_test_source_present_with_dir(opts, generate_tests): write(opts.ql_test_output / "foo" / "A" / "test.swift") - assert generate_tests([ - schema.Class("A", pragmas={"group": "foo"}), - ]) == { + assert generate_tests( + [ + schema.Class("A", pragmas={"group": "foo"}), + ] + ) == { "foo/A/A.ql": a_ql_class_tester(class_name="A"), } def test_test_total_properties(opts, generate_tests): write(opts.ql_test_output / "B" / "test.swift") - assert generate_tests([ - schema.Class("A", derived={"B"}, properties=[ - schema.SingleProperty("x", "string"), - ]), - schema.Class("B", bases=["A"], properties=[ - schema.PredicateProperty("y", "int"), - ]), - ]) == { - "B/B.ql": a_ql_class_tester(class_name="B", properties=[ - ql.PropertyForTest(getter="getX", type="string"), - ql.PropertyForTest(getter="y"), - ]) + assert generate_tests( + [ + schema.Class( + "A", + derived={"B"}, + properties=[ + schema.SingleProperty("x", "string"), + ], + ), + schema.Class( + "B", + bases=["A"], + properties=[ + schema.PredicateProperty("y", "int"), + ], + ), + ] + ) == { + "B/B.ql": a_ql_class_tester( + class_name="B", + properties=[ + ql.PropertyForTest(getter="getX", type="string"), + ql.PropertyForTest(getter="y"), + ], + ) } def test_test_partial_properties(opts, generate_tests): write(opts.ql_test_output / "B" / "test.swift") - assert generate_tests([ - schema.Class("A", derived={"B", "C"}, properties=[ - schema.OptionalProperty("x", "string"), - ]), - schema.Class("B", bases=["A"], properties=[ - schema.RepeatedProperty("y", "bool"), - schema.RepeatedOptionalProperty("z", "int"), - schema.RepeatedUnorderedProperty("w", "string"), - ]), - ]) == { - "B/B.ql": a_ql_class_tester(class_name="B", properties=[ - ql.PropertyForTest(getter="hasX"), - ql.PropertyForTest(getter="getNumberOfYs", type="int"), - ql.PropertyForTest(getter="getNumberOfWs", type="int"), - ]), - "B/B_getX.ql": a_ql_property_tester(class_name="B", - property=ql.PropertyForTest(getter="getX", is_total=False, - type="string")), - "B/B_getY.ql": a_ql_property_tester(class_name="B", - property=ql.PropertyForTest(getter="getY", is_total=False, - is_indexed=True, - type="bool")), - "B/B_getZ.ql": a_ql_property_tester(class_name="B", - property=ql.PropertyForTest(getter="getZ", is_total=False, - is_indexed=True, - type="int")), - "B/B_getAW.ql": a_ql_property_tester(class_name="B", - property=ql.PropertyForTest(getter="getAW", is_total=False, - type="string")), + assert generate_tests( + [ + schema.Class( + "A", + derived={"B", "C"}, + properties=[ + schema.OptionalProperty("x", "string"), + ], + ), + schema.Class( + "B", + bases=["A"], + properties=[ + schema.RepeatedProperty("y", "bool"), + schema.RepeatedOptionalProperty("z", "int"), + schema.RepeatedUnorderedProperty("w", "string"), + ], + ), + ] + ) == { + "B/B.ql": a_ql_class_tester( + class_name="B", + properties=[ + ql.PropertyForTest(getter="hasX"), + ql.PropertyForTest(getter="getNumberOfYs", type="int"), + ql.PropertyForTest(getter="getNumberOfWs", type="int"), + ], + ), + "B/B_getX.ql": a_ql_property_tester( + class_name="B", + property=ql.PropertyForTest(getter="getX", is_total=False, type="string"), + ), + "B/B_getY.ql": a_ql_property_tester( + class_name="B", + property=ql.PropertyForTest( + getter="getY", is_total=False, is_indexed=True, type="bool" + ), + ), + "B/B_getZ.ql": a_ql_property_tester( + class_name="B", + property=ql.PropertyForTest( + getter="getZ", is_total=False, is_indexed=True, type="int" + ), + ), + "B/B_getAW.ql": a_ql_property_tester( + class_name="B", + property=ql.PropertyForTest(getter="getAW", is_total=False, type="string"), + ), } def test_test_properties_deduplicated(opts, generate_tests): write(opts.ql_test_output / "Final" / "test.swift") - assert generate_tests([ - schema.Class("Base", derived={"A", "B"}, properties=[ - schema.SingleProperty("x", "string"), - schema.RepeatedProperty("y", "bool"), - ]), - schema.Class("A", bases=["Base"], derived={"Final"}), - schema.Class("B", bases=["Base"], derived={"Final"}), - schema.Class("Final", bases=["A", "B"]), - ]) == { - "Final/Final.ql": a_ql_class_tester(class_name="Final", properties=[ - ql.PropertyForTest(getter="getX", type="string"), - ql.PropertyForTest(getter="getNumberOfYs", type="int"), - ]), - "Final/Final_getY.ql": a_ql_property_tester(class_name="Final", - property=ql.PropertyForTest(getter="getY", is_total=False, - is_indexed=True, - type="bool")), + assert generate_tests( + [ + schema.Class( + "Base", + derived={"A", "B"}, + properties=[ + schema.SingleProperty("x", "string"), + schema.RepeatedProperty("y", "bool"), + ], + ), + schema.Class("A", bases=["Base"], derived={"Final"}), + schema.Class("B", bases=["Base"], derived={"Final"}), + schema.Class("Final", bases=["A", "B"]), + ] + ) == { + "Final/Final.ql": a_ql_class_tester( + class_name="Final", + properties=[ + ql.PropertyForTest(getter="getX", type="string"), + ql.PropertyForTest(getter="getNumberOfYs", type="int"), + ], + ), + "Final/Final_getY.ql": a_ql_property_tester( + class_name="Final", + property=ql.PropertyForTest( + getter="getY", is_total=False, is_indexed=True, type="bool" + ), + ), } def test_test_properties_skipped(opts, generate_tests): write(opts.ql_test_output / "Derived" / "test.swift") - assert generate_tests([ - schema.Class("Base", derived={"Derived"}, properties=[ - schema.SingleProperty("x", "string", pragmas=["qltest_skip", "foo"]), - schema.RepeatedProperty("y", "int", pragmas=["bar", "qltest_skip"]), - ]), - schema.Class("Derived", bases=["Base"], properties=[ - schema.PredicateProperty("a", pragmas=["qltest_skip"]), - schema.OptionalProperty( - "b", "int", pragmas=["bar", "qltest_skip", "baz"]), - ]), - ]) == { + assert generate_tests( + [ + schema.Class( + "Base", + derived={"Derived"}, + properties=[ + schema.SingleProperty( + "x", "string", pragmas=["qltest_skip", "foo"] + ), + schema.RepeatedProperty("y", "int", pragmas=["bar", "qltest_skip"]), + ], + ), + schema.Class( + "Derived", + bases=["Base"], + properties=[ + schema.PredicateProperty("a", pragmas=["qltest_skip"]), + schema.OptionalProperty( + "b", "int", pragmas=["bar", "qltest_skip", "baz"] + ), + ], + ), + ] + ) == { "Derived/Derived.ql": a_ql_class_tester(class_name="Derived"), } def test_test_base_class_skipped(opts, generate_tests): write(opts.ql_test_output / "Derived" / "test.swift") - assert generate_tests([ - schema.Class("Base", derived={"Derived"}, pragmas=["qltest_skip", "foo"], properties=[ - schema.SingleProperty("x", "string"), - schema.RepeatedProperty("y", "int"), - ]), - schema.Class("Derived", bases=["Base"]), - ]) == { + assert generate_tests( + [ + schema.Class( + "Base", + derived={"Derived"}, + pragmas=["qltest_skip", "foo"], + properties=[ + schema.SingleProperty("x", "string"), + schema.RepeatedProperty("y", "int"), + ], + ), + schema.Class("Derived", bases=["Base"]), + ] + ) == { "Derived/Derived.ql": a_ql_class_tester(class_name="Derived"), } def test_test_final_class_skipped(opts, generate_tests): write(opts.ql_test_output / "Derived" / "test.swift") - assert generate_tests([ - schema.Class("Base", derived={"Derived"}), - schema.Class("Derived", bases=["Base"], pragmas=["qltest_skip", "foo"], properties=[ - schema.SingleProperty("x", "string"), - schema.RepeatedProperty("y", "int"), - ]), - ]) == {} + assert ( + generate_tests( + [ + schema.Class("Base", derived={"Derived"}), + schema.Class( + "Derived", + bases=["Base"], + pragmas=["qltest_skip", "foo"], + properties=[ + schema.SingleProperty("x", "string"), + schema.RepeatedProperty("y", "int"), + ], + ), + ] + ) + == {} + ) def test_test_class_hierarchy_collapse(opts, generate_tests): write(opts.ql_test_output / "Base" / "test.swift") - assert generate_tests([ - schema.Class("Base", derived={"D1", "D2"}, pragmas=["foo", "qltest_collapse_hierarchy"]), - schema.Class("D1", bases=["Base"], properties=[schema.SingleProperty("x", "string")]), - schema.Class("D2", bases=["Base"], derived={"D3"}, properties=[schema.SingleProperty("y", "string")]), - schema.Class("D3", bases=["D2"], properties=[schema.SingleProperty("z", "string")]), - ]) == { + assert generate_tests( + [ + schema.Class( + "Base", + derived={"D1", "D2"}, + pragmas=["foo", "qltest_collapse_hierarchy"], + ), + schema.Class( + "D1", bases=["Base"], properties=[schema.SingleProperty("x", "string")] + ), + schema.Class( + "D2", + bases=["Base"], + derived={"D3"}, + properties=[schema.SingleProperty("y", "string")], + ), + schema.Class( + "D3", bases=["D2"], properties=[schema.SingleProperty("z", "string")] + ), + ] + ) == { "Base/Base.ql": a_ql_class_tester(class_name="Base", show_ql_class=True), } @@ -737,13 +1203,26 @@ def test_test_class_hierarchy_collapse(opts, generate_tests): def test_test_class_hierarchy_uncollapse(opts, generate_tests): for d in ("Base", "D3", "D4"): write(opts.ql_test_output / d / "test.swift") - assert generate_tests([ - schema.Class("Base", derived={"D1", "D2"}, pragmas=["foo", "qltest_collapse_hierarchy"]), - schema.Class("D1", bases=["Base"], properties=[schema.SingleProperty("x", "string")]), - schema.Class("D2", bases=["Base"], derived={"D3", "D4"}, pragmas=["qltest_uncollapse_hierarchy", "bar"]), - schema.Class("D3", bases=["D2"]), - schema.Class("D4", bases=["D2"]), - ]) == { + assert generate_tests( + [ + schema.Class( + "Base", + derived={"D1", "D2"}, + pragmas=["foo", "qltest_collapse_hierarchy"], + ), + schema.Class( + "D1", bases=["Base"], properties=[schema.SingleProperty("x", "string")] + ), + schema.Class( + "D2", + bases=["Base"], + derived={"D3", "D4"}, + pragmas=["qltest_uncollapse_hierarchy", "bar"], + ), + schema.Class("D3", bases=["D2"]), + schema.Class("D4", bases=["D2"]), + ] + ) == { "Base/Base.ql": a_ql_class_tester(class_name="Base", show_ql_class=True), "D3/D3.ql": a_ql_class_tester(class_name="D3"), "D4/D4.ql": a_ql_class_tester(class_name="D4"), @@ -753,12 +1232,22 @@ def test_test_class_hierarchy_uncollapse(opts, generate_tests): def test_test_class_hierarchy_uncollapse_at_final(opts, generate_tests): for d in ("Base", "D3"): write(opts.ql_test_output / d / "test.swift") - assert generate_tests([ - schema.Class("Base", derived={"D1", "D2"}, pragmas=["foo", "qltest_collapse_hierarchy"]), - schema.Class("D1", bases=["Base"], properties=[schema.SingleProperty("x", "string")]), - schema.Class("D2", bases=["Base"], derived={"D3"}), - schema.Class("D3", bases=["D2"], pragmas=["qltest_uncollapse_hierarchy", "bar"]), - ]) == { + assert generate_tests( + [ + schema.Class( + "Base", + derived={"D1", "D2"}, + pragmas=["foo", "qltest_collapse_hierarchy"], + ), + schema.Class( + "D1", bases=["Base"], properties=[schema.SingleProperty("x", "string")] + ), + schema.Class("D2", bases=["Base"], derived={"D3"}), + schema.Class( + "D3", bases=["D2"], pragmas=["qltest_uncollapse_hierarchy", "bar"] + ), + ] + ) == { "Base/Base.ql": a_ql_class_tester(class_name="Base", show_ql_class=True), "D3/D3.ql": a_ql_class_tester(class_name="D3"), } @@ -766,11 +1255,13 @@ def test_test_class_hierarchy_uncollapse_at_final(opts, generate_tests): def test_test_with(opts, generate_tests): write(opts.ql_test_output / "B" / "test.swift") - assert generate_tests([ - schema.Class("Base", derived={"A", "B"}), - schema.Class("A", bases=["Base"], pragmas={"qltest_test_with": "B"}), - schema.Class("B", bases=["Base"]), - ]) == { + assert generate_tests( + [ + schema.Class("Base", derived={"A", "B"}), + schema.Class("A", bases=["Base"], pragmas={"qltest_test_with": "B"}), + schema.Class("B", bases=["Base"]), + ] + ) == { "B/A.ql": a_ql_class_tester(class_name="A"), "B/B.ql": a_ql_class_tester(class_name="B"), } @@ -778,291 +1269,605 @@ def test_test_with(opts, generate_tests): def test_property_description(generate_classes): description = ["Lorem", "Ipsum"] - assert generate_classes([ - schema.Class("MyObject", properties=[ - schema.SingleProperty("foo", "bar", description=description), - ]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, - properties=[ - ql.Property(singular="Foo", type="bar", tablename="my_objects", - tableparams=["this", "result"], - doc="foo of this my object", - description=description), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class( + "MyObject", + properties=[ + schema.SingleProperty("foo", "bar", description=description), + ], + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + type="bar", + tablename="my_objects", + tableparams=["this", "result"], + doc="foo of this my object", + description=description, + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_property_doc_override(generate_classes): - assert generate_classes([ - schema.Class("MyObject", properties=[ - schema.SingleProperty("foo", "bar", doc="baz")]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, - properties=[ - ql.Property(singular="Foo", type="bar", tablename="my_objects", - tableparams=["this", "result"], doc="baz"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class( + "MyObject", properties=[schema.SingleProperty("foo", "bar", doc="baz")] + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + type="bar", + tablename="my_objects", + tableparams=["this", "result"], + doc="baz", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_repeated_property_doc_override(generate_classes): - assert generate_classes([ - schema.Class("MyObject", properties=[ - schema.RepeatedProperty("x", "int", doc="children of this"), - schema.RepeatedOptionalProperty("y", "int", doc="child of this")]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, - properties=[ - ql.Property(singular="X", plural="Xes", type="int", - tablename="my_object_xes", - tableparams=["this", "index", "result"], - doc="child of this", doc_plural="children of this"), - ql.Property(singular="Y", plural="Ys", type="int", - tablename="my_object_ies", is_optional=True, - tableparams=["this", "index", "result"], - doc="child of this", doc_plural="children of this"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class( + "MyObject", + properties=[ + schema.RepeatedProperty("x", "int", doc="children of this"), + schema.RepeatedOptionalProperty("y", "int", doc="child of this"), + ], + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="X", + plural="Xes", + type="int", + tablename="my_object_xes", + tableparams=["this", "index", "result"], + doc="child of this", + doc_plural="children of this", + ), + ql.Property( + singular="Y", + plural="Ys", + type="int", + tablename="my_object_ies", + is_optional=True, + tableparams=["this", "index", "result"], + doc="child of this", + doc_plural="children of this", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } @pytest.mark.parametrize("abbr,expected", list(qlgen.abbreviations.items())) def test_property_doc_abbreviations(generate_classes, abbr, expected): expected_doc = f"foo {expected} bar of this object" - assert generate_classes([ - schema.Class("Object", properties=[ - schema.SingleProperty(f"foo_{abbr}_bar", "baz")]), - ]) == { - "Object.qll": (a_ql_class_public(name="Object"), - a_ql_stub(name="Object"), - a_ql_class(name="Object", final=True, - properties=[ - ql.Property(singular=f"Foo{abbr.capitalize()}Bar", type="baz", - tablename="objects", - tableparams=["this", "result"], doc=expected_doc), - ], - imports=[stub_import_prefix + "Object"])), + assert generate_classes( + [ + schema.Class( + "Object", properties=[schema.SingleProperty(f"foo_{abbr}_bar", "baz")] + ), + ] + ) == { + "Object.qll": ( + a_ql_class_public(name="Object"), + a_ql_stub(name="Object"), + a_ql_class( + name="Object", + final=True, + properties=[ + ql.Property( + singular=f"Foo{abbr.capitalize()}Bar", + type="baz", + tablename="objects", + tableparams=["this", "result"], + doc=expected_doc, + ), + ], + imports=[stub_import_prefix + "Object"], + ), + ), } @pytest.mark.parametrize("abbr,expected", list(qlgen.abbreviations.items())) -def test_property_doc_abbreviations_ignored_if_within_word(generate_classes, abbr, expected): +def test_property_doc_abbreviations_ignored_if_within_word( + generate_classes, abbr, expected +): expected_doc = f"foo {abbr}acadabra bar of this object" - assert generate_classes([ - schema.Class("Object", properties=[ - schema.SingleProperty(f"foo_{abbr}acadabra_bar", "baz")]), - ]) == { - "Object.qll": (a_ql_class_public(name="Object"), - a_ql_stub(name="Object"), - a_ql_class(name="Object", final=True, - properties=[ - ql.Property(singular=f"Foo{abbr.capitalize()}acadabraBar", type="baz", - tablename="objects", - tableparams=["this", "result"], doc=expected_doc), - ], - imports=[stub_import_prefix + "Object"])), + assert generate_classes( + [ + schema.Class( + "Object", + properties=[schema.SingleProperty(f"foo_{abbr}acadabra_bar", "baz")], + ), + ] + ) == { + "Object.qll": ( + a_ql_class_public(name="Object"), + a_ql_stub(name="Object"), + a_ql_class( + name="Object", + final=True, + properties=[ + ql.Property( + singular=f"Foo{abbr.capitalize()}acadabraBar", + type="baz", + tablename="objects", + tableparams=["this", "result"], + doc=expected_doc, + ), + ], + imports=[stub_import_prefix + "Object"], + ), + ), } def test_repeated_property_doc_override_with_format(generate_classes): - assert generate_classes([ - schema.Class("MyObject", properties=[ - schema.RepeatedProperty("x", "int", doc="special {children} of this"), - schema.RepeatedOptionalProperty("y", "int", doc="special {child} of this")]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, - properties=[ - ql.Property(singular="X", plural="Xes", type="int", - tablename="my_object_xes", - tableparams=["this", "index", "result"], - doc="special child of this", - doc_plural="special children of this"), - ql.Property(singular="Y", plural="Ys", type="int", - tablename="my_object_ies", is_optional=True, - tableparams=["this", "index", "result"], - doc="special child of this", - doc_plural="special children of this"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class( + "MyObject", + properties=[ + schema.RepeatedProperty( + "x", "int", doc="special {children} of this" + ), + schema.RepeatedOptionalProperty( + "y", "int", doc="special {child} of this" + ), + ], + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="X", + plural="Xes", + type="int", + tablename="my_object_xes", + tableparams=["this", "index", "result"], + doc="special child of this", + doc_plural="special children of this", + ), + ql.Property( + singular="Y", + plural="Ys", + type="int", + tablename="my_object_ies", + is_optional=True, + tableparams=["this", "index", "result"], + doc="special child of this", + doc_plural="special children of this", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_repeated_property_doc_override_with_multiple_formats(generate_classes): - assert generate_classes([ - schema.Class("MyObject", properties=[ - schema.RepeatedProperty("x", "int", doc="{cat} or {dog}"), - schema.RepeatedOptionalProperty("y", "int", doc="{cats} or {dogs}")]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, - properties=[ - ql.Property(singular="X", plural="Xes", type="int", - tablename="my_object_xes", - tableparams=["this", "index", "result"], - doc="cat or dog", doc_plural="cats or dogs"), - ql.Property(singular="Y", plural="Ys", type="int", - tablename="my_object_ies", is_optional=True, - tableparams=["this", "index", "result"], - doc="cat or dog", doc_plural="cats or dogs"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class( + "MyObject", + properties=[ + schema.RepeatedProperty("x", "int", doc="{cat} or {dog}"), + schema.RepeatedOptionalProperty("y", "int", doc="{cats} or {dogs}"), + ], + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="X", + plural="Xes", + type="int", + tablename="my_object_xes", + tableparams=["this", "index", "result"], + doc="cat or dog", + doc_plural="cats or dogs", + ), + ql.Property( + singular="Y", + plural="Ys", + type="int", + tablename="my_object_ies", + is_optional=True, + tableparams=["this", "index", "result"], + doc="cat or dog", + doc_plural="cats or dogs", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_property_doc_override_with_format(generate_classes): - assert generate_classes([ - schema.Class("MyObject", properties=[ - schema.SingleProperty("foo", "bar", doc="special {baz} of this")]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, - properties=[ - ql.Property(singular="Foo", type="bar", tablename="my_objects", - tableparams=["this", "result"], doc="special baz of this"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class( + "MyObject", + properties=[ + schema.SingleProperty("foo", "bar", doc="special {baz} of this") + ], + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + type="bar", + tablename="my_objects", + tableparams=["this", "result"], + doc="special baz of this", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_property_on_class_with_default_doc_name(generate_classes): - assert generate_classes([ - schema.Class("MyObject", properties=[ - schema.SingleProperty("foo", "bar")], - pragmas={"ql_default_doc_name": "baz"}), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, - properties=[ - ql.Property(singular="Foo", type="bar", tablename="my_objects", - tableparams=["this", "result"], doc="foo of this baz"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class( + "MyObject", + properties=[schema.SingleProperty("foo", "bar")], + pragmas={"ql_default_doc_name": "baz"}, + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + type="bar", + tablename="my_objects", + tableparams=["this", "result"], + doc="foo of this baz", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_stub_on_class_with_synth_from_class(generate_classes): - assert generate_classes([ - schema.Class("MyObject", pragmas={"synth": schema.SynthInfo(from_class="A")}, - properties=[schema.SingleProperty("foo", "bar")]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), a_ql_stub(name="MyObject", synth_accessors=[ - ql.SynthUnderlyingAccessor(argument="Entity", type="Raw::A", constructorparams=["result"]), - ]), - a_ql_class(name="MyObject", final=True, properties=[ - ql.Property(singular="Foo", type="bar", tablename="my_objects", synth=True, - tableparams=["this", "result"], doc="foo of this my object"), - ], imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class( + "MyObject", + pragmas={"synth": schema.SynthInfo(from_class="A")}, + properties=[schema.SingleProperty("foo", "bar")], + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub( + name="MyObject", + synth_accessors=[ + ql.SynthUnderlyingAccessor( + argument="Entity", type="Raw::A", constructorparams=["result"] + ), + ], + ), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + type="bar", + tablename="my_objects", + synth=True, + tableparams=["this", "result"], + doc="foo of this my object", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_stub_on_class_with_synth_on_arguments(generate_classes): - assert generate_classes([ - schema.Class("MyObject", pragmas={"synth": schema.SynthInfo(on_arguments={"base": "A", "index": "int", "label": "string"})}, - properties=[schema.SingleProperty("foo", "bar")]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), a_ql_stub(name="MyObject", synth_accessors=[ - ql.SynthUnderlyingAccessor(argument="Base", type="Raw::A", constructorparams=["result", "_", "_"]), - ql.SynthUnderlyingAccessor(argument="Index", type="int", constructorparams=["_", "result", "_"]), - ql.SynthUnderlyingAccessor(argument="Label", type="string", constructorparams=["_", "_", "result"]), - ]), - a_ql_class(name="MyObject", final=True, properties=[ - ql.Property(singular="Foo", type="bar", tablename="my_objects", synth=True, - tableparams=["this", "result"], doc="foo of this my object"), - ], imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class( + "MyObject", + pragmas={ + "synth": schema.SynthInfo( + on_arguments={"base": "A", "index": "int", "label": "string"} + ) + }, + properties=[schema.SingleProperty("foo", "bar")], + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub( + name="MyObject", + synth_accessors=[ + ql.SynthUnderlyingAccessor( + argument="Base", + type="Raw::A", + constructorparams=["result", "_", "_"], + ), + ql.SynthUnderlyingAccessor( + argument="Index", + type="int", + constructorparams=["_", "result", "_"], + ), + ql.SynthUnderlyingAccessor( + argument="Label", + type="string", + constructorparams=["_", "_", "result"], + ), + ], + ), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + type="bar", + tablename="my_objects", + synth=True, + tableparams=["this", "result"], + doc="foo of this my object", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_synth_property(generate_classes): - assert generate_classes([ - schema.Class("MyObject", properties=[ - schema.SingleProperty("foo", "bar", synth=True)]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), - a_ql_stub(name="MyObject"), - a_ql_class(name="MyObject", final=True, - properties=[ - ql.Property(singular="Foo", type="bar", tablename="my_objects", - synth=True, - tableparams=["this", "result"], doc="foo of this my object"), - ], - imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class( + "MyObject", properties=[schema.SingleProperty("foo", "bar", synth=True)] + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + properties=[ + ql.Property( + singular="Foo", + type="bar", + tablename="my_objects", + synth=True, + tableparams=["this", "result"], + doc="foo of this my object", + ), + ], + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_hideable_class(generate_classes): - assert generate_classes([ - schema.Class("MyObject", pragmas=["ql_hideable"]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), a_ql_stub(name="MyObject"), a_ql_class(name="MyObject", final=True, hideable=True, imports=[stub_import_prefix + "MyObject"])), + assert generate_classes( + [ + schema.Class("MyObject", pragmas=["ql_hideable"]), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + hideable=True, + imports=[stub_import_prefix + "MyObject"], + ), + ), } def test_hideable_property(generate_classes): - assert generate_classes([ - schema.Class("MyObject", pragmas=["ql_hideable"]), - schema.Class("Other", properties=[ - schema.SingleProperty("x", "MyObject"), - ]), - ]) == { - "MyObject.qll": (a_ql_class_public(name="MyObject"), a_ql_stub(name="MyObject"), a_ql_class(name="MyObject", final=True, hideable=True, imports=[stub_import_prefix + "MyObject"])), - "Other.qll": (a_ql_class_public(name="Other", imports=[stub_import_prefix + "MyObject"]), - a_ql_stub(name="Other"), - a_ql_class(name="Other", imports=[stub_import_prefix + "MyObject", stub_import_prefix + "Other"], - final=True, properties=[ - ql.Property(singular="X", type="MyObject", tablename="others", - type_is_hideable=True, - type_is_codegen_class=True, - tableparams=["this", "result"], doc="x of this other"), - ])), + assert generate_classes( + [ + schema.Class("MyObject", pragmas=["ql_hideable"]), + schema.Class( + "Other", + properties=[ + schema.SingleProperty("x", "MyObject"), + ], + ), + ] + ) == { + "MyObject.qll": ( + a_ql_class_public(name="MyObject"), + a_ql_stub(name="MyObject"), + a_ql_class( + name="MyObject", + final=True, + hideable=True, + imports=[stub_import_prefix + "MyObject"], + ), + ), + "Other.qll": ( + a_ql_class_public(name="Other", imports=[stub_import_prefix + "MyObject"]), + a_ql_stub(name="Other"), + a_ql_class( + name="Other", + imports=[stub_import_prefix + "MyObject", stub_import_prefix + "Other"], + final=True, + properties=[ + ql.Property( + singular="X", + type="MyObject", + tablename="others", + type_is_hideable=True, + type_is_codegen_class=True, + tableparams=["this", "result"], + doc="x of this other", + ), + ], + ), + ), } def test_property_with_custom_db_table_name(generate_classes): - assert generate_classes([ - schema.Class("Obj", properties=[ - schema.OptionalProperty("x", "a", pragmas={"ql_db_table_name": "foo"}), - schema.RepeatedProperty("y", "b", pragmas={"ql_db_table_name": "bar"}), - schema.RepeatedOptionalProperty("z", "c", pragmas={"ql_db_table_name": "baz"}), - schema.PredicateProperty("p", pragmas={"ql_db_table_name": "hello"}), - schema.RepeatedUnorderedProperty("q", "d", pragmas={"ql_db_table_name": "world"}), - ]), - ]) == { - "Obj.qll": (a_ql_class_public(name="Obj"), - a_ql_stub(name="Obj"), - a_ql_class(name="Obj", final=True, properties=[ - ql.Property(singular="X", type="a", tablename="foo", - tableparams=["this", "result"], - is_optional=True, doc="x of this obj"), - ql.Property(singular="Y", plural="Ys", type="b", tablename="bar", - tableparams=["this", "index", "result"], - doc="y of this obj", doc_plural="ys of this obj"), - ql.Property(singular="Z", plural="Zs", type="c", tablename="baz", - tableparams=["this", "index", "result"], - is_optional=True, doc="z of this obj", doc_plural="zs of this obj"), - ql.Property(singular="p", type="predicate", tablename="hello", - tableparams=["this"], is_predicate=True, - doc="this obj p"), - ql.Property(singular="Q", plural="Qs", type="d", tablename="world", - tableparams=["this", "result"], is_unordered=True, - doc="q of this obj", doc_plural="qs of this obj"), - ], - imports=[stub_import_prefix + "Obj"])), + assert generate_classes( + [ + schema.Class( + "Obj", + properties=[ + schema.OptionalProperty( + "x", "a", pragmas={"ql_db_table_name": "foo"} + ), + schema.RepeatedProperty( + "y", "b", pragmas={"ql_db_table_name": "bar"} + ), + schema.RepeatedOptionalProperty( + "z", "c", pragmas={"ql_db_table_name": "baz"} + ), + schema.PredicateProperty( + "p", pragmas={"ql_db_table_name": "hello"} + ), + schema.RepeatedUnorderedProperty( + "q", "d", pragmas={"ql_db_table_name": "world"} + ), + ], + ), + ] + ) == { + "Obj.qll": ( + a_ql_class_public(name="Obj"), + a_ql_stub(name="Obj"), + a_ql_class( + name="Obj", + final=True, + properties=[ + ql.Property( + singular="X", + type="a", + tablename="foo", + tableparams=["this", "result"], + is_optional=True, + doc="x of this obj", + ), + ql.Property( + singular="Y", + plural="Ys", + type="b", + tablename="bar", + tableparams=["this", "index", "result"], + doc="y of this obj", + doc_plural="ys of this obj", + ), + ql.Property( + singular="Z", + plural="Zs", + type="c", + tablename="baz", + tableparams=["this", "index", "result"], + is_optional=True, + doc="z of this obj", + doc_plural="zs of this obj", + ), + ql.Property( + singular="p", + type="predicate", + tablename="hello", + tableparams=["this"], + is_predicate=True, + doc="this obj p", + ), + ql.Property( + singular="Q", + plural="Qs", + type="d", + tablename="world", + tableparams=["this", "result"], + is_unordered=True, + doc="q of this obj", + doc_plural="qs of this obj", + ), + ], + imports=[stub_import_prefix + "Obj"], + ), + ), } -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/misc/codegen/test/test_render.py b/misc/codegen/test/test_render.py index 21378e715bb2..74803c2300c4 100644 --- a/misc/codegen/test/test_render.py +++ b/misc/codegen/test/test_render.py @@ -46,7 +46,10 @@ def write_registry(file, *files_and_hashes): def assert_registry(file, *files_and_hashes): assert_file(file, create_registry(files_and_hashes)) files = [file.name, ".gitattributes"] + [f for f, _, _ in files_and_hashes] - assert_file(file.parent / ".gitattributes", "\n".join(f"/{f} linguist-generated" for f in files) + "\n") + assert_file( + file.parent / ".gitattributes", + "\n".join(f"/{f} linguist-generated" for f in files) + "\n", + ) def hash(text): @@ -56,11 +59,11 @@ def hash(text): def test_constructor(pystache_renderer_cls, sut): - pystache_init, = pystache_renderer_cls.mock_calls - assert set(pystache_init.kwargs) == {'search_dirs', 'escape'} - assert pystache_init.kwargs['search_dirs'] == str(paths.templates_dir) + (pystache_init,) = pystache_renderer_cls.mock_calls + assert set(pystache_init.kwargs) == {"search_dirs", "escape"} + assert pystache_init.kwargs["search_dirs"] == str(paths.templates_dir) an_object = object() - assert pystache_init.kwargs['escape'](an_object) is an_object + assert pystache_init.kwargs["escape"](an_object) is an_object def test_render(pystache_renderer, sut): @@ -218,7 +221,9 @@ def test_managed_render_with_skipping_of_stub_file(pystache_renderer, sut): some_processed_output = "// generated some processed output" registry = paths.root_dir / "a/registry.list" write(stub, some_processed_output) - write_registry(registry, ("some/stub.txt", hash(some_output), hash(some_processed_output))) + write_registry( + registry, ("some/stub.txt", hash(some_output), hash(some_processed_output)) + ) pystache_renderer.render_name.side_effect = (some_output,) @@ -227,7 +232,9 @@ def test_managed_render_with_skipping_of_stub_file(pystache_renderer, sut): assert renderer.written == set() assert_file(stub, some_processed_output) - assert_registry(registry, ("some/stub.txt", hash(some_output), hash(some_processed_output))) + assert_registry( + registry, ("some/stub.txt", hash(some_output), hash(some_processed_output)) + ) assert pystache_renderer.mock_calls == [ mock.call.render_name(data.template, data, generator=generator), ] @@ -238,13 +245,17 @@ def test_managed_render_with_modified_generated_file(pystache_renderer, sut): some_processed_output = "// some processed output" registry = paths.root_dir / "a/registry.list" write(output, "// something else") - write_registry(registry, ("some/output.txt", "whatever", hash(some_processed_output))) + write_registry( + registry, ("some/output.txt", "whatever", hash(some_processed_output)) + ) with pytest.raises(render.Error): sut.manage(generated=(output,), stubs=(), registry=registry) -def test_managed_render_with_modified_stub_file_still_marked_as_generated(pystache_renderer, sut): +def test_managed_render_with_modified_stub_file_still_marked_as_generated( + pystache_renderer, sut +): stub = paths.root_dir / "a/some/stub.txt" some_processed_output = "// generated some processed output" registry = paths.root_dir / "a/registry.list" @@ -255,7 +266,9 @@ def test_managed_render_with_modified_stub_file_still_marked_as_generated(pystac sut.manage(generated=(), stubs=(stub,), registry=registry) -def test_managed_render_with_modified_stub_file_not_marked_as_generated(pystache_renderer, sut): +def test_managed_render_with_modified_stub_file_not_marked_as_generated( + pystache_renderer, sut +): stub = paths.root_dir / "a/some/stub.txt" some_processed_output = "// generated some processed output" registry = paths.root_dir / "a/registry.list" @@ -272,7 +285,9 @@ class MyError(Exception): pass -def test_managed_render_exception_drops_written_and_inexsistent_from_registry(pystache_renderer, sut): +def test_managed_render_exception_drops_written_and_inexsistent_from_registry( + pystache_renderer, sut +): data = mock.Mock(spec=("template",)) text = "some text" pystache_renderer.render_name.side_effect = (text,) @@ -281,11 +296,9 @@ def test_managed_render_exception_drops_written_and_inexsistent_from_registry(py write(output, text) write(paths.root_dir / "a/a") write(paths.root_dir / "a/c") - write_registry(registry, - "aaa", - ("some/output.txt", "whatever", hash(text)), - "bbb", - "ccc") + write_registry( + registry, "aaa", ("some/output.txt", "whatever", hash(text)), "bbb", "ccc" + ) with pytest.raises(MyError): with sut.manage(generated=(), stubs=(), registry=registry) as renderer: @@ -299,17 +312,14 @@ def test_managed_render_drops_inexsistent_from_registry(pystache_renderer, sut): registry = paths.root_dir / "a/registry.list" write(paths.root_dir / "a/a") write(paths.root_dir / "a/c") - write_registry(registry, - ("a", hash(''), hash('')), - "bbb", - ("c", hash(''), hash(''))) + write_registry( + registry, ("a", hash(""), hash("")), "bbb", ("c", hash(""), hash("")) + ) with sut.manage(generated=(), stubs=(), registry=registry): pass - assert_registry(registry, - ("a", hash(''), hash('')), - ("c", hash(''), hash(''))) + assert_registry(registry, ("a", hash(""), hash("")), ("c", hash(""), hash(""))) def test_managed_render_exception_does_not_erase(pystache_renderer, sut): @@ -321,7 +331,9 @@ def test_managed_render_exception_does_not_erase(pystache_renderer, sut): write_registry(registry) with pytest.raises(MyError): - with sut.manage(generated=(output,), stubs=(stub,), registry=registry) as renderer: + with sut.manage( + generated=(output,), stubs=(stub,), registry=registry + ) as renderer: raise MyError assert output.is_file() @@ -333,14 +345,15 @@ def test_render_with_extensions(pystache_renderer, sut): data.template = "test_template" data.extensions = ["foo", "bar", "baz"] output = pathlib.Path("my", "test", "file") - expected_outputs = [pathlib.Path("my", "test", p) for p in ("file.foo", "file.bar", "file.baz")] + expected_outputs = [ + pathlib.Path("my", "test", p) for p in ("file.foo", "file.bar", "file.baz") + ] rendered = [f"text{i}" for i in range(len(expected_outputs))] pystache_renderer.render_name.side_effect = rendered sut.render(data, output) expected_templates = ["test_template_foo", "test_template_bar", "test_template_baz"] assert pystache_renderer.mock_calls == [ - mock.call.render_name(t, data, generator=generator) - for t in expected_templates + mock.call.render_name(t, data, generator=generator) for t in expected_templates ] for expected_output, expected_contents in zip(expected_outputs, rendered): assert_file(expected_output, expected_contents) @@ -356,7 +369,9 @@ def test_managed_render_with_force_not_skipping_generated_file(pystache_renderer pystache_renderer.render_name.side_effect = (some_output,) - with sut.manage(generated=(output,), stubs=(), registry=registry, force=True) as renderer: + with sut.manage( + generated=(output,), stubs=(), registry=registry, force=True + ) as renderer: renderer.render(data, output) assert renderer.written == {output} assert_file(output, some_output) @@ -374,11 +389,15 @@ def test_managed_render_with_force_not_skipping_stub_file(pystache_renderer, sut some_processed_output = "// generated some processed output" registry = paths.root_dir / "a/registry.list" write(stub, some_processed_output) - write_registry(registry, ("some/stub.txt", hash(some_output), hash(some_processed_output))) + write_registry( + registry, ("some/stub.txt", hash(some_output), hash(some_processed_output)) + ) pystache_renderer.render_name.side_effect = (some_output,) - with sut.manage(generated=(), stubs=(stub,), registry=registry, force=True) as renderer: + with sut.manage( + generated=(), stubs=(stub,), registry=registry, force=True + ) as renderer: renderer.render(data, stub) assert renderer.written == {stub} assert_file(stub, some_output) @@ -394,13 +413,17 @@ def test_managed_render_with_force_ignores_modified_generated_file(sut): some_processed_output = "// some processed output" registry = paths.root_dir / "a/registry.list" write(output, "// something else") - write_registry(registry, ("some/output.txt", "whatever", hash(some_processed_output))) + write_registry( + registry, ("some/output.txt", "whatever", hash(some_processed_output)) + ) with sut.manage(generated=(output,), stubs=(), registry=registry, force=True): pass -def test_managed_render_with_force_ignores_modified_stub_file_still_marked_as_generated(sut): +def test_managed_render_with_force_ignores_modified_stub_file_still_marked_as_generated( + sut, +): stub = paths.root_dir / "a/some/stub.txt" some_processed_output = "// generated some processed output" registry = paths.root_dir / "a/registry.list" @@ -411,5 +434,5 @@ def test_managed_render_with_force_ignores_modified_stub_file_still_marked_as_ge pass -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/misc/codegen/test/test_schemaloader.py b/misc/codegen/test/test_schemaloader.py index 1c8bfba271b2..5e8ba91b742a 100644 --- a/misc/codegen/test/test_schemaloader.py +++ b/misc/codegen/test/test_schemaloader.py @@ -26,9 +26,9 @@ class MyClass: pass assert data.classes == { - 'MyClass': schema.Class('MyClass'), + "MyClass": schema.Class("MyClass"), } - assert data.root_class is data.classes['MyClass'] + assert data.root_class is data.classes["MyClass"] def test_two_empty_classes(): @@ -41,10 +41,10 @@ class MyClass2(MyClass1): pass assert data.classes == { - 'MyClass1': schema.Class('MyClass1', derived={'MyClass2'}), - 'MyClass2': schema.Class('MyClass2', bases=['MyClass1']), + "MyClass1": schema.Class("MyClass1", derived={"MyClass2"}), + "MyClass2": schema.Class("MyClass2", bases=["MyClass1"]), } - assert data.root_class is data.classes['MyClass1'] + assert data.root_class is data.classes["MyClass1"] def test_no_external_bases(): @@ -52,6 +52,7 @@ class A: pass with pytest.raises(schema.Error): + @load class data: class MyClass(A): @@ -60,6 +61,7 @@ class MyClass(A): def test_no_multiple_roots(): with pytest.raises(schema.Error): + @load class data: class MyClass1: @@ -85,10 +87,10 @@ class D(B, C): pass assert data.classes == { - 'A': schema.Class('A', derived={'B', 'C'}), - 'B': schema.Class('B', bases=['A'], derived={'D'}), - 'C': schema.Class('C', bases=['A'], derived={'D'}), - 'D': schema.Class('D', bases=['B', 'C']), + "A": schema.Class("A", derived={"B", "C"}), + "B": schema.Class("B", bases=["A"], derived={"D"}), + "C": schema.Class("C", bases=["A"], derived={"D"}), + "D": schema.Class("D", bases=["B", "C"]), } @@ -101,7 +103,7 @@ class A: pass assert data.classes == { - 'A': schema.Class('A', pragmas={"group": "xxx"}), + "A": schema.Class("A", pragmas={"group": "xxx"}), } @@ -114,7 +116,7 @@ class A: class B(A): pass - @defs.group('xxx') + @defs.group("xxx") class C(A): pass @@ -122,25 +124,26 @@ class D(B, C): pass assert data.classes == { - 'A': schema.Class('A', derived={'B', 'C'}), - 'B': schema.Class('B', bases=['A'], derived={'D'}), - 'C': schema.Class('C', bases=['A'], derived={'D'}, pragmas={"group": "xxx"}), - 'D': schema.Class('D', bases=['B', 'C'], pragmas={"group": "xxx"}), + "A": schema.Class("A", derived={"B", "C"}), + "B": schema.Class("B", bases=["A"], derived={"D"}), + "C": schema.Class("C", bases=["A"], derived={"D"}, pragmas={"group": "xxx"}), + "D": schema.Class("D", bases=["B", "C"], pragmas={"group": "xxx"}), } def test_no_mixed_groups_in_bases(): with pytest.raises(schema.Error): + @load class data: class A: pass - @defs.group('x') + @defs.group("x") class B(A): pass - @defs.group('y') + @defs.group("y") class C(A): pass @@ -153,6 +156,7 @@ class D(B, C): def test_lowercase_rejected(): with pytest.raises(schema.Error): + @load class data: class aLowerCase: @@ -171,14 +175,17 @@ class A: six: defs.set[defs.string] assert data.classes == { - 'A': schema.Class('A', properties=[ - schema.SingleProperty('one', 'string'), - schema.OptionalProperty('two', 'int'), - schema.RepeatedProperty('three', 'boolean'), - schema.RepeatedOptionalProperty('four', 'string'), - schema.PredicateProperty('five'), - schema.RepeatedUnorderedProperty('six', 'string'), - ]), + "A": schema.Class( + "A", + properties=[ + schema.SingleProperty("one", "string"), + schema.OptionalProperty("two", "int"), + schema.RepeatedProperty("three", "boolean"), + schema.RepeatedOptionalProperty("four", "string"), + schema.PredicateProperty("five"), + schema.RepeatedUnorderedProperty("six", "string"), + ], + ), } @@ -199,14 +206,18 @@ class B(A): five: defs.set[A] assert data.classes == { - 'A': schema.Class('A', derived={'B'}), - 'B': schema.Class('B', bases=['A'], properties=[ - schema.SingleProperty('one', 'A'), - schema.OptionalProperty('two', 'A'), - schema.RepeatedProperty('three', 'A'), - schema.RepeatedOptionalProperty('four', 'A'), - schema.RepeatedUnorderedProperty('five', 'A'), - ]), + "A": schema.Class("A", derived={"B"}), + "B": schema.Class( + "B", + bases=["A"], + properties=[ + schema.SingleProperty("one", "A"), + schema.OptionalProperty("two", "A"), + schema.RepeatedProperty("three", "A"), + schema.RepeatedOptionalProperty("four", "A"), + schema.RepeatedUnorderedProperty("five", "A"), + ], + ), } @@ -221,20 +232,31 @@ class A: five: defs.set["A"] assert data.classes == { - 'A': schema.Class('A', properties=[ - schema.SingleProperty('one', 'A'), - schema.OptionalProperty('two', 'A'), - schema.RepeatedProperty('three', 'A'), - schema.RepeatedOptionalProperty('four', 'A'), - schema.RepeatedUnorderedProperty('five', 'A'), - ]), + "A": schema.Class( + "A", + properties=[ + schema.SingleProperty("one", "A"), + schema.OptionalProperty("two", "A"), + schema.RepeatedProperty("three", "A"), + schema.RepeatedOptionalProperty("four", "A"), + schema.RepeatedUnorderedProperty("five", "A"), + ], + ), } -@pytest.mark.parametrize("spec", [lambda t: t, lambda t: defs.optional[t], lambda t: defs.list[t], - lambda t: defs.list[defs.optional[t]]]) +@pytest.mark.parametrize( + "spec", + [ + lambda t: t, + lambda t: defs.optional[t], + lambda t: defs.list[t], + lambda t: defs.list[defs.optional[t]], + ], +) def test_string_reference_dangling(spec): with pytest.raises(schema.Error): + @load class data: class A: @@ -251,18 +273,24 @@ class A: four: defs.list[defs.optional["A"]] | defs.child assert data.classes == { - 'A': schema.Class('A', properties=[ - schema.SingleProperty('one', 'A', is_child=True), - schema.OptionalProperty('two', 'A', is_child=True), - schema.RepeatedProperty('three', 'A', is_child=True), - schema.RepeatedOptionalProperty('four', 'A', is_child=True), - ]), + "A": schema.Class( + "A", + properties=[ + schema.SingleProperty("one", "A", is_child=True), + schema.OptionalProperty("two", "A", is_child=True), + schema.RepeatedProperty("three", "A", is_child=True), + schema.RepeatedOptionalProperty("four", "A", is_child=True), + ], + ), } -@pytest.mark.parametrize("spec", [defs.string, defs.int, defs.boolean, defs.predicate, defs.set["A"]]) +@pytest.mark.parametrize( + "spec", [defs.string, defs.int, defs.boolean, defs.predicate, defs.set["A"]] +) def test_builtin_predicate_and_set_children_not_allowed(spec): with pytest.raises(schema.Error): + @load class data: class A: @@ -291,9 +319,12 @@ class A: x: defs.string | pragma assert data.classes == { - 'A': schema.Class('A', properties=[ - schema.SingleProperty('x', 'string', pragmas=[expected]), - ]), + "A": schema.Class( + "A", + properties=[ + schema.SingleProperty("x", "string", pragmas=[expected]), + ], + ), } @@ -308,9 +339,16 @@ class A: x: spec assert data.classes == { - 'A': schema.Class('A', properties=[ - schema.SingleProperty('x', 'string', pragmas=[expected for _, expected in _property_pragmas]), - ]), + "A": schema.Class( + "A", + properties=[ + schema.SingleProperty( + "x", + "string", + pragmas=[expected for _, expected in _property_pragmas], + ), + ], + ), } @@ -323,7 +361,7 @@ class A: pass assert data.classes == { - 'A': schema.Class('A', pragmas=[expected]), + "A": schema.Class("A", pragmas=[expected]), } @@ -340,7 +378,7 @@ class A: apply_pragmas(A) assert data.classes == { - 'A': schema.Class('A', pragmas=[e for _, e in _pragmas]), + "A": schema.Class("A", pragmas=[e for _, e in _pragmas]), } @@ -355,8 +393,10 @@ class B(A): pass assert data.classes == { - 'A': schema.Class('A', derived={'B'}, pragmas={"synth": True}), - 'B': schema.Class('B', bases=['A'], pragmas={"synth": schema.SynthInfo(from_class="A")}), + "A": schema.Class("A", derived={"B"}, pragmas={"synth": True}), + "B": schema.Class( + "B", bases=["A"], pragmas={"synth": schema.SynthInfo(from_class="A")} + ), } @@ -371,13 +411,16 @@ class B(A): pass assert data.classes == { - 'A': schema.Class('A', derived={'B'}, pragmas={"synth": schema.SynthInfo(from_class="B")}), - 'B': schema.Class('B', bases=['A']), + "A": schema.Class( + "A", derived={"B"}, pragmas={"synth": schema.SynthInfo(from_class="B")} + ), + "B": schema.Class("B", bases=["A"]), } def test_synth_from_class_dangling(): with pytest.raises(schema.Error): + @load class data: @defs.synth.from_class("X") @@ -396,8 +439,12 @@ class B(A): pass assert data.classes == { - 'A': schema.Class('A', derived={'B'}, pragmas={"synth": True}), - 'B': schema.Class('B', bases=['A'], pragmas={"synth": schema.SynthInfo(on_arguments={'a': 'A', 'i': 'int'})}), + "A": schema.Class("A", derived={"B"}, pragmas={"synth": True}), + "B": schema.Class( + "B", + bases=["A"], + pragmas={"synth": schema.SynthInfo(on_arguments={"a": "A", "i": "int"})}, + ), } @@ -415,13 +462,18 @@ class B(A): pass assert data.classes == { - 'A': schema.Class('A', derived={'B'}, pragmas={"synth": schema.SynthInfo(on_arguments={'b': 'B', 'i': 'int'})}), - 'B': schema.Class('B', bases=['A']), + "A": schema.Class( + "A", + derived={"B"}, + pragmas={"synth": schema.SynthInfo(on_arguments={"b": "B", "i": "int"})}, + ), + "B": schema.Class("B", bases=["A"]), } def test_synth_class_on_dangling(): with pytest.raises(schema.Error): + @load class data: @defs.synth.on_arguments(s=defs.string, a="A", i=defs.int) @@ -453,12 +505,25 @@ class C(Root): pass assert data.classes == { - 'Root': schema.Class('Root', derived={'Base', 'C'}), - 'Base': schema.Class('Base', bases=['Root'], derived={'Intermediate', 'B'}, pragmas={"synth": True}), - 'Intermediate': schema.Class('Intermediate', bases=['Base'], derived={'A'}, pragmas={"synth": True}), - 'A': schema.Class('A', bases=['Intermediate'], pragmas={"synth": schema.SynthInfo(on_arguments={'a': 'Base', 'i': 'int'})}), - 'B': schema.Class('B', bases=['Base'], pragmas={"synth": schema.SynthInfo(from_class='Base')}), - 'C': schema.Class('C', bases=['Root']), + "Root": schema.Class("Root", derived={"Base", "C"}), + "Base": schema.Class( + "Base", + bases=["Root"], + derived={"Intermediate", "B"}, + pragmas={"synth": True}, + ), + "Intermediate": schema.Class( + "Intermediate", bases=["Base"], derived={"A"}, pragmas={"synth": True} + ), + "A": schema.Class( + "A", + bases=["Intermediate"], + pragmas={"synth": schema.SynthInfo(on_arguments={"a": "Base", "i": "int"})}, + ), + "B": schema.Class( + "B", bases=["Base"], pragmas={"synth": schema.SynthInfo(from_class="Base")} + ), + "C": schema.Class("C", bases=["Root"]), } @@ -479,9 +544,7 @@ class data: class A: """Very important class.""" - assert data.classes == { - 'A': schema.Class('A', doc=["Very important class."]) - } + assert data.classes == {"A": schema.Class("A", doc=["Very important class."])} def test_property_docstring(): @@ -491,7 +554,14 @@ class A: x: int | defs.desc("very important property.") assert data.classes == { - 'A': schema.Class('A', properties=[schema.SingleProperty('x', 'int', description=["very important property."])]) + "A": schema.Class( + "A", + properties=[ + schema.SingleProperty( + "x", "int", description=["very important property."] + ) + ], + ) } @@ -502,21 +572,27 @@ class A: """Very important class.""" - assert data.classes == { - 'A': schema.Class('A', doc=["Very important", "class."]) - } + assert data.classes == {"A": schema.Class("A", doc=["Very important", "class."])} def test_property_docstring_newline(): @load class data: class A: - x: int | defs.desc("""very important - property.""") + x: int | defs.desc( + """very important + property.""" + ) assert data.classes == { - 'A': schema.Class('A', - properties=[schema.SingleProperty('x', 'int', description=["very important", "property."])]) + "A": schema.Class( + "A", + properties=[ + schema.SingleProperty( + "x", "int", description=["very important", "property."] + ) + ], + ) } @@ -530,23 +606,30 @@ class A: """ - assert data.classes == { - 'A': schema.Class('A', doc=["Very important class."]) - } + assert data.classes == {"A": schema.Class("A", doc=["Very important class."])} def test_property_docstring_stripped(): @load class data: class A: - x: int | defs.desc(""" + x: int | defs.desc( + """ very important property. - """) + """ + ) assert data.classes == { - 'A': schema.Class('A', properties=[schema.SingleProperty('x', 'int', description=["very important property."])]) + "A": schema.Class( + "A", + properties=[ + schema.SingleProperty( + "x", "int", description=["very important property."] + ) + ], + ) } @@ -559,7 +642,9 @@ class A: As said, very important.""" assert data.classes == { - 'A': schema.Class('A', doc=["Very important class.", "", "As said, very important."]) + "A": schema.Class( + "A", doc=["Very important class.", "", "As said, very important."] + ) } @@ -567,13 +652,27 @@ def test_property_docstring_split(): @load class data: class A: - x: int | defs.desc("""very important property. + x: int | defs.desc( + """very important property. - Very very important.""") + Very very important.""" + ) assert data.classes == { - 'A': schema.Class('A', properties=[ - schema.SingleProperty('x', 'int', description=["very important property.", "", "Very very important."])]) + "A": schema.Class( + "A", + properties=[ + schema.SingleProperty( + "x", + "int", + description=[ + "very important property.", + "", + "Very very important.", + ], + ) + ], + ) } @@ -587,7 +686,9 @@ class A: """ assert data.classes == { - 'A': schema.Class('A', doc=["Very important class.", " As said, very important."]) + "A": schema.Class( + "A", doc=["Very important class.", " As said, very important."] + ) } @@ -595,14 +696,24 @@ def test_property_docstring_indent(): @load class data: class A: - x: int | defs.desc(""" + x: int | defs.desc( + """ very important property. Very very important. - """) + """ + ) assert data.classes == { - 'A': schema.Class('A', properties=[ - schema.SingleProperty('x', 'int', description=["very important property.", " Very very important."])]) + "A": schema.Class( + "A", + properties=[ + schema.SingleProperty( + "x", + "int", + description=["very important property.", " Very very important."], + ) + ], + ) } @@ -613,13 +724,13 @@ class A: x: int | defs.doc("y") assert data.classes == { - 'A': schema.Class('A', properties=[ - schema.SingleProperty('x', 'int', doc="y")]), + "A": schema.Class("A", properties=[schema.SingleProperty("x", "int", doc="y")]), } def test_property_doc_override_no_newlines(): with pytest.raises(schema.Error): + @load class data: class A: @@ -628,6 +739,7 @@ class A: def test_property_doc_override_no_trailing_dot(): with pytest.raises(schema.Error): + @load class data: class A: @@ -642,7 +754,7 @@ class A: pass assert data.classes == { - 'A': schema.Class('A', pragmas={"ql_default_doc_name": "b"}), + "A": schema.Class("A", pragmas={"ql_default_doc_name": "b"}), } @@ -653,7 +765,12 @@ class A: x: optional[int] | defs.ql.db_table_name("foo") assert data.classes == { - 'A': schema.Class('A', properties=[schema.OptionalProperty("x", "int", pragmas={"ql_db_table_name": "foo"})]), + "A": schema.Class( + "A", + properties=[ + schema.OptionalProperty("x", "int", pragmas={"ql_db_table_name": "foo"}) + ], + ), } @@ -668,15 +785,16 @@ class Null(Root): pass assert data.classes == { - 'Root': schema.Class('Root', derived={'Null'}), - 'Null': schema.Class('Null', bases=['Root']), + "Root": schema.Class("Root", derived={"Null"}), + "Null": schema.Class("Null", bases=["Root"]), } - assert data.null == 'Null' + assert data.null == "Null" assert data.null_class is data.classes[data.null] def test_null_class_cannot_be_derived(): with pytest.raises(schema.Error): + @load class data: class Root: @@ -692,6 +810,7 @@ class Impossible(Null): def test_null_class_cannot_be_defined_multiple_times(): with pytest.raises(schema.Error): + @load class data: class Root: @@ -708,6 +827,7 @@ class Null2(Root): def test_uppercase_acronyms_are_rejected(): with pytest.raises(schema.Error): + @load class data: class Root: @@ -737,10 +857,18 @@ class NonHideable(Root): pass assert data.classes == { - "Root": schema.Class("Root", derived={"A", "IndirectlyHideable", "NonHideable"}, pragmas=["ql_hideable"]), + "Root": schema.Class( + "Root", + derived={"A", "IndirectlyHideable", "NonHideable"}, + pragmas=["ql_hideable"], + ), "A": schema.Class("A", bases=["Root"], derived={"B"}, pragmas=["ql_hideable"]), - "IndirectlyHideable": schema.Class("IndirectlyHideable", bases=["Root"], derived={"B"}, pragmas=["ql_hideable"]), - "B": schema.Class("B", bases=["A", "IndirectlyHideable"], pragmas=["ql_hideable"]), + "IndirectlyHideable": schema.Class( + "IndirectlyHideable", bases=["Root"], derived={"B"}, pragmas=["ql_hideable"] + ), + "B": schema.Class( + "B", bases=["A", "IndirectlyHideable"], pragmas=["ql_hideable"] + ), "NonHideable": schema.Class("NonHideable", bases=["Root"]), } @@ -771,7 +899,9 @@ class E(B): assert data.classes == { "Root": schema.Class("Root", derived=set("ABCD")), "A": schema.Class("A", bases=["Root"]), - "B": schema.Class("B", bases=["Root"], pragmas={"qltest_test_with": "A"}, derived={'E'}), + "B": schema.Class( + "B", bases=["Root"], pragmas={"qltest_test_with": "A"}, derived={"E"} + ), "C": schema.Class("C", bases=["Root"], pragmas={"qltest_test_with": "D"}), "D": schema.Class("D", bases=["Root"]), "E": schema.Class("E", bases=["B"], pragmas={"qltest_test_with": "A"}), @@ -782,10 +912,10 @@ def test_annotate_docstring(): @load class data: class Root: - """ old docstring """ + """old docstring""" class A(Root): - """ A docstring """ + """A docstring""" @defs.annotate(Root) class _: @@ -819,7 +949,15 @@ class _: pass assert data.classes == { - "Root": schema.Class("Root", pragmas=["qltest_skip", "cpp_skip", "ql_hideable", "qltest_collapse_hierarchy"]), + "Root": schema.Class( + "Root", + pragmas=[ + "qltest_skip", + "cpp_skip", + "ql_hideable", + "qltest_collapse_hierarchy", + ], + ), } @@ -837,11 +975,16 @@ class _: z: defs.string assert data.classes == { - "Root": schema.Class("Root", properties=[ - schema.SingleProperty("x", "int", doc="foo"), - schema.OptionalProperty("y", "Root", pragmas=["ql_internal"], is_child=True), - schema.SingleProperty("z", "string"), - ]), + "Root": schema.Class( + "Root", + properties=[ + schema.SingleProperty("x", "int", doc="foo"), + schema.OptionalProperty( + "y", "Root", pragmas=["ql_internal"], is_child=True + ), + schema.SingleProperty("z", "string"), + ], + ), } @@ -860,16 +1003,20 @@ class _: z: defs._ | ~defs.synth | ~defs.doc assert data.classes == { - "Root": schema.Class("Root", properties=[ - schema.SingleProperty("x", "int"), - schema.OptionalProperty("y", "Root"), - schema.SingleProperty("z", "string"), - ]), + "Root": schema.Class( + "Root", + properties=[ + schema.SingleProperty("x", "int"), + schema.OptionalProperty("y", "Root"), + schema.SingleProperty("z", "string"), + ], + ), } def test_annotate_non_existing_field(): with pytest.raises(schema.Error): + @load class data: class Root: @@ -882,6 +1029,7 @@ class _: def test_annotate_not_underscore(): with pytest.raises(schema.Error): + @load class data: class Root: @@ -916,6 +1064,7 @@ class Derived(A, B): @defs.annotate(Derived, replace_bases={B: C}) class _: pass + assert data.classes == { "Root": schema.Class("Root", derived={"A", "B"}), "A": schema.Class("A", bases=["Root"], derived={"Derived"}), @@ -946,6 +1095,7 @@ class Derived(A): @defs.annotate(Derived, add_bases=(B, C)) class _: pass + assert data.classes == { "Root": schema.Class("Root", derived={"A", "B", "C"}), "A": schema.Class("A", bases=["Root"], derived={"Derived"}), @@ -968,15 +1118,19 @@ class _: y: defs.drop assert data.classes == { - "Root": schema.Class("Root", properties=[ - schema.SingleProperty("x", "int"), - schema.SingleProperty("z", "boolean"), - ]), + "Root": schema.Class( + "Root", + properties=[ + schema.SingleProperty("x", "int"), + schema.SingleProperty("z", "boolean"), + ], + ), } def test_test_with_unknown_string(): with pytest.raises(schema.Error): + @load class data: class Root: @@ -989,6 +1143,7 @@ class A(Root): def test_test_with_unknown_class(): with pytest.raises(schema.Error): + class B: pass @@ -1004,6 +1159,7 @@ class A(Root): def test_test_with_double(): with pytest.raises(schema.Error): + class B: pass @@ -1024,5 +1180,5 @@ class C(Root): pass -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/misc/codegen/test/test_trapgen.py b/misc/codegen/test/test_trapgen.py index a81f40e0dd83..590c83aa7347 100644 --- a/misc/codegen/test/test_trapgen.py +++ b/misc/codegen/test/test_trapgen.py @@ -17,10 +17,16 @@ def ret(entities): dirs = {f.parent for f in generated} assert all(isinstance(f, pathlib.Path) for f in generated) assert all(f.name in ("TrapEntries", "TrapTags") for f in generated) - assert set(f for f in generated if f.name == "TrapTags") == {output_dir / "TrapTags"} - return ({ - str(d.relative_to(output_dir)): generated[d / "TrapEntries"] for d in dirs - }, generated[output_dir / "TrapTags"]) + assert set(f for f in generated if f.name == "TrapTags") == { + output_dir / "TrapTags" + } + return ( + { + str(d.relative_to(output_dir)): generated[d / "TrapEntries"] + for d in dirs + }, + generated[output_dir / "TrapTags"], + ) return ret @@ -65,87 +71,130 @@ def test_empty_tags(generate_tags): def test_one_empty_table_rejected(generate_traps): with pytest.raises(AssertionError): - generate_traps([ - dbscheme.Table(name="foos", columns=[]), - ]) + generate_traps( + [ + dbscheme.Table(name="foos", columns=[]), + ] + ) def test_one_table(generate_traps): - assert generate_traps([ - dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]), - ]) == [ + assert generate_traps( + [ + dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]), + ] + ) == [ cpp.Trap("foos", name="Foos", fields=[cpp.Field("bla", "int")]), ] def test_one_table(generate_traps): - assert generate_traps([ - dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]), - ]) == [ + assert generate_traps( + [ + dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]), + ] + ) == [ cpp.Trap("foos", name="Foos", fields=[cpp.Field("bla", "int")]), ] def test_one_table_with_id(generate_traps): - assert generate_traps([ - dbscheme.Table(name="foos", columns=[ - dbscheme.Column("bla", "int", binding=True)]), - ]) == [ - cpp.Trap("foos", name="Foos", fields=[cpp.Field( - "bla", "int")], id=cpp.Field("bla", "int")), + assert generate_traps( + [ + dbscheme.Table( + name="foos", columns=[dbscheme.Column("bla", "int", binding=True)] + ), + ] + ) == [ + cpp.Trap( + "foos", + name="Foos", + fields=[cpp.Field("bla", "int")], + id=cpp.Field("bla", "int"), + ), ] def test_one_table_with_two_binding_first_is_id(generate_traps): - assert generate_traps([ - dbscheme.Table(name="foos", columns=[ - dbscheme.Column("x", "a", binding=True), - dbscheme.Column("y", "b", binding=True), - ]), - ]) == [ - cpp.Trap("foos", name="Foos", fields=[ - cpp.Field("x", "a"), - cpp.Field("y", "b"), - ], id=cpp.Field("x", "a")), + assert generate_traps( + [ + dbscheme.Table( + name="foos", + columns=[ + dbscheme.Column("x", "a", binding=True), + dbscheme.Column("y", "b", binding=True), + ], + ), + ] + ) == [ + cpp.Trap( + "foos", + name="Foos", + fields=[ + cpp.Field("x", "a"), + cpp.Field("y", "b"), + ], + id=cpp.Field("x", "a"), + ), ] -@pytest.mark.parametrize("column,field", [ - (dbscheme.Column("x", "string"), cpp.Field("x", "std::string")), - (dbscheme.Column("y", "boolean"), cpp.Field("y", "bool")), - (dbscheme.Column("z", "@db_type"), cpp.Field("z", "TrapLabel")), -]) +@pytest.mark.parametrize( + "column,field", + [ + (dbscheme.Column("x", "string"), cpp.Field("x", "std::string")), + (dbscheme.Column("y", "boolean"), cpp.Field("y", "bool")), + (dbscheme.Column("z", "@db_type"), cpp.Field("z", "TrapLabel")), + ], +) def test_one_table_special_types(generate_traps, column, field): - assert generate_traps([ - dbscheme.Table(name="foos", columns=[column]), - ]) == [ + assert generate_traps( + [ + dbscheme.Table(name="foos", columns=[column]), + ] + ) == [ cpp.Trap("foos", name="Foos", fields=[field]), ] -@pytest.mark.parametrize("name", ["start_line", "start_column", "end_line", "end_column", "index", "num_whatever"]) +@pytest.mark.parametrize( + "name", + ["start_line", "start_column", "end_line", "end_column", "index", "num_whatever"], +) def test_one_table_overridden_unsigned_field(generate_traps, name): - assert generate_traps([ - dbscheme.Table(name="foos", columns=[dbscheme.Column(name, "bar")]), - ]) == [ + assert generate_traps( + [ + dbscheme.Table(name="foos", columns=[dbscheme.Column(name, "bar")]), + ] + ) == [ cpp.Trap("foos", name="Foos", fields=[cpp.Field(name, "unsigned")]), ] def test_one_table_overridden_underscore_named_field(generate_traps): - assert generate_traps([ - dbscheme.Table(name="foos", columns=[dbscheme.Column("whatever_", "bar")]), - ]) == [ + assert generate_traps( + [ + dbscheme.Table(name="foos", columns=[dbscheme.Column("whatever_", "bar")]), + ] + ) == [ cpp.Trap("foos", name="Foos", fields=[cpp.Field("whatever", "bar")]), ] def test_tables_with_dir(generate_grouped_traps): - assert generate_grouped_traps([ - dbscheme.Table(name="x", columns=[dbscheme.Column("i", "int")]), - dbscheme.Table(name="y", columns=[dbscheme.Column("i", "int")], dir=pathlib.Path("foo")), - dbscheme.Table(name="z", columns=[dbscheme.Column("i", "int")], dir=pathlib.Path("foo/bar")), - ]) == { + assert generate_grouped_traps( + [ + dbscheme.Table(name="x", columns=[dbscheme.Column("i", "int")]), + dbscheme.Table( + name="y", columns=[dbscheme.Column("i", "int")], dir=pathlib.Path("foo") + ), + dbscheme.Table( + name="z", + columns=[dbscheme.Column("i", "int")], + dir=pathlib.Path("foo/bar"), + ), + ] + ) == { ".": [cpp.Trap("x", name="X", fields=[cpp.Field("i", "int")])], "foo": [cpp.Trap("y", name="Y", fields=[cpp.Field("i", "int")])], "foo/bar": [cpp.Trap("z", name="Z", fields=[cpp.Field("i", "int")])], @@ -153,15 +202,22 @@ def test_tables_with_dir(generate_grouped_traps): def test_one_table_no_tags(generate_tags): - assert generate_tags([ - dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]), - ]) == [] + assert ( + generate_tags( + [ + dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]), + ] + ) + == [] + ) def test_one_union_tags(generate_tags): - assert generate_tags([ - dbscheme.Union(lhs="@left_hand_side", rhs=["@b", "@a", "@c"]), - ]) == [ + assert generate_tags( + [ + dbscheme.Union(lhs="@left_hand_side", rhs=["@b", "@a", "@c"]), + ] + ) == [ cpp.Tag(name="LeftHandSide", bases=[], id="@left_hand_side"), cpp.Tag(name="A", bases=["LeftHandSide"], id="@a"), cpp.Tag(name="B", bases=["LeftHandSide"], id="@b"), @@ -170,11 +226,13 @@ def test_one_union_tags(generate_tags): def test_multiple_union_tags(generate_tags): - assert generate_tags([ - dbscheme.Union(lhs="@d", rhs=["@a"]), - dbscheme.Union(lhs="@a", rhs=["@b", "@c"]), - dbscheme.Union(lhs="@e", rhs=["@c", "@f"]), - ]) == [ + assert generate_tags( + [ + dbscheme.Union(lhs="@d", rhs=["@a"]), + dbscheme.Union(lhs="@a", rhs=["@b", "@c"]), + dbscheme.Union(lhs="@e", rhs=["@c", "@f"]), + ] + ) == [ cpp.Tag(name="D", bases=[], id="@d"), cpp.Tag(name="E", bases=[], id="@e"), cpp.Tag(name="A", bases=["D"], id="@a"), @@ -184,5 +242,5 @@ def test_multiple_union_tags(generate_tags): ] -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/misc/codegen/test/utils.py b/misc/codegen/test/utils.py index e33500711f25..094455d3d14d 100644 --- a/misc/codegen/test/utils.py +++ b/misc/codegen/test/utils.py @@ -39,8 +39,9 @@ def opts(): @pytest.fixture(autouse=True) def override_paths(tmp_path): - with mock.patch("misc.codegen.lib.paths.root_dir", tmp_path), \ - mock.patch("misc.codegen.lib.paths.exe_file", tmp_path / "exe"): + with mock.patch("misc.codegen.lib.paths.root_dir", tmp_path), mock.patch( + "misc.codegen.lib.paths.exe_file", tmp_path / "exe" + ): yield