Supervised 3D Mesh Reconstruction

I am really excited to be a part of GSOC'20, Google Summer of Code. During this program, I will be working on the project "Deep learning for 3D Computer Vision".

With recent advances and success in Machine Leaning, researchers have come up with novel deep learning architecture to analyse and find high-order patterns in domains like 3D structure. However, it is not easy to implement domains like 3D computer vision with essential utilities of NN library (which is primarily focused on classic computer vision).

Processing and applying transforms/metrics on 3D data involves intense computation, and therefore JuliaLang is an ideal programming language. ML ecosystem (Flux.jl and Zygote.jl) will enable us to train our 3D models.

For this reason, I am working on 3D vision framework (Flux3D.jl) which will assist users in processing and applying transforms/metrics on 3D structure (like PointCloud, TriMesh, Voxels). This framework will also allow interconversion among the different 3D structures and visualize them.

In the following demonstration, I will be using this package along with Flux and Zygote to perform 3D mesh reconstruction task.

Problem Description:

We are given an initial source shape (sphere in this case) and we want to deform this source shape to fit target shape (dolphin in this case). For this demonstration, we will be using Triangle Mesh for the representation of source and target shape.

Triangle Mesh has two main components, vertices and faces. Deformation of source shape to fit target shape can be achieved by offsetting source's vertices to fit target surface. Also, the number of vertices and faces is not equal in source and target shape.

Installing and importing packages

Flux3D will be used in processing data and defining metrics, Zygote and Flux for computing gradients and optimizing, and Plots for plotting gif.

pkg"up"
pkg"add https://github.com/nirmal-suthar/Flux3D.jl#ns/mesh"
pkg"add Flux Zygote ImageIO Plots"
pkg"precompile"
using Flux3D, Zygote, Flux, FileIO, Statistics, Plots 
1138.0s
Flux3D.AbstractPlotting.inline!(true)
Flux3D.AbstractPlotting.set_theme!(show_axis = false)
0.3s

Downloading obj file of sphere and dolphin

download("https://github.com/nirmalsuthar/public_files/raw/master/dolphin.obj",
         "dolphin.obj")
download("https://github.com/nirmalsuthar/public_files/raw/master/sphere.obj",
         "sphere.obj")
2.6s
"sphere.obj"

Loading Triangle Mesh

Triangle Mesh is handled by TriMesh in Flux3D. TriMesh also supports batched format, namely padded, packed and list, which allow us to use fast batched operations. We can load TriMesh with load_trimesh function (supports objstlplyoff and 2DM)

dolphin = load_trimesh("dolphin.obj")
src = load_trimesh("sphere.obj")
7.8s
TriMesh{Float32,UInt32}(1, 2562, 5120, true, Bool[1], -1, [2562], [5120], nothing, nothing, Array{Float32,2}[[-0.525687 0.850678 0.0; 0.525687 0.850678 0.0; … ; -0.334152 0.041299 -0.941614; -0.400982 0.0 -0.916086]], nothing, nothing, Array{UInt32,2}[[0x00000001 0x00000287 0x00000284; 0x00000001 0x00000284 0x00000283; … ; 0x000009ff 0x000009fd 0x000009fe; 0x000009de 0x000009dd 0x000009dc]], nothing, nothing, nothing, nothing)

Preprocessing data

Preprocessing tgt (dolphin), such that its mean is zero and also scale it according to the bounding box of src (sphere), So that src can converge at greater speed.

tgt = deepcopy(dolphin)
verts = get_verts_packed(tgt)
center = mean(verts, dims=1)
verts = verts .- center
scale = maximum(abs.(verts))
verts = verts ./ scale
tgt._verts_packed = verts
2.5s
2562×3 Array{Float32,2}: -0.101319 0.28717 0.0906198 0.100773 0.28668 0.0908082 -0.102468 -0.158384 -0.122923 0.0990615 -0.158078 -0.120888 -0.000445389 0.0352555 0.459447 0.00411056 0.364373 0.702185 -0.000652599 -0.709654 -0.957446 0.00149755 0.124749 -0.391589 0.158752 -0.0786782 -0.254219 0.272774 -0.0218674 0.358715 ⋮ -0.104368 0.112222 0.584489 -0.108841 0.131724 0.609518 -0.121354 0.116081 0.564659 0.0595596 -0.214296 -0.648242 0.0619035 -0.160664 -0.6125 0.0734885 -0.181159 -0.605178 -0.0598853 -0.211955 -0.649319 -0.0614622 -0.158247 -0.614201 -0.0734454 -0.179162 -0.607734

Visualizing TriMesh

