Skip to content

Commit 7d3d2a1

Browse files
committed
Add Differentiable Physics: Mass-Spring System example
1 parent 65afde6 commit 7d3d2a1

File tree

3 files changed

+169
-0
lines changed

3 files changed

+169
-0
lines changed

differentiable_physics/mass_spring.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
import argparse
5+
6+
7+
class MassSpringSystem(nn.Module):
8+
def __init__(self, num_particles, springs, mass=1.0, dt=0.01, gravity=9.81, device="cpu"):
9+
super().__init__()
10+
self.device = device
11+
self.mass = mass
12+
self.springs = springs
13+
self.dt = dt
14+
self.gravity = gravity
15+
16+
# 🛑 Particle 0 fixed at origin
17+
self.initial_position_0 = torch.tensor([0.0, 0.0], device=device)
18+
19+
# 🛑 Only remaining particles are trainable
20+
self.initial_positions_rest = nn.Parameter(torch.randn(num_particles - 1, 2, device=device))
21+
22+
# Velocities
23+
self.velocities = torch.zeros(num_particles, 2, device=device)
24+
25+
def forward(self, steps):
26+
positions = torch.cat([self.initial_position_0.unsqueeze(0), self.initial_positions_rest], dim=0)
27+
velocities = self.velocities
28+
29+
for _ in range(steps):
30+
forces = torch.zeros_like(positions)
31+
32+
# Compute spring forces
33+
for (i, j, rest_length, stiffness) in self.springs:
34+
xi, xj = positions[i], positions[j]
35+
dir_vec = xj - xi
36+
dist = dir_vec.norm()
37+
force = stiffness * (dist - rest_length) * dir_vec / (dist + 1e-6)
38+
forces[i] += force
39+
forces[j] -= force
40+
41+
# Apply gravity
42+
forces[:, 1] -= self.gravity * self.mass
43+
44+
# Integrate (semi-implicit Euler)
45+
acceleration = forces / self.mass
46+
velocities = velocities + acceleration * self.dt
47+
positions = positions + velocities * self.dt
48+
49+
# Fix particle 0 after integration
50+
positions[0] = self.initial_position_0
51+
velocities[0] = torch.tensor([0.0, 0.0], device=positions.device)
52+
53+
return positions
54+
55+
56+
57+
def train(args):
58+
"""
59+
Train the MassSpringSystem to match a target configuration.
60+
"""
61+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62+
system = MassSpringSystem(
63+
num_particles=args.num_particles,
64+
springs=[(0, 1, 1.0, args.stiffness)],
65+
mass=args.mass,
66+
dt=args.dt,
67+
gravity=args.gravity,
68+
device=device,
69+
)
70+
71+
optimizer = optim.Adam(system.parameters(), lr=args.lr)
72+
target_positions = torch.tensor(
73+
[[0.0, 0.0], [1.0, 0.0]], device=device
74+
) # Target: particle 0 at (0,0), particle 1 at (1,0)
75+
76+
for epoch in range(args.epochs):
77+
optimizer.zero_grad()
78+
final_positions = system(args.steps) # <--- final_positions comes from forward()
79+
loss = (final_positions - target_positions).pow(2).mean()
80+
loss.backward()
81+
optimizer.step()
82+
83+
if (epoch + 1) % args.log_interval == 0:
84+
print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss.item():.6f}")
85+
86+
print("\nTraining completed.")
87+
print(f"Final positions:\n{final_positions.detach().cpu().numpy()}") # <--- print final_positions
88+
print(f"Target positions:\n{target_positions.cpu().numpy()}")
89+
90+
91+
def evaluate(args):
92+
"""
93+
Evaluate the trained MassSpringSystem without optimization.
94+
"""
95+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96+
system = MassSpringSystem(
97+
num_particles=args.num_particles,
98+
springs=[(0, 1, 1.0, args.stiffness)],
99+
mass=args.mass,
100+
dt=args.dt,
101+
gravity=args.gravity, # <-- Gravity passed here too
102+
device=device,
103+
)
104+
105+
with torch.no_grad():
106+
final_positions = system(args.steps)
107+
print(f"Final positions after {args.steps} steps:\n{final_positions.cpu().numpy()}")
108+
109+
110+
def parse_args():
111+
parser = argparse.ArgumentParser(description="Differentiable Physics: Mass-Spring System")
112+
parser.add_argument("--epochs", type=int, default=1000, help="Number of training epochs")
113+
parser.add_argument("--steps", type=int, default=50, help="Number of simulation steps per forward pass")
114+
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate")
115+
parser.add_argument("--dt", type=float, default=0.01, help="Time step for integration")
116+
parser.add_argument("--mass", type=float, default=1.0, help="Mass of each particle")
117+
parser.add_argument("--stiffness", type=float, default=10.0, help="Spring stiffness constant")
118+
parser.add_argument("--num_particles", type=int, default=2, help="Number of particles in the system")
119+
parser.add_argument("--mode", choices=["train", "eval"], default="train", help="Mode: train or eval")
120+
parser.add_argument("--log_interval", type=int, default=100, help="Print loss every n epochs")
121+
parser.add_argument("--gravity", type=float, default=9.81, help="Gravity strength")
122+
return parser.parse_args()
123+
124+
125+
def main():
126+
args = parse_args()
127+
128+
if args.mode == "train":
129+
train(args)
130+
elif args.mode == "eval":
131+
evaluate(args)
132+
133+
134+
if __name__ == "__main__":
135+
main()

differentiable_physics/readme.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Differentiable Physics: Mass-Spring System
2+
3+
This example demonstrates a simple differentiable mass-spring system using PyTorch.
4+
5+
Particles are connected by springs and evolve under the forces exerted by the springs and gravity.
6+
The system is fully differentiable, allowing the optimization of particle positions to match a target configuration using gradient-based learning.
7+
8+
---
9+
10+
## Files
11+
12+
- `mass_spring.py` — Implements the mass-spring simulation, training loop, and evaluation.
13+
- `README.md` — Usage instructions and description.
14+
15+
---
16+
17+
## Requirements
18+
19+
- Python 3.8+
20+
- PyTorch
21+
22+
No external dependencies are required apart from PyTorch.
23+
24+
---
25+
26+
## Usage
27+
28+
First, ensure PyTorch is installed.
29+
30+
### Train the system
31+
32+
```bash
33+
python mass_spring.py --mode train
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torch

0 commit comments

Comments
 (0)