|
--- |
|
license: other |
|
license_name: coqui-public-model-license |
|
license_link: https://coqui.ai/cpml.txt |
|
datasets: |
|
- ivanzhu109/zh-taiwan |
|
language: |
|
- zh |
|
- en |
|
base_model: |
|
- coqui/XTTS-v2 |
|
--- |
|
|
|
### Model Description |
|
This model is a fine-tuned version of Coqui TTS, optimized to generate text-to-speech (TTS) output with a Mandarin accent. |
|
|
|
### Features |
|
- Language: Chinese |
|
- Fine-tuned from: Coqui-ai XTTS-v2 |
|
|
|
|
|
### Training Data |
|
|
|
The model was trained using the [zh-taiwan dataset](!https://huggingface.co/datasets/ivanzhu109/zh-taiwan?row=1), which consists of a mixture of Mandarin and English audio sample. |
|
|
|
## How to Get Started with the Model |
|
|
|
Init from XttsConfig and load checkpoint: |
|
|
|
```bash |
|
git clone https://github.com/idiap/coqui-ai-TTS |
|
cd coqui-ai-TTS |
|
pip install -e . |
|
``` |
|
|
|
|
|
```python |
|
import os |
|
import torch |
|
import torchaudio |
|
from datetime import datetime |
|
from TTS.tts.configs.xtts_config import XttsConfig |
|
from TTS.tts.models.xtts import Xtts |
|
import logging |
|
import time |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
logger.info("Loading model...") |
|
config = XttsConfig() |
|
config.load_json("xtts-v2-zh-tw/config.json") |
|
model = Xtts.init_from_config(config) |
|
model.load_checkpoint( |
|
config, |
|
checkpoint_path="xtts-v2-zh-tw/checkpoint.pth", |
|
use_deepspeed=True, |
|
eval=True, |
|
) |
|
|
|
model.cuda() |
|
phrases = [ |
|
"合併稅後盈653.22億元", |
|
"EPS 為11.52元創下新紀錄" |
|
] |
|
|
|
logger.info(len(phrases)) |
|
start_time = time.time() |
|
|
|
logger.info("Computing speaker latents...") |
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( |
|
audio_path=["YOUR_REFERNCE.wav"] |
|
) |
|
|
|
logger.info("Inference...") |
|
wav_list = [] |
|
|
|
for idx, sub in enumerate(phrases): |
|
out = model.inference( |
|
sub, |
|
"zh-cn", |
|
gpt_cond_latent, |
|
speaker_embedding, |
|
enable_text_splitting=True, |
|
# top_k=40, |
|
# top_p=0.5, |
|
speed=1.2, |
|
# temperature=0.4 |
|
) |
|
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
# compute stats |
|
process_time = time.time() - start_time |
|
audio_time = len(torch.tensor(out["wav"]).unsqueeze(0) / 22050) |
|
logger.warning("Processing time: %.3f", process_time) |
|
logger.warning("Real-time factor: %.3f", process_time / audio_time) |
|
wav_list.append(torch.tensor(out["wav"]).unsqueeze(0)) |
|
|
|
combined_wav = torch.cat(wav_list, dim=1) |
|
logger.info(f"export: voice-{idx}-xtts.wav") |
|
torchaudio.save(f"voice-{idx}-xtts.wav", combined_wav, 22050) |
|
``` |
|
|
|
|