We will use visualize function for visualizing TriMesh. This function uses Makie for plotting. In fact, we can also visualize PointCloud using this function, which makes this function handy dealing with different 3D format.

Flux3D.AbstractPlotting.vbox(visualize(src), visualize(tgt))
11.8s

Defining loss objective

Starting from the src mesh, we will deform src mesh by offsetting its vertices (by offset array), such that new deformed mesh is close to target mesh. Therefore, our loss function will optimize the offset array. We will be using the following metrics to define loss objective:

  • chamfer_distance - the distance between the deformed mesh and target mesh, which is calculated by taking randomly 5000 points from the surface of each mesh and calculating chamfer_distance between these two pointcloud.

  • laplacian_loss - also known as Laplacian smoothing will act as a regularizer.

  • edge_loss - this will minimize edges length in deformed mesh, also act as a regularizer.

function loss_dolphin(x::Array, src::TriMesh, tgt::TriMesh)
    src = Flux3D.offset(src, x)
    loss1 = chamfer_distance(src, tgt, 5000)
    loss2 = laplacian_loss(src)
    loss3 = edge_loss(src)
    return loss1 + 0.1*loss2 + loss3
end
0.1s
loss_dolphin (generic function with 1 method)

Defining learning rate and optimizer

lr = 1.0
opt = Flux.Optimise.Momentum(lr, 0.9)
0.7s
Momentum(1.0, 0.9, IdDict{Any,Any}())

Optimizing the offset array

We first initialize offset array as zeros, hence deformed mesh is equivalent to src mesh (sphere). Next, we calculate loss using this offset array and we compute derivatives wrt. offset array and finally optimize the array.

@info("Training...")
_offset = zeros(Float32, size(get_verts_packed(src))...)
θ = Zygote.Params([_offset])
for itr in 1:2001
    gs = gradient(θ) do
        loss_dolphin(_offset, src, tgt)
    end
    Flux.update!(opt, _offset, gs[_offset])
    if (itr%250 == 1)
        loss = loss_dolphin(_offset, src, tgt)
        @show itr, loss
        save("src_$(itr).png", visualize(Flux3D.offset(src, _offset)))
    end
end
2120.0s
anim = @animate for i  1:8
    Plots.plot(load("src_$(1+250*(i-1)).png"), showaxis=false)
end
gif(anim, "src_deform.gif", fps = 2)
32.7s

Postprocessing the predicted mesh

We create a new TriMesh by offsetting src by final offset array and scale up the final_mesh by the same scaling factor we scale down tgt, such that final_mesh has similar bounding box as dolphin mesh.

final_mesh = Flux3D.offset(src, _offset)
final_mesh = Flux3D.scale!(final_mesh, scale)
13.0s
TriMesh{Float32,UInt32}(1, 2562, 5120, true, Bool[1], -1, [2562], [5120], Float32[-0.0196962 0.121593 0.0174378; 0.0190504 0.12169 0.0170868; … ; -0.0195147 -0.0431417 -0.219973; -0.0238372 -0.0503164 -0.219878], nothing, AbstractArray{Float32,2}[[-0.0530034 0.327212 0.0469259; 0.0512655 0.327473 0.0459815; … ; -0.0525149 -0.116096 -0.591958; -0.0641471 -0.135404 -0.591702]], nothing, nothing, Array{UInt32,2}[[0x00000001 0x00000287 0x00000284; 0x00000001 0x00000284 0x00000283; … ; 0x000009ff 0x000009fd 0x000009fe; 0x000009de 0x000009dd 0x000009dc]], nothing, nothing, nothing, nothing)
Flux3D.AbstractPlotting.vbox(visualize(final_mesh), visualize(dolphin))
13.6s

Saving the final_mesh

Flux3D provide IO function save_trimesh to save TriMesh (supports objstlplyoff and 2DM)

save_trimesh("results/final_mesh.off", final_mesh)
save("results/final_mesh.png", visualize(final_mesh))
12.0s
21007
final_mesh.png
final_mesh.off

Conclusion

In this demonstration, we were able to fit a source shape (sphere) to target shape (dolphin) and use TriMesh structure to represent triangle mesh and finally use IO function to save TriMesh. One interesting difference between predicted and dolphin mesh is the lack of sharp feature (like nose-tip and tail-corners) in predicted mesh. This may be because of the fact that random points (for chamfer_distance) are samples based on the categorical distribution of the area of faces, and sharp region have low face area.

Future Work

  • Add GPU support for fast training.

  • Add 3D structure and transform for voxels

  • Interconversion between different 3D structures like PointCloud, Voxel and TriMesh.

  • Add more metrics for TriMesh (like normal_consistency and cloud_mesh_distance)

Runtimes (1)