Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public struct RayInfo
public Vector3 worldEnd;
public bool castHit;
public float hitFraction;
public float castRadius;
}

public void Reset()
Expand Down Expand Up @@ -145,22 +146,22 @@ public virtual SensorCompressionType GetCompressionType()
/// nothing was hit.
///
/// </summary>
/// <param name="rayLength"></param>
/// <param name="unscaledRayLength"></param>
/// <param name="rayAngles">List of angles (in degrees) used to define the rays. 90 degrees is considered
/// "forward" relative to the game object</param>
/// <param name="detectableObjects">List of tags which correspond to object types agent can see</param>
/// <param name="startOffset">Starting height offset of ray from center of agent.</param>
/// <param name="endOffset">Ending height offset of ray from center of agent.</param>
/// <param name="castRadius">Radius of the sphere to use for spherecasting. If 0 or less, rays are used
/// <param name="unscaledCastRadius">Radius of the sphere to use for spherecasting. If 0 or less, rays are used
/// instead - this may be faster, especially for complex environments.</param>
/// <param name="transform">Transform of the GameObject</param>
/// <param name="castType">Whether to perform the casts in 2D or 3D.</param>
/// <param name="perceptionBuffer">Output array of floats. Must be (num rays) * (num tags + 2) in size.</param>
/// <param name="debugInfo">Optional debug information output, only used by RayPerceptionSensor.</param>
///
public static void PerceiveStatic(float rayLength,
public static void PerceiveStatic(float unscaledRayLength,
IReadOnlyList<float> rayAngles, IReadOnlyList<string> detectableObjects,
float startOffset, float endOffset, float castRadius,
float startOffset, float endOffset, float unscaledCastRadius,
Transform transform, CastType castType, float[] perceptionBuffer,
int layerMask = Physics.DefaultRaycastLayers,
DebugDisplayInfo debugInfo = null)
Expand All @@ -185,20 +186,26 @@ public static void PerceiveStatic(float rayLength,
if (castType == CastType.Cast3D)
{
startPositionLocal = new Vector3(0, startOffset, 0);
endPositionLocal = PolarToCartesian3D(rayLength, angle);
endPositionLocal = PolarToCartesian3D(unscaledRayLength, angle);
endPositionLocal.y += endOffset;
}
else
{
// Vector2s here get converted to Vector3s (and back to Vector2s for casting)
startPositionLocal = new Vector2();
endPositionLocal = PolarToCartesian2D(rayLength, angle);
endPositionLocal = PolarToCartesian2D(unscaledRayLength, angle);
}

var startPositionWorld = transform.TransformPoint(startPositionLocal);
var endPositionWorld = transform.TransformPoint(endPositionLocal);

var rayDirection = endPositionWorld - startPositionWorld;
// If there is non-unity scale, |rayDirection| will be different from rayLength.
// We want to use this transformed ray length for determining cast length, hit fraction etc.
// We also it to scale up or down the sphere or circle radii
var scaledRayLength = rayDirection.magnitude;
// Avoid 0/0 if unscaledRayLength is 0
var scaledCastRadius = unscaledRayLength > 0 ? unscaledCastRadius * scaledRayLength / unscaledRayLength : unscaledCastRadius;

// Do the cast and assign the hit information for each detectable object.
// sublist[0 ] <- did hit detectableObjects[0]
Expand All @@ -214,31 +221,33 @@ public static void PerceiveStatic(float rayLength,
if (castType == CastType.Cast3D)
{
RaycastHit rayHit;
if (castRadius > 0f)
if (scaledCastRadius > 0f)
{
castHit = Physics.SphereCast(startPositionWorld, castRadius, rayDirection, out rayHit,
rayLength, layerMask);
castHit = Physics.SphereCast(startPositionWorld, scaledCastRadius, rayDirection, out rayHit,
scaledRayLength, layerMask);
}
else
{
castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit,
rayLength, layerMask);
scaledRayLength, layerMask);
}

hitFraction = castHit ? rayHit.distance / rayLength : 1.0f;
// If scaledRayLength is 0, we still could have a hit with sphere casts (maybe?).
// To avoid 0/0, set the fraction to 0.
hitFraction = castHit ? (scaledRayLength > 0 ? rayHit.distance / scaledRayLength : 0.0f) : 1.0f;
hitObject = castHit ? rayHit.collider.gameObject : null;
}
else
{
RaycastHit2D rayHit;
if (castRadius > 0f)
if (scaledCastRadius > 0f)
{
rayHit = Physics2D.CircleCast(startPositionWorld, castRadius, rayDirection,
rayLength, layerMask);
rayHit = Physics2D.CircleCast(startPositionWorld, scaledCastRadius, rayDirection,
scaledRayLength, layerMask);
}
else
{
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, rayLength, layerMask);
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, scaledRayLength, layerMask);
}

