GSoC 2020-Blog#2: Summary of the Summer - HuggingFace Transformers with Julia

It's coming to the end of the GSoC 2020. We have implemented lot's of stuff during this summer and still some work remain unfinished, so let's see what is done during the GSoC coding period. Here is the summary of the my work. All the code are under Transformers.HuggingFace.

using Pkg
pkg"add Transformers#master CUDA Flux PyCall; build"
using Transformers, Flux, CUDA
using Transformers.HuggingFace
50.8s

The code is implemented in pure Julia, but to see it work correctly, we'll use PyCall here just for demonstration. It's actually unnecessary for using Transformers.HuggingFace.

using PyCall
pytorch = pyimport_conda("torch", "pytorch")
@pyimport pip
pip.main(["install", "transformers"])
pytransformers = pyimport("transformers")
39.0s
PyObject <module 'transformers' from '/opt/conda/lib/python3.7/site-packages/transformers/__init__.py'>

Project: Leveraging Hugging Face Transformers package in Julia

As the title says, the goal of this project is to reuse the existing python transformer ecosystem - the Huggingface transformers package. To achieve this, we start with the model loader and saver.

The Loader API

Huggingface has a really great ecosystem, they build not only that python library but also an amazing model hub that everyone can upload/download models to/from there. These helps the NLP community grow faster than ever before. We don't want to be missing from the fast growing trend of the NLP technology, so we build a few functionality above their download mechanism and end up some like this:

cfg = hgf"bert-base-cased:config"
17.2s
HGFBertConfig with 15 entries: :vocab_size => 28996 :hidden_size => 768 :num_hidden_layers => 12 :num_attention_heads => 12 :intermediate_size => 3072 :hidden_act => "gelu" :hidden_dropout_prob => 0.1 :attention_probs_dropout_prob => 0.1 :max_position_embeddings => 512 :type_vocab_size => 2 :initializer_range => 0.02 :layer_norm_eps => 1.0f-12 :pad_token_id => 0 :model_type => "bert" :architectures => Any["BertForMaskedLM"]

The @hgf_str API handle the whole process automatically. It download the required file (here is the bert config file) from their model hub. The downloaded file will be managed by Julia's Artifacts system, so there will be no duplicate files on our computer. Moreover, if there are already files cached by huggingface/transformers, we also reuse those files.

pygpt_cfg = pytransformers.AutoConfig.from_pretrained("gpt2")
0.9s
PyObject GPT2Config { "activation_function": "gelu_new", "architectures": [ "GPT2LMHeadModel" ], "attn_pdrop": 0.1, "bos_token_id": 50256, "embd_pdrop": 0.1, "eos_token_id": 50256, "initializer_range": 0.02, "layer_norm_epsilon": 1e-05, "model_type": "gpt2", "n_ctx": 1024, "n_embd": 768, "n_head": 12, "n_layer": 12, "n_positions": 1024, "resid_pdrop": 0.1, "summary_activation": null, "summary_first_dropout": 0.1, "summary_proj_to_labels": true, "summary_type": "cls_index", "summary_use_proj": true, "task_specific_params": { "text-generation": { "do_sample": true, "max_length": 50 } }, "vocab_size": 50257 }
gpt_cfg = hgf"gpt2:config"
0.9s
HGFGPT2Config with 23 entries: :vocab_size => 50257 :n_positions => 1024 :n_ctx => 1024 :n_embd => 768 :n_layer => 12 :n_head => 12 :n_inner => nothing :activation_function => "gelu_new" :resid_pdrop => 0.1 :embd_pdrop => 0.1 :attn_pdrop => 0.1 :layer_norm_epsilon => 1.0f-5 :initializer_range => 0.02 :summary_type => "cls_index" :summary_use_proj => true :summary_activation => nothing :summary_proj_to_labels => true :summary_first_dropout => 0.1 :bos_token_id => 50256 ⋮ => ⋮

