@@ -272,28 +272,18 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
272272 end
273273end
274274
275- struct ConditionContext{Names, Values,Ctx<: AbstractContext } <: AbstractContext
275+ struct ConditionContext{Values,Ctx<: AbstractContext } <: AbstractContext
276276 values:: Values
277277 context:: Ctx
278-
279- function ConditionContext {Values} (
280- values:: Values , context:: AbstractContext
281- ) where {names,Values<: NamedTuple{names} }
282- return new {names,typeof(values),typeof(context)} (values, context)
283- end
284278end
285279
286- function ConditionContext (values:: NamedTuple )
287- return ConditionContext (values, DefaultContext ())
288- end
289- function ConditionContext (values:: NamedTuple , context:: AbstractContext )
290- return ConditionContext {typeof(values)} (values, context)
291- end
280+ const NamedConditionContext{Names} = ConditionContext{<: NamedTuple{Names} }
281+ const DictConditionContext = ConditionContext{<: AbstractDict }
282+
283+ ConditionContext (values) = ConditionContext (values, DefaultContext ())
292284
293285# Try to avoid nested `ConditionContext`.
294- function ConditionContext (
295- values:: NamedTuple{Names} , context:: ConditionContext
296- ) where {Names}
286+ function ConditionContext (values:: NamedTuple , context:: NamedConditionContext )
297287 # Note that this potentially overrides values from `context`, thus giving
298288 # precedence to the outmost `ConditionContext`.
299289 return ConditionContext (merge (context. values, values), childcontext (context))
@@ -303,7 +293,7 @@ function Base.show(io::IO, context::ConditionContext)
303293 return print (io, " ConditionContext($(context. values) , $(childcontext (context)) )" )
304294end
305295
306- NodeTrait (context :: ConditionContext ) = IsParent ()
296+ NodeTrait (:: ConditionContext ) = IsParent ()
307297childcontext (context:: ConditionContext ) = context. context
308298setchildcontext (parent:: ConditionContext , child) = ConditionContext (parent. values, child)
309299
@@ -313,14 +303,9 @@ setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.value
313303Return `true` if `vn` is found in `context`.
314304"""
315305hasvalue (context, vn) = false
316-
317- function hasvalue (context:: ConditionContext{vars} , vn:: VarName{sym} ) where {vars,sym}
318- return sym in vars
319- end
320- function hasvalue (
321- context:: ConditionContext{vars} , vn:: AbstractArray{<:VarName{sym}}
322- ) where {vars,sym}
323- return sym in vars
306+ hasvalue (context:: ConditionContext , vn:: VarName ) = nested_haskey (context. values, vn)
307+ function hasvalue (context:: ConditionContext , vns:: AbstractArray{<:VarName} )
308+ return all (Base. Fix1 (nested_haskey, context. values), vns)
324309end
325310
326311"""
@@ -331,7 +316,8 @@ Return value of `vn` in `context`.
331316function getvalue (context:: AbstractContext , vn)
332317 return error (" context $(context) does not contain value for $vn " )
333318end
334- getvalue (context:: ConditionContext , vn) = get (context. values, vn)
319+ getvalue (context:: NamedConditionContext , vn) = get (context. values, vn)
320+ getvalue (context:: ConditionContext , vn) = nested_getindex (context. values, vn)
335321
336322"""
337323 hasvalue_nested(context, vn)
@@ -386,15 +372,33 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default.
386372
387373See also: [`decondition`](@ref)
388374"""
389- AbstractPPL. condition (; values... ) = condition (DefaultContext (), NamedTuple (values))
375+ AbstractPPL. condition (; values... ) = condition (NamedTuple (values))
390376AbstractPPL. condition (values:: NamedTuple ) = condition (DefaultContext (), values)
377+ function AbstractPPL. condition (value:: Pair{<:VarName} , values:: Pair{<:VarName} ...)
378+ return condition ((value, values... ))
379+ end
380+ function AbstractPPL. condition (values:: NTuple{<:Any,<:Pair{<:VarName}} )
381+ return condition (DefaultContext (), values)
382+ end
391383AbstractPPL. condition (context:: AbstractContext , values:: NamedTuple{()} ) = context
392- function AbstractPPL. condition (context:: AbstractContext , values:: NamedTuple )
384+ function AbstractPPL. condition (
385+ context:: AbstractContext , values:: Union{AbstractDict,NamedTuple}
386+ )
393387 return ConditionContext (values, context)
394388end
395389function AbstractPPL. condition (context:: AbstractContext ; values... )
396390 return condition (context, NamedTuple (values))
397391end
392+ function AbstractPPL. condition (
393+ context:: AbstractContext , value:: Pair{<:VarName} , values:: Pair{<:VarName} ...
394+ )
395+ return condition (context, (value, values... ))
396+ end
397+ function AbstractPPL. condition (
398+ context:: AbstractContext , values:: NTuple{<:Any,Pair{<:VarName}}
399+ )
400+ return condition (context, Dict (values))
401+ end
398402
399403"""
400404 decondition(context::AbstractContext, syms...)
@@ -430,6 +434,19 @@ function AbstractPPL.decondition(context::ConditionContext, sym, syms...)
430434 )
431435end
432436
437+ function AbstractPPL. decondition (
438+ context:: NamedConditionContext , vn:: VarName{sym}
439+ ) where {sym}
440+ return condition (
441+ decondition (childcontext (context), vn), BangBang. delete!! (context. values, sym)
442+ )
443+ end
444+ function AbstractPPL. decondition (context:: ConditionContext , vn:: VarName )
445+ return condition (
446+ decondition (childcontext (context), vn), BangBang. delete!! (context. values, vn)
447+ )
448+ end
449+
433450"""
434451 conditioned(context::AbstractContext)
435452
0 commit comments