Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,13 @@ public override StatementSyntax GenerateUnmanagedToManagedByValueOutMarshalState
private List<StatementSyntax> GenerateElementStages(
StubIdentifierContext context,
IBoundMarshallingGenerator elementMarshaller,
out LinearCollectionElementIdentifierContext elementSetupSubContext,
out string indexer,
params StubIdentifierContext.Stage[] stagesToGeneratePerElement)
{
string managedSpanIdentifier = MarshallerHelpers.GetManagedSpanIdentifier(CollectionSource.TypeInfo, context);
string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(CollectionSource.TypeInfo, context);
StubCodeContext elementCodeContext = StubCodeContext.CreateElementMarshallingContext(CollectionSource.CodeContext);
elementSetupSubContext = new LinearCollectionElementIdentifierContext(
LinearCollectionElementIdentifierContext elementSetupSubContext = new(
context,
elementMarshaller.TypeInfo,
managedSpanIdentifier,
Expand All @@ -485,13 +485,74 @@ private List<StatementSyntax> GenerateElementStages(
CodeEmitOptions = context.CodeEmitOptions
};

indexer = elementSetupSubContext.IndexerIdentifier;

StubIdentifierContext identifierContext = elementSetupSubContext;

if (elementMarshaller.NativeType is PointerTypeInfo)
{
identifierContext = new GenericFriendlyPointerIdentifierContext(elementSetupSubContext, elementMarshaller.TypeInfo, $"{nativeSpanIdentifier}__{indexer}")
{
CodeEmitOptions = elementSetupSubContext.CodeEmitOptions,
};
}

List<StatementSyntax> elementStatements = [];
foreach (StubIdentifierContext.Stage stage in stagesToGeneratePerElement)
{
var elementSubContext = elementSetupSubContext with { CurrentStage = stage };
elementStatements.AddRange(elementMarshaller.Generate(elementSubContext));
var elementIdentifierContext = identifierContext with { CurrentStage = stage };
elementStatements.AddRange(elementMarshaller.Generate(elementIdentifierContext));
}

if (elementStatements.Count == 0)
{
return [];
}

// Only add the setup stage if we generated code for other stages.
elementStatements.InsertRange(0, elementMarshaller.Generate(identifierContext with { CurrentStage = StubIdentifierContext.Stage.Setup }));

if (identifierContext is not GenericFriendlyPointerIdentifierContext)
{
// If we didn't need to account for pointer types, we have the statements we need.
return elementStatements;
}
return elementStatements;

// If we have the generic friendly pointer context, we need to declare the special identifier and assign to/from it.

// <native_type> <native_exactType> = (<native_type>)<native_collection>[i];
StatementSyntax exactTypeDeclaration =
LocalDeclarationStatement(
VariableDeclaration(
elementMarshaller.NativeType.Syntax,
SingletonSeparatedList(
VariableDeclarator(
Identifier(identifierContext.GetIdentifiers(elementMarshaller.TypeInfo).native))
.WithInitializer(
EqualsValueClause(
CastExpression(elementMarshaller.NativeType.Syntax,
ParseExpression(elementSetupSubContext.GetIdentifiers(elementMarshaller.TypeInfo).native)))))));

if (stagesToGeneratePerElement.Any(stage => stage is StubIdentifierContext.Stage.Marshal or StubIdentifierContext.Stage.PinnedMarshal))
{
// <native_collection>[i] = (<generic_compatible_native_type>)<native_exactType>;
StatementSyntax propagateResult = AssignmentStatement(
ParseExpression(elementSetupSubContext.GetIdentifiers(elementMarshaller.TypeInfo).native),
CastExpression(TypeSyntaxes.System_IntPtr,
IdentifierName(identifierContext.GetIdentifiers(elementMarshaller.TypeInfo).native)));

return
[
exactTypeDeclaration,
..elementStatements,
propagateResult
];
}

return [
exactTypeDeclaration,
..elementStatements
];
}

private StatementSyntax GenerateContentsMarshallingStatement(
Expand All @@ -500,22 +561,14 @@ private StatementSyntax GenerateContentsMarshallingStatement(
IBoundMarshallingGenerator elementMarshaller,
params StubIdentifierContext.Stage[] stagesToGeneratePerElement)
{
var elementStatements = GenerateElementStages(context, elementMarshaller, out var elementSetupSubContext, stagesToGeneratePerElement);
var elementStatements = GenerateElementStages(context, elementMarshaller, out string indexer, stagesToGeneratePerElement);

if (elementStatements.Count != 0)
{
StatementSyntax marshallingStatement = Block(
List(elementMarshaller.Generate(elementSetupSubContext)
.Concat(elementStatements)));

if (elementMarshaller.NativeType is PointerTypeInfo nativeTypeInfo)
{
PointerNativeTypeAssignmentRewriter rewriter = new(elementSetupSubContext.GetIdentifiers(elementMarshaller.TypeInfo).native, (PointerTypeSyntax)nativeTypeInfo.Syntax);
marshallingStatement = (StatementSyntax)rewriter.Visit(marshallingStatement);
}
StatementSyntax marshallingStatement = Block(elementStatements);

// Iterate through the elements of the native collection to marshal them
var forLoop = ForLoop(elementSetupSubContext.IndexerIdentifier, lengthExpression)
var forLoop = ForLoop(indexer, lengthExpression)
.WithStatement(marshallingStatement);
// If we're tracking LastIndexMarshalled, increment that each iteration as well.
if (UsesLastIndexMarshalled(CollectionSource.TypeInfo, CollectionSource.CodeContext) && stagesToGeneratePerElement.Contains(StubIdentifierContext.Stage.Marshal))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;

namespace Microsoft.Interop
{
internal sealed record GenericFriendlyPointerIdentifierContext : StubIdentifierContext
{
private readonly StubIdentifierContext _innerContext;
private readonly TypePositionInfo _adaptedInfo;
private readonly string _nativeIdentifier;

public GenericFriendlyPointerIdentifierContext(StubIdentifierContext inner, TypePositionInfo adaptedInfo, string baseIdentifier)
{
_innerContext = inner;
_adaptedInfo = adaptedInfo;
_nativeIdentifier = baseIdentifier + "_exactType";
CurrentStage = inner.CurrentStage;
}

public override (string managed, string native) GetIdentifiers(TypePositionInfo info)
{
if (info.PositionsEqual(_adaptedInfo))
{
(string managed, _) = _innerContext.GetIdentifiers(info);
return (managed, _nativeIdentifier);
}

return _innerContext.GetIdentifiers(info);
}

public override string GetAdditionalIdentifier(TypePositionInfo info, string name) => _innerContext.GetAdditionalIdentifier(info, name);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Microsoft.Interop.SyntaxFactoryExtensions;

namespace Microsoft.Interop
{
Expand Down Expand Up @@ -71,18 +72,26 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(StubIdentifierCont
yield break;
}

// <nativeIdentifier> = <convertToUnmanaged>;
var assignment = AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(nativeIdentifier),
convertToUnmanaged);

ExpressionSyntax assignment;

if (unmanagedType is PointerTypeInfo pointer)
// For some of our exception marshallers, our marshaller returns nint for pointer types.
// As a result, we need to insert a cast here in case we're in that scenario (which we can't detect specifically).
if (unmanagedType is PointerTypeInfo ptrType)
{
// <nativeIdentifier> = (<nativeType>)<convertToUnmanaged>;
assignment = AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(nativeIdentifier),
CastExpression(ptrType.Syntax, convertToUnmanaged));
}
else
{
var rewriter = new PointerNativeTypeAssignmentRewriter(assignment.Right.ToString(), (PointerTypeSyntax)pointer.Syntax);
assignment = (AssignmentExpressionSyntax)rewriter.Visit(assignment);
// <nativeIdentifier> = <convertToUnmanaged>;
assignment = AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(nativeIdentifier),
convertToUnmanaged);
}

Comment on lines -74 to +94
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need anything like this for native marshalling?

Copy link
Member Author

@jkoritzinsky jkoritzinsky Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need this for unmarshalling because we only hit it in a specific (due to us artificially resolving the marshaller) scenario. We don't put ourselves into that situation in the native->managed case.

yield return ExpressionStatement(assignment);
}

Expand Down