Supervised 3D Mesh Reconstruction
I am really excited to be a part of GSOC'20, Google Summer of Code. During this program, I will be working on the project "Deep learning for 3D Computer Vision".
With recent advances and success in Machine Leaning, researchers have come up with novel deep learning architecture to analyse and find high-order patterns in domains like 3D structure. However, it is not easy to implement domains like 3D computer vision with essential utilities of NN library (which is primarily focused on classic computer vision).
Processing and applying transforms/metrics on 3D data involves intense computation, and therefore JuliaLang is an ideal programming language. ML ecosystem (Flux.jl and Zygote.jl) will enable us to train our 3D models.
For this reason, I am working on 3D vision framework (Flux3D.jl) which will assist users in processing and applying transforms/metrics on 3D structure (like PointCloud, TriMesh, Voxels). This framework will also allow interconversion among the different 3D structures and visualize them.
In the following demonstration, I will be using this package along with Flux and Zygote to perform 3D mesh reconstruction task.
Problem Description:
We are given an initial source shape (sphere in this case) and we want to deform this source shape to fit target shape (dolphin in this case). For this demonstration, we will be using Triangle Mesh for the representation of source and target shape.
Triangle Mesh has two main components, vertices and faces. Deformation of source shape to fit target shape can be achieved by offsetting source's vertices to fit target surface. Also, the number of vertices and faces is not equal in source and target shape.
Installing and importing packages
Flux3D will be used in processing data and defining metrics, Zygote and Flux for computing gradients and optimizing, and Plots for plotting gif.
pkg"up"
pkg"add https://github.com/nirmal-suthar/Flux3D.jl#ns/mesh"
pkg"add Flux Zygote ImageIO Plots"
pkg"precompile"
using Flux3D, Zygote, Flux, FileIO, Statistics, Plots
Flux3D.AbstractPlotting.inline!(true)
Flux3D.AbstractPlotting.set_theme!(show_axis = false)
Downloading obj file of sphere and dolphin
download("https://github.com/nirmalsuthar/public_files/raw/master/dolphin.obj",
"dolphin.obj")
download("https://github.com/nirmalsuthar/public_files/raw/master/sphere.obj",
"sphere.obj")
Loading Triangle Mesh
Triangle Mesh is handled by TriMesh in Flux3D. TriMesh also supports batched format, namely padded, packed and list, which allow us to use fast batched operations. We can load TriMesh with load_trimesh function (supports obj
, stl
, ply
, off
and 2DM
)
dolphin = load_trimesh("dolphin.obj")
src = load_trimesh("sphere.obj")
Preprocessing data
Preprocessing tgt (dolphin), such that its mean is zero and also scale it according to the bounding box of src (sphere), So that src can converge at greater speed.
tgt = deepcopy(dolphin)
verts = get_verts_packed(tgt)
center = mean(verts, dims=1)
verts = verts .- center
scale = maximum(abs.(verts))
verts = verts ./ scale
tgt._verts_packed = verts
Visualizing TriMesh
We will use visualize
function for visualizing TriMesh. This function uses Makie for plotting. In fact, we can also visualize PointCloud using this function, which makes this function handy dealing with different 3D format.
Flux3D.AbstractPlotting.vbox(visualize(src), visualize(tgt))
Defining loss objective
Starting from the src mesh, we will deform src mesh by offsetting its vertices (by offset array), such that new deformed mesh is close to target mesh. Therefore, our loss function will optimize the offset array. We will be using the following metrics to define loss objective:
chamfer_distance
- the distance between the deformed mesh and target mesh, which is calculated by taking randomly 5000 points from the surface of each mesh and calculating chamfer_distance between these two pointcloud.laplacian_loss
- also known as Laplacian smoothing will act as a regularizer.edge_loss
- this will minimize edges length in deformed mesh, also act as a regularizer.
function loss_dolphin(x::Array, src::TriMesh, tgt::TriMesh)
src = Flux3D.offset(src, x)
loss1 = chamfer_distance(src, tgt, 5000)
loss2 = laplacian_loss(src)
loss3 = edge_loss(src)
return loss1 + 0.1*loss2 + loss3
end
Defining learning rate and optimizer
lr = 1.0
opt = Flux.Optimise.Momentum(lr, 0.9)
Optimizing the offset array
We first initialize offset array as zeros, hence deformed mesh is equivalent to src mesh (sphere). Next, we calculate loss using this offset array and we compute derivatives wrt. offset array and finally optimize the array.
"Training...") (
_offset = zeros(Float32, size(get_verts_packed(src))...)
θ = Zygote.Params([_offset])
for itr in 1:2001
gs = gradient(θ) do
loss_dolphin(_offset, src, tgt)
end
Flux.update!(opt, _offset, gs[_offset])
if (itr%250 == 1)
loss = loss_dolphin(_offset, src, tgt)
itr, loss
save("src_$(itr).png", visualize(Flux3D.offset(src, _offset)))
end
end
anim = for i ∈ 1:8
Plots.plot(load("src_$(1+250*(i-1)).png"), showaxis=false)
end
gif(anim, "src_deform.gif", fps = 2)
Postprocessing the predicted mesh
We create a new TriMesh by offsetting src by final offset array and scale up the final_mesh by the same scaling factor we scale down tgt, such that final_mesh has similar bounding box as dolphin mesh.
final_mesh = Flux3D.offset(src, _offset)
final_mesh = Flux3D.scale!(final_mesh, scale)
Flux3D.AbstractPlotting.vbox(visualize(final_mesh), visualize(dolphin))
Saving the final_mesh
Flux3D provide IO function save_trimesh
to save TriMesh (supports obj
, stl
, ply
, off
and 2DM
)
save_trimesh("results/final_mesh.off", final_mesh)
save("results/final_mesh.png", visualize(final_mesh))
Conclusion
In this demonstration, we were able to fit a source shape (sphere) to target shape (dolphin) and use TriMesh structure to represent triangle mesh and finally use IO function to save TriMesh. One interesting difference between predicted and dolphin mesh is the lack of sharp feature (like nose-tip and tail-corners) in predicted mesh. This may be because of the fact that random points (for chamfer_distance) are samples based on the categorical distribution of the area of faces, and sharp region have low face area.
Future Work
Add GPU support for fast training.
Add 3D structure and transform for voxels
Interconversion between different 3D structures like PointCloud, Voxel and TriMesh.
Add more metrics for TriMesh (like normal_consistency and cloud_mesh_distance)