transformer-lm-japanese-1.0b

This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code (lm1b).

Source Code

We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github.

Our Blog Post

Model Details

Model Params Layers Dim Heads Dataset Dataset size Training time PPL
transformer-lm-japanese-1.0b 1.0B 18 2048 16 wiki40b/ja 2.19GB 4 days 31.47

Usage: FlaxAutoModel

Requirements:

pip install transformers>=4.39.0
pip install jax==0.4.31
pip install flax==0.8.3
pip install sentencepiece==0.1.99

# For CPU
pip install -U "jax[cpu]==0.4.31"

# For GPU
pip install -U "jax[cuda12]==0.4.31"

Note: Set trust_remote_code=True to load our custom model.

from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-1.0b", trust_remote_code=True)
model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-1.0b", trust_remote_code=True)

text = "日本の首都は、"
token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False)

output_ids = model.generate(
  token_ids,
  do_sample=True,
  temperature=0.6,
  top_k=20,
  max_new_tokens=100
)

output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
print(output)

We tested text generation in a Python 3.10 environment on GCP as follows

  • GPU Type: NVIDIA L4 (x 1)
  • Machine Type: g2-standard-16 (16 CPUs, 64GB Memory)
  • Disk: 256GB
  • OS: Ubuntu 22.04 LTS x86/64

Dataset

Tokenization

Author

Ryoichi Fukugawa

Downloads last month
8
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.

Dataset used to train fukugawa/transformer-lm-japanese-1.0b