Skip to content

Rust: Resolve function calls to traits methods #19575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion rust/ql/lib/codeql/rust/elements/internal/CallExprImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ private import codeql.rust.elements.PathExpr
module Impl {
private import rust
private import codeql.rust.internal.PathResolution as PathResolution
private import codeql.rust.internal.TypeInference as TypeInference

pragma[nomagic]
Path getFunctionPath(CallExpr ce) { result = ce.getFunction().(PathExpr).getPath() }
Expand All @@ -36,7 +37,14 @@ module Impl {
class CallExpr extends Generated::CallExpr {
override string toStringImpl() { result = this.getFunction().toAbbreviatedString() + "(...)" }

override Callable getStaticTarget() { result = getResolvedFunction(this) }
override Callable getStaticTarget() {
// If this call is to a trait method, e.g., `Trait::foo(bar)`, then check
// if type inference can resolve it to the correct trait implementation.
result = TypeInference::resolveMethodCallTarget(this)
or
not exists(TypeInference::resolveMethodCallTarget(this)) and
result = getResolvedFunction(this)
}

/** Gets the struct that this call resolves to, if any. */
Struct getStruct() { result = getResolvedFunction(this) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@ private import codeql.rust.internal.TypeInference
* be referenced directly.
*/
module Impl {
private predicate isInherentImplFunction(Function f) {
f = any(Impl impl | not impl.hasTrait()).(ImplItemNode).getAnAssocItem()
}

private predicate isTraitImplFunction(Function f) {
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
}

// the following QLdoc is generated: if you need to edit it, do it in the schema file
/**
* A method call expression. For example:
Expand All @@ -31,38 +23,7 @@ module Impl {
* ```
*/
class MethodCallExpr extends Generated::MethodCallExpr {
private Function getStaticTargetFrom(boolean fromSource) {
result = resolveMethodCallExpr(this) and
(if result.fromSource() then fromSource = true else fromSource = false) and
(
// prioritize inherent implementation methods first
isInherentImplFunction(result)
or
not isInherentImplFunction(resolveMethodCallExpr(this)) and
(
// then trait implementation methods
isTraitImplFunction(result)
or
not isTraitImplFunction(resolveMethodCallExpr(this)) and
(
// then trait methods with default implementations
result.hasBody()
or
// and finally trait methods without default implementations
not resolveMethodCallExpr(this).hasBody()
)
)
)
}

override Function getStaticTarget() {
// Functions in source code also gets extracted as library code, due to
// this duplication we prioritize functions from source code.
result = this.getStaticTargetFrom(true)
or
not exists(this.getStaticTargetFrom(true)) and
result = this.getStaticTargetFrom(false)
}
override Function getStaticTarget() { result = resolveMethodCallTarget(this) }

private string toStringPart(int index) {
index = 0 and
Expand Down
257 changes: 179 additions & 78 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
Declaration getTarget() {
result = CallExprImpl::getResolvedFunction(this)
or
result = resolveMethodCallExpr(this) // mutual recursion; resolving method calls requires resolving types and vice versa
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
}
}

Expand Down Expand Up @@ -1000,6 +1000,150 @@ private StructType inferLiteralType(LiteralExpr le) {
)
}

private module MethodCall {
/** An expression that calls a method. */
abstract private class MethodCallImpl extends Expr {
/** Gets the name of the method targeted. */
abstract string getMethodName();

/** Gets the number of arguments _excluding_ the `self` argument. */
abstract int getArity();

/** Gets the trait targeted by this method call, if any. */
Trait getTrait() { none() }

/** Gets the type of the receiver of the method call at `path`. */
abstract Type getTypeAt(TypePath path);
}

final class MethodCall = MethodCallImpl;

private class MethodCallExprMethodCall extends MethodCallImpl instanceof MethodCallExpr {
override string getMethodName() { result = super.getIdentifier().getText() }

override int getArity() { result = super.getArgList().getNumberOfArgs() }

pragma[nomagic]
override Type getTypeAt(TypePath path) {
exists(TypePath path0 | result = inferType(super.getReceiver(), path0) |
path0.isCons(TRefTypeParameter(), path)
or
not path0.isCons(TRefTypeParameter(), _) and
not (path0.isEmpty() and result = TRefType()) and
path = path0
)
}
}

private class CallExprMethodCall extends MethodCallImpl instanceof CallExpr {
TraitItemNode trait;
string methodName;
Expr receiver;

CallExprMethodCall() {
receiver = this.getArgList().getArg(0) and
exists(Path path, Function f |
path = this.getFunction().(PathExpr).getPath() and
f = resolvePath(path) and
f.getParamList().hasSelfParam() and
trait = resolvePath(path.getQualifier()) and
trait.getAnAssocItem() = f and
path.getSegment().getIdentifier().getText() = methodName
)
}

override string getMethodName() { result = methodName }

override int getArity() { result = super.getArgList().getNumberOfArgs() - 1 }

override Trait getTrait() { result = trait }

pragma[nomagic]
override Type getTypeAt(TypePath path) { result = inferType(receiver, path) }
}
}

import MethodCall

/**
* Holds if a method for `type` with the name `name` and the arity `arity`
* exists in `impl`.
*/
private predicate methodCandidate(Type type, string name, int arity, Impl impl) {
type = impl.getSelfTy().(TypeMention).resolveType() and
exists(Function f |
f = impl.(ImplItemNode).getASuccessor(name) and
f.getParamList().hasSelfParam() and
arity = f.getParamList().getNumberOfParams()
)
}

/**
* Holds if a method for `type` for `trait` with the name `name` and the arity
* `arity` exists in `impl`.
*/
pragma[nomagic]
private predicate methodCandidateTrait(Type type, Trait trait, string name, int arity, Impl impl) {
trait = resolvePath(impl.(ImplItemNode).getTraitPath()) and
methodCandidate(type, name, arity, impl)
}

private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
pragma[nomagic]
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
exists(Type rootType, string name, int arity |
rootType = mc.getTypeAt(TypePath::nil()) and
name = mc.getMethodName() and
arity = mc.getArity() and
constraint = impl.(ImplTypeAbstraction).getSelfTy()
|
methodCandidateTrait(rootType, mc.getTrait(), name, arity, impl)
or
not exists(mc.getTrait()) and
methodCandidate(rootType, name, arity, impl)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new disjunct and methodCandidateTrait is the crux of the change.

)
}

predicate relevantTypeMention(TypeMention constraint) {
exists(Impl impl | methodCandidate(_, _, _, impl) and constraint = impl.getSelfTy())
}
}

bindingset[item, name]
pragma[inline_late]
private Function getMethodSuccessor(ItemNode item, string name) {
result = item.getASuccessor(name)
}

bindingset[tp, name]
pragma[inline_late]
private Function getTypeParameterMethod(TypeParameter tp, string name) {
result = getMethodSuccessor(tp.(TypeParamTypeParameter).getTypeParam(), name)
or
result = getMethodSuccessor(tp.(SelfTypeParameter).getTrait(), name)
}

/** Gets a method from an `impl` block that matches the method call `mc`. */
private Function getMethodFromImpl(MethodCall mc) {
exists(Impl impl |
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
result = getMethodSuccessor(impl, mc.getMethodName())
)
}

/**
* Gets a method that the method call `mc` resolves to based on type inference,
* if any.
*/
private Function inferMethodCallTarget(MethodCall mc) {
// The method comes from an `impl` block targeting the type of the receiver.
result = getMethodFromImpl(mc)
or
// The type of the receiver is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
}

cached
private module Cached {
private import codeql.rust.internal.CachedStages
Expand All @@ -1026,90 +1170,47 @@ private module Cached {
)
}

private class ReceiverExpr extends Expr {
MethodCallExpr mce;

ReceiverExpr() { mce.getReceiver() = this }

string getField() { result = mce.getIdentifier().getText() }

int getNumberOfArgs() { result = mce.getArgList().getNumberOfArgs() }

pragma[nomagic]
Type getTypeAt(TypePath path) {
exists(TypePath path0 | result = inferType(this, path0) |
path0.isCons(TRefTypeParameter(), path)
or
not path0.isCons(TRefTypeParameter(), _) and
not (path0.isEmpty() and result = TRefType()) and
path = path0
)
}
}

/** Holds if a method for `type` with the name `name` and the arity `arity` exists in `impl`. */
pragma[nomagic]
private predicate methodCandidate(Type type, string name, int arity, Impl impl) {
type = impl.getSelfTy().(TypeReprMention).resolveType() and
exists(Function f |
f = impl.(ImplItemNode).getASuccessor(name) and
f.getParamList().hasSelfParam() and
arity = f.getParamList().getNumberOfParams()
)
}

private module IsInstantiationOfInput implements IsInstantiationOfInputSig<ReceiverExpr> {
pragma[nomagic]
predicate potentialInstantiationOf(
ReceiverExpr receiver, TypeAbstraction impl, TypeMention constraint
) {
methodCandidate(receiver.getTypeAt(TypePath::nil()), receiver.getField(),
receiver.getNumberOfArgs(), impl) and
constraint = impl.(ImplTypeAbstraction).getSelfTy()
}

predicate relevantTypeMention(TypeMention constraint) {
exists(Impl impl | methodCandidate(_, _, _, impl) and constraint = impl.getSelfTy())
}
}

bindingset[item, name]
pragma[inline_late]
private Function getMethodSuccessor(ItemNode item, string name) {
result = item.getASuccessor(name)
private predicate isInherentImplFunction(Function f) {
f = any(Impl impl | not impl.hasTrait()).(ImplItemNode).getAnAssocItem()
}

bindingset[tp, name]
pragma[inline_late]
private Function getTypeParameterMethod(TypeParameter tp, string name) {
result = getMethodSuccessor(tp.(TypeParamTypeParameter).getTypeParam(), name)
or
result = getMethodSuccessor(tp.(SelfTypeParameter).getTrait(), name)
private predicate isTraitImplFunction(Function f) {
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
}

/**
* Gets the method from an `impl` block with an implementing type that matches
* the type of `receiver` and with a name of the method call in which
* `receiver` occurs, if any.
*/
private Function getMethodFromImpl(ReceiverExpr receiver) {
exists(Impl impl |
IsInstantiationOf<ReceiverExpr, IsInstantiationOfInput>::isInstantiationOf(receiver, impl, _) and
result = getMethodSuccessor(impl, receiver.getField())
private Function resolveMethodCallTargetFrom(MethodCall mc, boolean fromSource) {
result = inferMethodCallTarget(mc) and
(if result.fromSource() then fromSource = true else fromSource = false) and
(
// prioritize inherent implementation methods first
isInherentImplFunction(result)
or
not isInherentImplFunction(inferMethodCallTarget(mc)) and
(
// then trait implementation methods
isTraitImplFunction(result)
or
not isTraitImplFunction(inferMethodCallTarget(mc)) and
(
// then trait methods with default implementations
result.hasBody()
or
// and finally trait methods without default implementations
not inferMethodCallTarget(mc).hasBody()
)
)
)
}

/** Gets a method that the method call `mce` resolves to, if any. */
/** Gets a method that the method call `mc` resolves to, if any. */
cached
Function resolveMethodCallExpr(MethodCallExpr mce) {
exists(ReceiverExpr receiver | mce.getReceiver() = receiver |
// The method comes from an `impl` block targeting the type of `receiver`.
result = getMethodFromImpl(receiver)
or
// The type of `receiver` is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(receiver.getTypeAt(TypePath::nil()), receiver.getField())
)
Function resolveMethodCallTarget(MethodCall mc) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This resolveMethodCallTarget was previously getStaticTarget inside MethodCallExpr. But since we now need for CallExpr as well, I've moved it in here.

// Functions in source code also gets extracted as library code, due to
// this duplication we prioritize functions from source code.
result = resolveMethodCallTargetFrom(mc, true)
or
not exists(resolveMethodCallTargetFrom(mc, true)) and
result = resolveMethodCallTargetFrom(mc, false)
}

pragma[inline]
Expand Down Expand Up @@ -1243,6 +1344,6 @@ private module Debug {

Function debugResolveMethodCallExpr(MethodCallExpr mce) {
mce = getRelevantLocatable() and
result = resolveMethodCallExpr(mce)
result = resolveMethodCallTarget(mce)
}
}
Loading