castHit = rayHit;
Expand All @@ -254,6 +263,7 @@ public static void PerceiveStatic(float rayLength,
debugInfo.rayInfos[rayIndex].worldEnd = endPositionWorld;
debugInfo.rayInfos[rayIndex].castHit = castHit;
debugInfo.rayInfos[rayIndex].hitFraction = hitFraction;
debugInfo.rayInfos[rayIndex].castRadius = scaledCastRadius;
}
else if (Application.isEditor)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ public void OnDrawGizmos()
// hit fraction ^2 will shift "far" hits closer to the hit color
var lerpT = rayInfo.hitFraction * rayInfo.hitFraction;
var color = Color.Lerp(rayHitColor, rayMissColor, lerpT);
color.a = alpha;
color.a *= alpha;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran into this when trying to get a screenshot; I wanted to hide the "higher" raycast component by setting its alpha to 0

Gizmos.color = color;
Gizmos.DrawRay(startPositionWorld, rayDirection);

// Draw the hit point as a sphere. If using rays to cast (0 radius), use a small sphere.
if (rayInfo.castHit)
{
var hitRadius = Mathf.Max(sphereCastRadius, .05f);
var hitRadius = Mathf.Max(rayInfo.castRadius, .05f);
Gizmos.DrawWireSphere(startPositionWorld + rayDirection, hitRadius);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,89 @@ public void TestRayFilter()
}
}
}

[Test]
public void TestRaycastsScaled()
{
SetupScene();
var obj = new GameObject("agent");
var perception = obj.AddComponent<RayPerceptionSensorComponent3D>();
obj.transform.localScale = new Vector3(2, 2,2 );

perception.raysPerDirection = 0;
perception.maxRayDegrees = 45;
perception.rayLength = 20;
perception.detectableTags = new List<string>();
perception.detectableTags.Add(k_CubeTag);

var radii = new[] { 0f, .5f };
foreach (var castRadius in radii)
{
perception.sphereCastRadius = castRadius;
var sensor = perception.CreateSensor();

var expectedObs = (2 * perception.raysPerDirection + 1) * (perception.detectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
var outputBuffer = new float[expectedObs];

WriteAdapter writer = new WriteAdapter();
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);

var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

// Expected hits:
// ray 0 should hit the cube at roughly 1/4 way
//
Assert.AreEqual(1.0f, outputBuffer[0]); // hit cube
Assert.AreEqual(0.0f, outputBuffer[1]); // missed unknown tag

// Hit is at z=9.0 in world space, ray length was 20
// But scale increases the cast size and the ray length
var scaledRayLength = 2 * perception.rayLength;
var scaledCastRadius = 2 * castRadius;
Assert.That(
outputBuffer[2], Is.EqualTo((9.5f - scaledCastRadius) / scaledRayLength).Within(.0005f)
);
}
}

[Test]
public void TestRayZeroLength()
{
// Place the cube touching the origin
var cube = GameObject.CreatePrimitive(PrimitiveType.Cube);
cube.transform.position = new Vector3(0, 0, .5f);
cube.tag = k_CubeTag;

Physics.SyncTransforms();

var obj = new GameObject("agent");
var perception = obj.AddComponent<RayPerceptionSensorComponent3D>();
perception.raysPerDirection = 0;
perception.rayLength = 0.0f;
perception.sphereCastRadius = .5f;
perception.detectableTags = new List<string>();
perception.detectableTags.Add(k_CubeTag);

{
// Set the layer mask to either the default, or one that ignores the close cube's layer

var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.raysPerDirection + 1) * (perception.detectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
var outputBuffer = new float[expectedObs];

WriteAdapter writer = new WriteAdapter();
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);

var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);

// hit fraction is arbitrary but should be finite in [0,1]
Assert.GreaterOrEqual(outputBuffer[2], 0.0f);
Assert.LessOrEqual(outputBuffer[2], 1.0f);
}
}
}
}