diff --git a/popper/loop.py b/popper/loop.py index 8aeb69ed..2ce3d6c6 100644 --- a/popper/loop.py +++ b/popper/loop.py @@ -3,7 +3,9 @@ from bitarray.util import subset, any_and, ones from functools import cache from itertools import chain, combinations, permutations -from . util import timeout, format_rule, rule_is_recursive, prog_is_recursive, prog_has_invention, calc_prog_size, format_literal, Constraint, mdl_score, suppress_stdout_stderr, get_raw_prog, Literal, remap_variables, format_prog +from .util import timeout, format_rule, rule_is_recursive, prog_is_recursive, prog_has_invention, calc_prog_size, \ + format_literal, Constraint, mdl_score, suppress_stdout_stderr, get_raw_prog, Literal, remap_variables, format_prog, \ + order_prog from . tester import Tester from . bkcons import deduce_bk_cons, deduce_recalls, deduce_type_cons, deduce_non_singletons from . combine import Combiner @@ -1509,7 +1511,7 @@ def print_incomplete_solution2(self, prog, tp, fn, tn, fp, size): else: self.logger.info(f'tp:{tp} fn:{fn} tn:{tn} fp:{fp} size:{size}') for rule in order_prog(prog): - self.logger.info(format_rule(order_rule(rule))) + self.logger.info(format_rule(self.settings.order_rule(rule))) self.logger.info('*'*20) def needs_datalog(self, prog): @@ -1613,6 +1615,7 @@ def learn_solution(settings): bkcons = get_bk_cons(settings, tester) time_so_far = time.time()-t1 timeout(settings, popper, (settings, tester, bkcons), timeout_duration=int(settings.timeout-time_so_far),) + tester.destroy_prolog_module() return settings.solution, settings.best_prog_score, settings.stats def generalisations(prog, allow_headless=True, recursive=False): diff --git a/popper/tester.py b/popper/tester.py index 49a9ca5f..f3270f8d 100644 --- a/popper/tester.py +++ b/popper/tester.py @@ -1,47 +1,66 @@ +import datetime import os -import time -import pkg_resources -from janus_swi import query_once, consult -from functools import cache +from collections import defaultdict from contextlib import contextmanager -from . util import order_prog, prog_is_recursive, rule_is_recursive, calc_rule_size, calc_prog_size, prog_hash, format_rule, format_literal, Literal +from functools import cache +from itertools import product +from typing import Any, Dict, Optional, cast, Tuple + +import pkg_resources from bitarray import bitarray, frozenbitarray from bitarray.util import ones -from collections import defaultdict -from itertools import product +from janus_swi import query_once, consult, cmd + +from .util import order_prog, prog_is_recursive, calc_rule_size, calc_prog_size, prog_hash, \ + format_rule, Literal, Settings + def format_literal_janus(literal): args = ','.join(f'_V{i}' for i in literal.arguments) return f'{literal.predicate}({args})' -def bool_query(query): - return query_once(query)['truth'] +class PopperTesterError(Exception): + pass class Tester(): + settings: Settings + cached_pos_covered: Dict[int, frozenbitarray] + neg_fact_str: str + neg_literal_set: frozenset + module_name: str + def __init__(self, settings): self.settings = settings + self.module_name = module_name = 'popper_tester_module_' + datetime.datetime.now().strftime('%Y%m%d_%H%M%S') bk_pl_path = self.settings.bk_file exs_pl_path = self.settings.ex_file test_pl_path = pkg_resources.resource_filename(__name__, "lp/test.pl") + results: Dict[str, Any] = query_once('use_module(library(modules))') + if not results['truth']: + raise PopperTesterError("Unable to use library(modules)") + + results = query_once("modules:prepare_temporary_module(X)", {"X": module_name}) + if not results['truth']: + raise PopperTesterError(f'Unable to create temporary module named {module_name}') if not settings.pi_enabled: - consult('prog', f':- dynamic {settings.head_literal.predicate}/{len(settings.head_literal.arguments)}.') + self.consult('prog', f':- dynamic {settings.head_literal.predicate}/{len(settings.head_literal.arguments)}.') for x in [exs_pl_path, bk_pl_path, test_pl_path]: if os.name == 'nt': # if on Windows, SWI requires escaped directory separators x = x.replace('\\', '\\\\') - consult(x) + self.consult(x) - query_once('load_examples') + self.query_once('load_examples') neg_literal = Literal('neg_fact', tuple(range(len(self.settings.head_literal.arguments)))) self.neg_fact_str = format_literal_janus(neg_literal) self.neg_literal_set = frozenset([neg_literal]) q = 'findall(_Atom2, (neg_index(_K, _Atom1), term_string(_Atom1, _Atom2)), S)' - res = query_once(q)['S'] + res = self.query_once(q)['S'] atoms = [] for x in res: x = x[:-1].split('(')[1].split(',') @@ -53,19 +72,33 @@ def __init__(self, settings): except Exception as e: print(e) - self.num_pos = query_once('findall(_K, pos_index(_K, _Atom), _S), length(_S, N)')['N'] - self.num_neg = query_once('findall(_K, neg_index(_K, _Atom), _S), length(_S, N)')['N'] + self.num_pos = self.query_once('findall(_K, pos_index(_K, _Atom), _S), length(_S, N)')['N'] + self.num_neg = self.query_once('findall(_K, neg_index(_K, _Atom), _S), length(_S, N)')['N'] self.pos_examples_ = ones(self.num_pos) self.cached_pos_covered = {} - self.cached_inconsistent = {} + # self.cached_inconsistent = {} -- never set or referenced. if self.settings.recursion_enabled: - query_once(f'assert(timeout({self.settings.eval_timeout})), fail') + self.query_once(f'assert(timeout({self.settings.eval_timeout})), fail') + + def consult(self, file: str, data: Optional[str] = None): + """Consult `file` (or the `data` string) in the Tester's module.""" + consult(file, data = data, module=self.module_name) + + def query_once(self, query: str, inputs: Optional[Dict[str, Any]] = None, error_on_failure: bool = False): + query_string = self.module_name + ":(" + query + ")" + res = query_once(query_string, inputs=inputs if inputs is not None else {}) + if error_on_failure and not res['truth']: + raise PopperTesterError(f'Unexpected query failure on: "{query_string}') + return res + + def bool_query(self, query) -> bool: + return cast(bool, self.query_once(query)['truth']) def janus_clear_cache(self): - return query_once('retractall(janus:py_call_cache(_String,_Input,_TV,_M,_Goal,_Dict,_Truth,_OutVars))') + return self.query_once('retractall(janus:py_call_cache(_String,_Input,_TV,_M,_Goal,_Dict,_Truth,_OutVars))') def parse_single_rule(self, prog): rule = next(iter(prog)) @@ -118,17 +151,17 @@ def test_prog(self, prog): if len(prog) == 1: atom_str, body_str = self.parse_single_rule(prog) q = f'findall(_ID, (pos_index(_ID, {atom_str}), ({body_str} -> true)), S)' - pos_covered = query_once(q)['S'] + pos_covered = self.query_once(q)['S'] inconsistent = False if self.num_neg > 0: q = f'neg_index(_ID, {atom_str}), {body_str}' - inconsistent = bool_query(q) + inconsistent = self.bool_query(q) else: with self.using(prog): - pos_covered = query_once('pos_covered(S)')['S'] + pos_covered = self.query_once('pos_covered(S)')['S'] inconsistent = False if self.num_neg > 0: - inconsistent = bool_query("inconsistent") + inconsistent = self.bool_query("inconsistent") pos_covered_bits = bitarray(self.num_pos) pos_covered_bits[pos_covered] = 1 @@ -136,7 +169,7 @@ def test_prog(self, prog): else: atom_str, body_str = self.parse_single_rule(prog) q = f'findall(_ID, (pos_index(_ID, {atom_str}),({body_str}-> true)), S)' - pos_covered = query_once(q)['S'] + pos_covered = self.query_once(q)['S'] pos_covered_bits = bitarray(self.num_pos) pos_covered_bits[pos_covered] = 1 pos_covered = frozenbitarray(pos_covered_bits) @@ -151,24 +184,23 @@ def test_prog(self, prog): head, body = next(iter(prog)) head, ordered_body = self.settings.order_rule((None, body | self.neg_literal_set)) q = ','.join(format_literal_janus(literal) for literal in ordered_body) - inconsistent = bool_query(q) + inconsistent = self.bool_query(q) self.cached_pos_covered[hash(prog)] = pos_covered return pos_covered, inconsistent - def test_prog_all(self, prog): - + def test_prog_all(self, prog) -> Tuple[frozenbitarray, frozenbitarray]: if len(prog) == 1: atom_str, body_str = self.parse_single_rule(prog) q = f'findall(_ID, (pos_index(_ID, {atom_str}), ({body_str}-> true)), S)' - pos_covered = query_once(q)['S'] + pos_covered = self.query_once(q)['S'] neg_covered = [] if self.num_neg > 0: q = f'findall(_ID, (neg_index(_ID, {atom_str}),({body_str}-> true)), S)' - neg_covered = query_once(q)['S'] + neg_covered = self.query_once(q)['S'] else: with self.using(prog): - res = query_once(f'pos_covered(S1), neg_covered(S2)') + res = self.query_once(f'pos_covered(S1), neg_covered(S2)') pos_covered = res['S1'] neg_covered = res['S2'] @@ -182,15 +214,15 @@ def test_prog_all(self, prog): return pos_covered, neg_covered - def test_prog_pos(self, prog): + def test_prog_pos(self, prog) -> frozenbitarray: if len(prog) == 1: atom_str, body_str = self.parse_single_rule(prog) q = f'findall(_ID, (pos_index(_ID, {atom_str}),({body_str}-> true)), S)' - pos_covered = query_once(q)['S'] + pos_covered = self.query_once(q)['S'] else: with self.using(prog): - pos_covered = query_once('pos_covered(S)')['S'] + pos_covered = self.query_once('pos_covered(S)')['S'] pos_covered_bits = bitarray(self.num_pos) pos_covered_bits[pos_covered] = 1 @@ -204,10 +236,10 @@ def test_prog_inconsistent(self, prog): if len(prog) == 1: atom_str, body_str = self.parse_single_rule(prog) q = f'neg_index(_ID, {atom_str}), {body_str}' - return bool_query(q) + return self.bool_query(q) with self.using(prog): - return bool_query("inconsistent") + return self.bool_query("inconsistent") def test_single_rule_neg_at_most_k(self, prog, k): @@ -215,7 +247,7 @@ def test_single_rule_neg_at_most_k(self, prog, k): if self.num_neg > 0: atom_str, body_str = self.parse_single_rule(prog) q = f'findfirstn(K, _ID, (neg_index(_ID, {atom_str}),({body_str}-> true)), S)' - neg_covered = query_once(q, {'K':k})['S'] + neg_covered = self.query_once(q, {'K':k})['S'] neg_covered_bits = bitarray(self.num_neg) neg_covered_bits[neg_covered] = 1 @@ -237,10 +269,10 @@ def get_pos_covered(self, prog): if len(prog) == 1: atom_str, body_str = self.parse_single_rule(prog) q = f'findall(_ID, (pos_index(_ID, {atom_str}),({body_str}-> true)), S)' - pos_covered = query_once(q)['S'] + pos_covered = self.query_once(q)['S'] else: with self.using(prog): - pos_covered = query_once('pos_covered(S)')['S'] + pos_covered = self.query_once('pos_covered(S)')['S'] pos_covered_bits = bitarray(self.num_pos) pos_covered_bits[pos_covered] = 1 @@ -275,15 +307,15 @@ def using(self, prog): str_prog.append(f':- dynamic {p}/{a}') str_prog = '.\n'.join(str_prog) +'.' - consult('prog', str_prog) + self.consult('prog', str_prog) yield for predicate, arity in current_clauses: args = ','.join(['_'] * arity) - x = query_once(f"retractall({predicate}({args}))") + self.query_once(f"retractall({predicate}({args}))", error_on_failure=True) - def is_non_functional(self, prog): + def is_non_functional(self, prog) -> bool: with self.using(prog): - return bool_query('non_functional') + return self.bool_query('non_functional') def reduce_inconsistent(self, program): if len(program) < 3: @@ -314,20 +346,20 @@ def is_sat(self, prog): _, ordered_body = self.parse_single_rule(prog) if self.settings.noisy: q = f'succeeds_k_times({new_head},({ordered_body}),K)' - return query_once(q, {'K':calc_rule_size(rule)})['truth'] + return self.query_once(q, {'K':calc_rule_size(rule)})['truth'] else: if self.settings.min_coverage == 1: q = f'{new_head},{ordered_body}' - return bool_query(q) + return self.bool_query(q) else: q = f'succeeds_k_times({new_head},({ordered_body}),K)' - return query_once(q, {'K':self.settings.min_coverage})['truth'] + return self.query_once(q, {'K':self.settings.min_coverage})['truth'] else: with self.using(prog): if self.settings.noisy: - return query_once(f'covers_at_least_k_pos(K)',{'K':calc_prog_size(prog)})['truth'] + return self.query_once(f'covers_at_least_k_pos(K)',{'K':calc_prog_size(prog)})['truth'] else: - return bool_query('sat') + return self.bool_query('sat') def is_body_sat(self, body): if len(body) > 1: @@ -335,7 +367,7 @@ def is_body_sat(self, body): else: q = format_literal_janus(next(iter(body))) - return bool_query(q) + return self.bool_query(q) def is_literal_redundant(self, body, literal): literal_str = format_literal_janus(literal) @@ -344,12 +376,12 @@ def is_literal_redundant(self, body, literal): else: x = format_literal_janus(next(iter(body))) q = f'{x}, \\+ {literal_str}' - return not bool_query(q) + return not self.bool_query(q) def diff_subs_single(self, literal): literal_str = format_literal_janus(literal) q = f'{self.neg_fact_str}, \\+ {literal_str}' - return not bool_query(q) + return not self.bool_query(q) def is_neg_reducible(self, body, literal): # AC: we do not cache as we can never see body + neg_literal again @@ -357,7 +389,7 @@ def is_neg_reducible(self, body, literal): body_str = ','.join(format_literal_janus(literal) for literal in ordered_body) literal_str = format_literal_janus(literal) q = f'{body_str}, \\+ {literal_str}' - return not bool_query(q) + return not self.bool_query(q) @cache def has_redundant_literal(self, prog): @@ -368,41 +400,41 @@ def has_redundant_literal(self, prog): else: c = f"[{','.join(tuple(format_literal_janus(lit) for lit in body))}]" q = f'redundant_literal({c})' - if query_once(q)['truth']: + if self.query_once(q)['truth']: # print(q, True) return True # print(q, False) return False - # # WE ASSUME THAT THERE IS A REUNDANT RULE - def find_redundant_rule_(self, prog): - prog_ = [] - for i, (head, body) in enumerate(prog): - c = f"{i}-[{','.join(('not_'+ format_literal(head),) + tuple(format_literal(lit) for lit in body))}]" - prog_.append(c) - prog_ = f"[{','.join(prog_)}]" - prog_ = janus_format_rule(prog_) - q = f'find_redundant_rule({prog_}, K1, K2)' - res = query_once(q) - k1 = res['K1'] - k2 = res['K2'] - return prog[k1], prog[k2] - - def find_redundant_rules(self, prog): - # assert(False) - # AC: if the overhead of this call becomes too high, such as when learning programs with lots of clauses, we can improve it by not comparing already compared clauses - base = [] - step = [] - for rule in prog: - if rule_is_recursive(rule): - step.append(rule) - else: - base.append(rule) - if len(base) > 1 and self.has_redundant_rule(base): - return self.find_redundant_rule_(base) - if len(step) > 1 and self.has_redundant_rule(step): - return self.find_redundant_rule_(step) - return None + # # WE ASSUME THAT THERE IS A REDUNDANT RULE + # def find_redundant_rule_(self, prog): + # prog_ = [] + # for i, (head, body) in enumerate(prog): + # c = f"{i}-[{','.join(('not_'+ format_literal(head),) + tuple(format_literal(lit) for lit in body))}]" + # prog_.append(c) + # prog_ = f"[{','.join(prog_)}]" + # prog_ = janus_format_rule(prog_) + # q = f'find_redundant_rule({prog_}, K1, K2)' + # res = query_once(q) + # k1 = res['K1'] + # k2 = res['K2'] + # return prog[k1], prog[k2] + # + # def find_redundant_rules(self, prog): + # # assert(False) + # # AC: if the overhead of this call becomes too high, such as when learning programs with lots of clauses, we can improve it by not comparing already compared clauses + # base = [] + # step = [] + # for rule in prog: + # if rule_is_recursive(rule): + # step.append(rule) + # else: + # base.append(rule) + # if len(base) > 1 and self.has_redundant_rule(base): + # return self.find_redundant_rule_(base) + # if len(step) > 1 and self.has_redundant_rule(step): + # return self.find_redundant_rule_(step) + # return None def find_pointless_relations(self): settings = self.settings @@ -417,7 +449,7 @@ def find_pointless_relations(self): query = f'current_predicate({p}/{pa})' try: - if not query_once(query)['truth']: + if not self.query_once(query)['truth']: pointless.add((p, pa)) # print(p, pa) missing.add(p) @@ -449,7 +481,7 @@ def find_pointless_relations(self): # print(query1) # print(query2) try: - if query_once(query1)['truth'] or query_once(query2)['truth']: + if self.query_once(query1)['truth'] or self.query_once(query2)['truth']: continue except Exception as Err: print('ERROR detecting pointless relations', Err) @@ -483,6 +515,11 @@ def find_pointless_relations(self): # return frozenset((p, arities[p]) for p in pointless) return pointless + def destroy_prolog_module(self): + if not query_once('modules:destroy_module(Module)', {'Module': self.module_name})['truth']: + raise PopperTesterError(f'module {self.module_name} not destroyed') + + def deduce_neg_example_recalls(settings, atoms): # Jan Struyf, Hendrik Blockeel: Query Optimization in Inductive Logic Programming by Reordering Literals. ILP 2003: 329-346