Beside the model hub, we can also use the API with our own local files. Under the hood are several low-level api that make this happened. The complete workflow is like this:

  1. Is Pretrained model from huggingface model hub? (yes=>2. / no=>3.)

  2. Already use the model in python before? (yes=>2-1. / no=>2-2.)

    1. There is a cached pretrain file on the computer: Use get_or_download_hgf_file directly. This will copy the file from cached dir to our Artifacts dir and register on Artifacts.toml(Julia's Artifacts system).

    2. No cached files: Also use get_or_download_hgf_file directly. This will download the file from huggingface's model server to our Artifacts dir and register on Artifact.toml.

  3. Using the local pretrained files: Use HuggingFace.find_or_register_hgf_file_hash to register the file to our Artifacts system. Once the file is registered, you can find the entry appear on Artifacts.toml.

  4. Once the model is managed under Julia's Artifacts system. we can use either HuggingFace.get_registered_file_dir or get_or_download_hgf_file to get the pretrain file or dir.

  5. Then we can enjoy the @hgf_str API.

?get_or_download_hgf_file
1.9s

get_or_download_hgf_file(model_name, file_name)

get the file path of the given model_name and file_name. Automatically download and register from huggingface server if file not found on Artifacts.toml. To use with a local file, register with find_or_register_hgf_file_hash first.

?HuggingFace.find_or_register_hgf_file_hash
0.3s

find_or_register_hgf_file_hash(path, model_name, file_name)

Get the artifacts hash of the give <model_name>/<file_name>. If not found in Artifacts.toml, get the file from path and register on Artifacts.toml. path can be either a url or a local path.

?@hgf_str
0.2s

hgf"<model-name>:<item>"

Get item from model-name. This will ensure the required data are downloaded and registered. item can be "config", "tokenizer", and model related like "model", or "formaskedlm", etc. Use get_model_type to see what model/task are supported.

Loading model

we have seen how to use the API with config, but what about the model? Can we also use the @hgf_str API? The answer is yes! But things are more complicate here. Huggingface transformers use a unique model type for each task. You need to make sure there exist a model for that task. You can find what task are supported with the `HuggingFace.get_model_type` function. For example:

HuggingFace.get_model_type(Val(:bert))
0.5s
(:model=>HGFBertModel, :forpretraining=>HGFBertForPreTraining, :lmheadmodel=>HGFBertLMHeadModel, :formaskedlm=>HGFBertForMaskedLM, :fornextsentenceprediction=>HGFBertForNextSentencePrediction, :forsequenceclassification=>HGFBertForSequenceClassification, :formultiplechoice=>HGFBertForMultipleChoice, :fortokenclassification=>HGFBertForTokenClassification, :forquestionanswering=>HGFBertForQuestionAnswering)

It will return the exist task name with a corresponding model type. Once you know which model you need, you can get the model loaded with the @hgf_str API:

bert_model = hgf"bert-base-cased:forquestionanswering";
30.7s

During the loading process, There are two kinds of warning might appear.

The first kind of warning is that there some extra variables in the loaded state and thus will be ignored. In the above example, the loaded is pretrained with masked language modeling task, so it would have some layer for decoding the token (field cls). However, those layer aren't needed for question answering task, so we just ignore those extra layers.

Next, you can see the second warning in the above block of what variables (qa_outputs here) are missing from the saved state (aren't initialized with loaded state) and thus randomly initialized. This happened because the loaded model is not pretrained on question answering task, but here we want to fine-tune the weight on question answering task. This warning inform you what weight is being loaded. So if some weights should be loaded but appear in the warning, that means something must go wrong.

You can also see the whole model architecture on printing the model. These models are implement with Flux, so we can use gpu like a regular Flux layer.

gpu_model = gpu(bert_model)
9.4s

Saving the model for python

We have shown how to get the pretrained model and load it in Julia. The next problem is that even if we trained a model in Julia, how can a person from python world use it? If we save the model with BSON just like the regular julian way, currently there is no easy way to load it in python. Not to mention upload it to Huggingface's model hub.

As a consequence, we decide to follow exactly the same format used by huggingface/transformers, which is PyTorch state_dict for models. Fortunately, It's just a modified python pickle format, so we implement another package called Pickle.jl in pure Julia (which is also done during this summer). Therefore, things work quite smoothly with Transformers.HuggingFace. For example:

mycfg = HuggingFace.HGFGPT2Config(vocab_size=10, 
  n_embd=128, n_layer=2, n_head=2, n_positions=100, n_ctx=100, 
  bos_token_id=0, eos_token_id=1)
0.2s
HGFGPT2Config with 20 entries: :vocab_size => 10 :n_positions => 100 :n_ctx => 100 :n_embd => 128 :n_layer => 2 :n_head => 2 :n_inner => nothing :activation_function => "gelu_new" :resid_pdrop => 0.1 :embd_pdrop => 0.1 :attn_pdrop => 0.1 :layer_norm_epsilon => 1.0f-5 :initializer_range => 0.02 :summary_type => "cls_index" :summary_use_proj => true :summary_activation => nothing :summary_proj_to_labels => true :summary_first_dropout => 0.1 :bos_token_id => 0 :eos_token_id => 1

we create our own gpt2 model in julia.

mygpt2 = HuggingFace.HGFGPT2LMHeadModel(mycfg)
2.0s
;mkdir jlgpt2
0.5s
save_config("jlgpt2", mycfg)
save_model("jlgpt2", mygpt2)
5.4s
"/jlgpt2/pytorch_model.bin"

and saved them at /jlgpt2. This can be load from python directly.

pygpt2 = pytransformers.GPT2LMHeadModel.from_pretrained("./jlgpt2")
0.4s
pygpt2.lm_head.weight
0.1s
PyObject Parameter containing: tensor([[-0.0232, -0.0051, -0.0137, ..., -0.0370, 0.0028, -0.0031], [ 0.0153, -0.0023, 0.0060, ..., 0.0181, -0.0063, 0.0163], [-0.0185, 0.0004, 0.0131, ..., 0.0014, -0.0339, 0.0246], ..., [-0.0077, 0.0027, -0.0333, ..., 0.0046, -0.0132, -0.0128], [-0.0217, 0.0042, -0.0397, ..., -0.0111, -0.0136, -0.0336], [ 0.0014, -0.0021, -0.0174, ..., 0.0060, -0.0120, 0.0098]], requires_grad=True)
mygpt2.lm_head.weight
0.9s
10×128 Array{Float32,2}: -0.0232328 -0.00507754 -0.0136856 … 0.00284516 -0.00307978 0.0153007 -0.00230721 0.00597577 -0.00630203 0.0163045 -0.0184844 0.000391856 0.0130985 -0.0338721 0.0246367 -0.00319505 -0.0164045 -0.00548908 -0.0190406 -0.0182496 0.000148878 0.0121778 -0.0139525 0.0104698 0.0337191 -0.0351344 -0.0337594 0.00567676 … 0.00542427 0.00574145 0.0354903 0.0145032 -0.0212157 -0.00514954 -0.0244643 -0.00770666 0.00274134 -0.0332742 -0.013247 -0.0127596 -0.0216602 0.00422401 -0.0396937 -0.0135971 -0.0335842 0.00141776 -0.00210052 -0.017447 -0.0120234 0.00975189

You can see the loaded value are correct. Everything works like a charm.

The Unfinished part and Future work

Beside the code above, we still have lots stuff remain unfinished during the GSoC 2020. Here is the list:

  • NO Tokenizer:

One of the best part about huggingface/transformers is that they wrap the tokenizer in a way that can be easy use. They even support multiple language and several pre-processing pipelines. Unfortunately, this work is too large to fit in the schedule with other workload. What's worse is that without the tokenizer, we cannot upload our trained model to Huggingface's model hub. Currently we can only keep use the old tokenizer from Transformers.jl to train the model, but implementing the tokenizer part will be on the top priority of the future development of this package.

  • Only support small amount of model kind:

There several model supported by huggingface/transformers, and we only implement 3 of them during the coding period. Since the model are manually translate from Python+Pytorch code to Julia+Flux code, each model require several days to keep the Julia API consistent with Python's. We'll add more model implementation in the Future once the tokenizer problem is fixed.

  • Lack of examples and tutorial:

Up to now, we didn't provide any complete training example with Transformers.HuggingFace.

Conclusion

After the GSoC 2020, we have a package that can get the pretrain model from Huggingface and train the model in pure Julia and Flux. Moreover we can also save the model for Python to use. Although there are still some stuff need to be done, this should step-by-step improve the status of NLP with JuliaLang.

Acknowledgement

Thanks the JuliaLang and Julia Community for these opportunities. Especially thanks to my mentors, Avik Sengupta, Jun Tian, and Dhairya Gandhi. This project won't be exist without their help.

After the GSoC, I will keep developing the package. Hopefully, The project could be the bridge between two communities and start the emergence of JuliaLang on the list of recommendation for NLP projects.

Runtimes (1)