diff --git a/app.py b/app.py index d99078015d7f1308d8cd4823430cce685503c289..5232107c72d94193a849588ab507af3a5d88554f 100644 --- a/app.py +++ b/app.py @@ -27,7 +27,7 @@ os.system('nvidia-smi') os.system('apt update -y && apt-get install -y apt-utils && apt install -y unzip') # os.system('pip install flash-attn --no-build-isolation') # os.system('git submodule update --init --recursive') -os.system('git clone https://github.com/shivammehta25/Matcha-TTS.git third_party/') +# os.system('git clone https://github.com/shivammehta25/Matcha-TTS.git third_party/') os.system('mkdir pretrained_models && cd pretrained_models && git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base.git &&git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-Long.git &&git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B.git &&git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-24kHz.git &&git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base-24kHz.git && for i in InspireMusic-Base InspireMusic-Base-24kHz InspireMusic-1.5B InspireMusic-1.5B-24kHz InspireMusic-1.5B-Long; do sed -i -e "s/\.\.\/\.\.\///g" ${i}/inspiremusic.yaml; done && cd ..') # os.system('mkdir pretrained_models && cd pretrained_models && git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base.git && for i in InspireMusic-Base; do sed -i -e "s/\.\.\/\.\.\///g" ${i}/inspiremusic.yaml; done && cd ..') diff --git a/third_party/Matcha-TTS/.env.example b/third_party/Matcha-TTS/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..a790e320464ebc778ca07f5bcd826a9c8412ed0e --- /dev/null +++ b/third_party/Matcha-TTS/.env.example @@ -0,0 +1,6 @@ +# example of file for storing private and user specific environment variables, like keys or system paths +# rename it to ".env" (excluded from version control by default) +# .env is loaded by train.py automatically +# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} + +MY_VAR="/home/user/my/system/path" diff --git a/third_party/Matcha-TTS/.github/PULL_REQUEST_TEMPLATE.md b/third_party/Matcha-TTS/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000000000000000000000000000000000..410bcd87a45297ab8f0d369574a032858b6b1811 --- /dev/null +++ b/third_party/Matcha-TTS/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,22 @@ +## What does this PR do? + + + +Fixes #\ + +## Before submitting + +- [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**? +- [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together? +- [ ] Did you list all the **breaking changes** introduced by this pull request? +- [ ] Did you **test your PR locally** with `pytest` command? +- [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command? + +## Did you have fun? + +Make sure you had fun coding 🙃 diff --git a/third_party/Matcha-TTS/.github/codecov.yml b/third_party/Matcha-TTS/.github/codecov.yml new file mode 100644 index 0000000000000000000000000000000000000000..c66853c4bd9991f730da5dda7dc9881986779558 --- /dev/null +++ b/third_party/Matcha-TTS/.github/codecov.yml @@ -0,0 +1,15 @@ +coverage: + status: + # measures overall project coverage + project: + default: + threshold: 100% # how much decrease in coverage is needed to not consider success + + # measures PR or single commit coverage + patch: + default: + threshold: 100% # how much decrease in coverage is needed to not consider success + + + # project: off + # patch: off diff --git a/third_party/Matcha-TTS/.github/dependabot.yml b/third_party/Matcha-TTS/.github/dependabot.yml new file mode 100644 index 0000000000000000000000000000000000000000..b19ccab12a3c573025ce6ba6d9068b062b29cc1b --- /dev/null +++ b/third_party/Matcha-TTS/.github/dependabot.yml @@ -0,0 +1,17 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "pip" # See documentation for possible values + directory: "/" # Location of package manifests + target-branch: "dev" + schedule: + interval: "daily" + ignore: + - dependency-name: "pytorch-lightning" + update-types: ["version-update:semver-patch"] + - dependency-name: "torchmetrics" + update-types: ["version-update:semver-patch"] diff --git a/third_party/Matcha-TTS/.github/release-drafter.yml b/third_party/Matcha-TTS/.github/release-drafter.yml new file mode 100644 index 0000000000000000000000000000000000000000..59af159f671abe75311eb626c8ec92ca6ea09d3c --- /dev/null +++ b/third_party/Matcha-TTS/.github/release-drafter.yml @@ -0,0 +1,44 @@ +name-template: "v$RESOLVED_VERSION" +tag-template: "v$RESOLVED_VERSION" + +categories: + - title: "🚀 Features" + labels: + - "feature" + - "enhancement" + - title: "🐛 Bug Fixes" + labels: + - "fix" + - "bugfix" + - "bug" + - title: "🧹 Maintenance" + labels: + - "maintenance" + - "dependencies" + - "refactoring" + - "cosmetic" + - "chore" + - title: "📝️ Documentation" + labels: + - "documentation" + - "docs" + +change-template: "- $TITLE @$AUTHOR (#$NUMBER)" +change-title-escapes: '\<*_&' # You can add # and @ to disable mentions + +version-resolver: + major: + labels: + - "major" + minor: + labels: + - "minor" + patch: + labels: + - "patch" + default: patch + +template: | + ## Changes + + $CHANGES diff --git a/third_party/Matcha-TTS/.gitignore b/third_party/Matcha-TTS/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..cbec8b43a0414bbbf4cc9feae49b9dc091a60c92 --- /dev/null +++ b/third_party/Matcha-TTS/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +### VisualStudioCode +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace +**/.vscode + +# JetBrains +.idea/ + +# Data & Models +*.h5 +*.tar +*.tar.gz + +# Lightning-Hydra-Template +configs/local/default.yaml +/data/ +/logs/ +.env + +# Aim logging +.aim + +# Cython complied files +matcha/utils/monotonic_align/core.c + +# Ignoring hifigan checkpoint +generator_v1 +g_02500000 +gradio_cached_examples/ +synth_output/ diff --git a/third_party/Matcha-TTS/.pre-commit-config.yaml b/third_party/Matcha-TTS/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..761867141e9f3a59316ab9f0b6eec6191d29900e --- /dev/null +++ b/third_party/Matcha-TTS/.pre-commit-config.yaml @@ -0,0 +1,59 @@ +default_language_version: + python: python3.11 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + # list of supported hooks: https://pre-commit.com/hooks.html + - id: trailing-whitespace + - id: end-of-file-fixer + # - id: check-docstring-first + - id: check-yaml + - id: debug-statements + - id: detect-private-key + - id: check-toml + - id: check-case-conflict + - id: check-added-large-files + + # python code formatting + - repo: https://github.com/psf/black + rev: 23.12.1 + hooks: + - id: black + args: [--line-length, "120"] + + # python import sorting + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] + + # python upgrading syntax to newer version + - repo: https://github.com/asottile/pyupgrade + rev: v3.15.0 + hooks: + - id: pyupgrade + args: [--py38-plus] + + # python check (PEP8), programming errors and code complexity + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + args: + [ + "--max-line-length", "120", + "--extend-ignore", + "E203,E402,E501,F401,F841,RST2,RST301", + "--exclude", + "logs/*,data/*,matcha/hifigan/*", + ] + additional_dependencies: [flake8-rst-docstrings==0.3.0] + + # pylint + - repo: https://github.com/pycqa/pylint + rev: v3.0.3 + hooks: + - id: pylint diff --git a/third_party/Matcha-TTS/.project-root b/third_party/Matcha-TTS/.project-root new file mode 100644 index 0000000000000000000000000000000000000000..63eab774b9e36aa1a46cbd31b59cbd373bc5477f --- /dev/null +++ b/third_party/Matcha-TTS/.project-root @@ -0,0 +1,2 @@ +# this file is required for inferring the project root directory +# do not delete diff --git a/third_party/Matcha-TTS/.pylintrc b/third_party/Matcha-TTS/.pylintrc new file mode 100644 index 0000000000000000000000000000000000000000..962864189eab99a66b315b80f5a9976e7a423d4a --- /dev/null +++ b/third_party/Matcha-TTS/.pylintrc @@ -0,0 +1,525 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-whitelist= + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Specify a configuration file. +#rcfile= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=missing-docstring, + too-many-public-methods, + too-many-lines, + bare-except, + ## for avoiding weird p3.6 CI linter error + ## TODO: see later if we can remove this + assigning-non-slot, + unsupported-assignment-operation, + ## end + line-too-long, + fixme, + wrong-import-order, + ungrouped-imports, + wrong-import-position, + import-error, + invalid-name, + too-many-instance-attributes, + arguments-differ, + arguments-renamed, + no-name-in-module, + no-member, + unsubscriptable-object, + raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + useless-object-inheritance, + too-few-public-methods, + too-many-branches, + too-many-arguments, + too-many-locals, + too-many-statements, + duplicate-code, + not-callable, + import-outside-toplevel, + logging-fstring-interpolation, + logging-not-lazy, + unused-argument, + no-else-return, + chained-comparison, + redefined-outer-name + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[LOGGING] + +# Format style used to check logging format string. `old` means using % +# formatting, while `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package.. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members=numpy.*,torch.* + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=120 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. +argument-rgx=[a-z_][a-z0-9_]{0,30}$ + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. +#class-attribute-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + x, + ex, + Run, + _ + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. +variable-rgx=[a-z_][a-z0-9_]{0,30}$ + + +[STRING] + +# This flag controls whether the implicit-str-concat-in-sequence should +# generate a warning on implicit string concatenation in sequences defined over +# several lines. +check-str-concat-over-line-jumps=no + + +[IMPORTS] + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[DESIGN] + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement. +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=15 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=builtins.BaseException, + builtins.Exception diff --git a/third_party/Matcha-TTS/LICENSE b/third_party/Matcha-TTS/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..858018e750da7be7b271bb7307e68d159ed67ef6 --- /dev/null +++ b/third_party/Matcha-TTS/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Shivam Mehta + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/Matcha-TTS/MANIFEST.in b/third_party/Matcha-TTS/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..c013140cdfb9de19c4d4e73c73a44e33f33fa871 --- /dev/null +++ b/third_party/Matcha-TTS/MANIFEST.in @@ -0,0 +1,14 @@ +include README.md +include LICENSE.txt +include requirements.*.txt +include *.cff +include requirements.txt +include matcha/VERSION +recursive-include matcha *.json +recursive-include matcha *.html +recursive-include matcha *.png +recursive-include matcha *.md +recursive-include matcha *.py +recursive-include matcha *.pyx +recursive-exclude tests * +prune tests* diff --git a/third_party/Matcha-TTS/Makefile b/third_party/Matcha-TTS/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..4b523dd17b13a19617c9cc9d9dad7f7d8d4c24a0 --- /dev/null +++ b/third_party/Matcha-TTS/Makefile @@ -0,0 +1,42 @@ + +help: ## Show help + @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +clean: ## Clean autogenerated files + rm -rf dist + find . -type f -name "*.DS_Store" -ls -delete + find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf + find . | grep -E ".pytest_cache" | xargs rm -rf + find . | grep -E ".ipynb_checkpoints" | xargs rm -rf + rm -f .coverage + +clean-logs: ## Clean logs + rm -rf logs/** + +create-package: ## Create wheel and tar gz + rm -rf dist/ + python setup.py bdist_wheel --plat-name=manylinux1_x86_64 + python setup.py sdist + python -m twine upload dist/* --verbose --skip-existing + +format: ## Run pre-commit hooks + pre-commit run -a + +sync: ## Merge changes from main branch to your current branch + git pull + git pull origin main + +test: ## Run not slow tests + pytest -k "not slow" + +test-full: ## Run all tests + pytest + +train-ljspeech: ## Train the model + python matcha/train.py experiment=ljspeech + +train-ljspeech-min: ## Train the model with minimum memory + python matcha/train.py experiment=ljspeech_min_memory + +start_app: ## Start the app + python matcha/app.py diff --git a/third_party/Matcha-TTS/README.md b/third_party/Matcha-TTS/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a5867d0a46568e545524073b435ad84d929b8d73 --- /dev/null +++ b/third_party/Matcha-TTS/README.md @@ -0,0 +1,315 @@ +
+ +# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching + +### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/) + +[![python](https://img.shields.io/badge/-Python_3.10-blue?logo=python&logoColor=white)](https://www.python.org/downloads/release/python-3100/) +[![pytorch](https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/) +[![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/) +[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/) +[![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/) +[![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) + +

+ +

+ +
+ +> This is the official code implementation of 🍵 Matcha-TTS [ICASSP 2024]. + +We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses [conditional flow matching](https://arxiv.org/abs/2210.02747) (similar to [rectified flows](https://arxiv.org/abs/2209.03003)) to speed up ODE-based speech synthesis. Our method: + +- Is probabilistic +- Has compact memory footprint +- Sounds highly natural +- Is very fast to synthesise from + +Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS) and read [our ICASSP 2024 paper](https://arxiv.org/abs/2309.03199) for more details. + +[Pre-trained models](https://drive.google.com/drive/folders/17C_gYgEHOxI5ZypcfE_k1piKCtyR0isJ?usp=sharing) will be automatically downloaded with the CLI or gradio interface. + +You can also [try 🍵 Matcha-TTS in your browser on HuggingFace 🤗 spaces](https://huggingface.co/spaces/shivammehta25/Matcha-TTS). + +## Teaser video + +[![Watch the video](https://img.youtube.com/vi/xmvJkz3bqw0/hqdefault.jpg)](https://youtu.be/xmvJkz3bqw0) + +## Installation + +1. Create an environment (suggested but optional) + +``` +conda create -n matcha-tts python=3.10 -y +conda activate matcha-tts +``` + +2. Install Matcha TTS using pip or from source + +```bash +pip install matcha-tts +``` + +from source + +```bash +pip install git+https://github.com/shivammehta25/Matcha-TTS.git +cd Matcha-TTS +pip install -e . +``` + +3. Run CLI / gradio app / jupyter notebook + +```bash +# This will download the required models +matcha-tts --text "" +``` + +or + +```bash +matcha-tts-app +``` + +or open `synthesis.ipynb` on jupyter notebook + +### CLI Arguments + +- To synthesise from given text, run: + +```bash +matcha-tts --text "" +``` + +- To synthesise from a file, run: + +```bash +matcha-tts --file +``` + +- To batch synthesise from a file, run: + +```bash +matcha-tts --file --batched +``` + +Additional arguments + +- Speaking rate + +```bash +matcha-tts --text "" --speaking_rate 1.0 +``` + +- Sampling temperature + +```bash +matcha-tts --text "" --temperature 0.667 +``` + +- Euler ODE solver steps + +```bash +matcha-tts --text "" --steps 10 +``` + +## Train with your own dataset + +Let's assume we are training with LJ Speech + +1. Download the dataset from [here](https://keithito.com/LJ-Speech-Dataset/), extract it to `data/LJSpeech-1.1`, and prepare the file lists to point to the extracted data like for [item 5 in the setup of the NVIDIA Tacotron 2 repo](https://github.com/NVIDIA/tacotron2#setup). + +2. Clone and enter the Matcha-TTS repository + +```bash +git clone https://github.com/shivammehta25/Matcha-TTS.git +cd Matcha-TTS +``` + +3. Install the package from source + +```bash +pip install -e . +``` + +4. Go to `configs/data/ljspeech.yaml` and change + +```yaml +train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt +valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt +``` + +5. Generate normalisation statistics with the yaml file of dataset configuration + +```bash +matcha-data-stats -i ljspeech.yaml +# Output: +#{'mel_mean': -5.53662231756592, 'mel_std': 2.1161014277038574} +``` + +Update these values in `configs/data/ljspeech.yaml` under `data_statistics` key. + +```bash +data_statistics: # Computed for ljspeech dataset + mel_mean: -5.536622 + mel_std: 2.116101 +``` + +to the paths of your train and validation filelists. + +6. Run the training script + +```bash +make train-ljspeech +``` + +or + +```bash +python matcha/train.py experiment=ljspeech +``` + +- for a minimum memory run + +```bash +python matcha/train.py experiment=ljspeech_min_memory +``` + +- for multi-gpu training, run + +```bash +python matcha/train.py experiment=ljspeech trainer.devices=[0,1] +``` + +7. Synthesise from the custom trained model + +```bash +matcha-tts --text "" --checkpoint_path +``` + +## ONNX support + +> Special thanks to [@mush42](https://github.com/mush42) for implementing ONNX export and inference support. + +It is possible to export Matcha checkpoints to [ONNX](https://onnx.ai/), and run inference on the exported ONNX graph. + +### ONNX export + +To export a checkpoint to ONNX, first install ONNX with + +```bash +pip install onnx +``` + +then run the following: + +```bash +python3 -m matcha.onnx.export matcha.ckpt model.onnx --n-timesteps 5 +``` + +Optionally, the ONNX exporter accepts **vocoder-name** and **vocoder-checkpoint** arguments. This enables you to embed the vocoder in the exported graph and generate waveforms in a single run (similar to end-to-end TTS systems). + +**Note** that `n_timesteps` is treated as a hyper-parameter rather than a model input. This means you should specify it during export (not during inference). If not specified, `n_timesteps` is set to **5**. + +**Important**: for now, torch>=2.1.0 is needed for export since the `scaled_product_attention` operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually as a pre-release. + +### ONNX Inference + +To run inference on the exported model, first install `onnxruntime` using + +```bash +pip install onnxruntime +pip install onnxruntime-gpu # for GPU inference +``` + +then use the following: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs +``` + +You can also control synthesis parameters: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --temperature 0.4 --speaking_rate 0.9 --spk 0 +``` + +To run inference on **GPU**, make sure to install **onnxruntime-gpu** package, and then pass `--gpu` to the inference command: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --gpu +``` + +If you exported only Matcha to ONNX, this will write mel-spectrogram as graphs and `numpy` arrays to the output directory. +If you embedded the vocoder in the exported graph, this will write `.wav` audio files to the output directory. + +If you exported only Matcha to ONNX, and you want to run a full TTS pipeline, you can pass a path to a vocoder model in `ONNX` format: + +```bash +python3 -m matcha.onnx.infer model.onnx --text "hey" --output-dir ./outputs --vocoder hifigan.small.onnx +``` + +This will write `.wav` audio files to the output directory. + +## Extract phoneme alignments from Matcha-TTS + +If the dataset is structured as + +```bash +data/ +└── LJSpeech-1.1 + ├── metadata.csv + ├── README + ├── test.txt + ├── train.txt + ├── val.txt + └── wavs +``` +Then you can extract the phoneme level alignments from a Trained Matcha-TTS model using: +```bash +python matcha/utils/get_durations_from_trained_model.py -i dataset_yaml -c +``` +Example: +```bash +python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c matcha_ljspeech.ckpt +``` +or simply: +```bash +matcha-tts-get-durations -i ljspeech.yaml -c matcha_ljspeech.ckpt +``` +--- +## Train using extracted alignments + +In the datasetconfig turn on load duration. +Example: `ljspeech.yaml` +``` +load_durations: True +``` +or see an examples in configs/experiment/ljspeech_from_durations.yaml + + +## Citation information + +If you use our code or otherwise find this work useful, please cite our paper: + +```text +@inproceedings{mehta2024matcha, + title={Matcha-{TTS}: A fast {TTS} architecture with conditional flow matching}, + author={Mehta, Shivam and Tu, Ruibo and Beskow, Jonas and Sz{\'e}kely, {\'E}va and Henter, Gustav Eje}, + booktitle={Proc. ICASSP}, + year={2024} +} +``` + +## Acknowledgements + +Since this code uses [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), you have all the powers that come with it. + +Other source code we would like to acknowledge: + +- [Coqui-TTS](https://github.com/coqui-ai/TTS/tree/dev): For helping me figure out how to make cython binaries pip installable and encouragement +- [Hugging Face Diffusers](https://huggingface.co/): For their awesome diffusers library and its components +- [Grad-TTS](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS): For the monotonic alignment search source code +- [torchdyn](https://github.com/DiffEqML/torchdyn): Useful for trying other ODE solvers during research and development +- [labml.ai](https://nn.labml.ai/transformers/rope/index.html): For the RoPE implementation diff --git a/third_party/Matcha-TTS/configs/__init__.py b/third_party/Matcha-TTS/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56bf7f4aa4906bc0f997132708cc0826c198e4aa --- /dev/null +++ b/third_party/Matcha-TTS/configs/__init__.py @@ -0,0 +1 @@ +# this file is needed here to include configs when building project as a package diff --git a/third_party/Matcha-TTS/configs/callbacks/default.yaml b/third_party/Matcha-TTS/configs/callbacks/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ebaa3ed31a7f626bc62f90184dc4b25b631e52a9 --- /dev/null +++ b/third_party/Matcha-TTS/configs/callbacks/default.yaml @@ -0,0 +1,5 @@ +defaults: + - model_checkpoint.yaml + - model_summary.yaml + - rich_progress_bar.yaml + - _self_ diff --git a/third_party/Matcha-TTS/configs/callbacks/model_checkpoint.yaml b/third_party/Matcha-TTS/configs/callbacks/model_checkpoint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d085c711a8521b6b98ad6401b686bb601ceacd6 --- /dev/null +++ b/third_party/Matcha-TTS/configs/callbacks/model_checkpoint.yaml @@ -0,0 +1,17 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html + +model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${paths.output_dir}/checkpoints # directory to save the model file + filename: checkpoint_{epoch:03d} # checkpoint filename + monitor: epoch # name of the logged metric which determines when model is improving + verbose: False # verbosity mode + save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 10 # save k best models (determined by above metric) + mode: "max" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: 100 # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/third_party/Matcha-TTS/configs/callbacks/model_summary.yaml b/third_party/Matcha-TTS/configs/callbacks/model_summary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6e5368d0e94298cce6d5421365b4583bd763ba92 --- /dev/null +++ b/third_party/Matcha-TTS/configs/callbacks/model_summary.yaml @@ -0,0 +1,5 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html + +model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: 3 # the maximum depth of layer nesting that the summary will include diff --git a/third_party/Matcha-TTS/configs/callbacks/none.yaml b/third_party/Matcha-TTS/configs/callbacks/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/configs/callbacks/rich_progress_bar.yaml b/third_party/Matcha-TTS/configs/callbacks/rich_progress_bar.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de6f1ccb11205a4db93645fb6f297e50205de172 --- /dev/null +++ b/third_party/Matcha-TTS/configs/callbacks/rich_progress_bar.yaml @@ -0,0 +1,4 @@ +# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html + +rich_progress_bar: + _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/third_party/Matcha-TTS/configs/data/hi-fi_en-US_female.yaml b/third_party/Matcha-TTS/configs/data/hi-fi_en-US_female.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1269f9b3b421d27a204bb0697e2f27a0fa0864a3 --- /dev/null +++ b/third_party/Matcha-TTS/configs/data/hi-fi_en-US_female.yaml @@ -0,0 +1,14 @@ +defaults: + - ljspeech + - _self_ + +# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/ +_target_: matcha.data.text_mel_datamodule.TextMelDataModule +name: hi-fi_en-US_female +train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt +valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt +batch_size: 32 +cleaners: [english_cleaners_piper] +data_statistics: # Computed for this dataset + mel_mean: -6.38385 + mel_std: 2.541796 diff --git a/third_party/Matcha-TTS/configs/data/ljspeech.yaml b/third_party/Matcha-TTS/configs/data/ljspeech.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee87a6a76a2c344b3a90196f87bb48f205e2e48d --- /dev/null +++ b/third_party/Matcha-TTS/configs/data/ljspeech.yaml @@ -0,0 +1,22 @@ +_target_: matcha.data.text_mel_datamodule.TextMelDataModule +name: ljspeech +train_filelist_path: data/LJSpeech-1.1/train.txt +valid_filelist_path: data/LJSpeech-1.1/val.txt +batch_size: 32 +num_workers: 20 +pin_memory: True +cleaners: [english_cleaners2] +add_blank: True +n_spks: 1 +n_fft: 1024 +n_feats: 80 +sample_rate: 22050 +hop_length: 256 +win_length: 1024 +f_min: 0 +f_max: 8000 +data_statistics: # Computed for ljspeech dataset + mel_mean: -5.536622 + mel_std: 2.116101 +seed: ${seed} +load_durations: false diff --git a/third_party/Matcha-TTS/configs/data/vctk.yaml b/third_party/Matcha-TTS/configs/data/vctk.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ba11cc63371ad6308d6711513268de7efe50eed9 --- /dev/null +++ b/third_party/Matcha-TTS/configs/data/vctk.yaml @@ -0,0 +1,14 @@ +defaults: + - ljspeech + - _self_ + +_target_: matcha.data.text_mel_datamodule.TextMelDataModule +name: vctk +train_filelist_path: data/filelists/vctk_audio_sid_text_train_filelist.txt +valid_filelist_path: data/filelists/vctk_audio_sid_text_val_filelist.txt +batch_size: 32 +add_blank: True +n_spks: 109 +data_statistics: # Computed for vctk dataset + mel_mean: -6.630575 + mel_std: 2.482914 diff --git a/third_party/Matcha-TTS/configs/debug/default.yaml b/third_party/Matcha-TTS/configs/debug/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3932c82585fbe44047c1569a5cfe9ee9895c71a --- /dev/null +++ b/third_party/Matcha-TTS/configs/debug/default.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +# default debugging setup, runs 1 full epoch +# other debugging configs can inherit from this one + +# overwrite task name so debugging logs are stored in separate folder +task_name: "debug" + +# disable callbacks and loggers during debugging +# callbacks: null +# logger: null + +extras: + ignore_warnings: False + enforce_tags: False + +# sets level of all command line loggers to 'DEBUG' +# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ +hydra: + job_logging: + root: + level: DEBUG + + # use this to also set hydra loggers to 'DEBUG' + # verbose: True + +trainer: + max_epochs: 1 + accelerator: cpu # debuggers don't like gpus + devices: 1 # debuggers don't like multiprocessing + detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor + +data: + num_workers: 0 # debuggers don't like multiprocessing + pin_memory: False # disable gpu memory pin diff --git a/third_party/Matcha-TTS/configs/debug/fdr.yaml b/third_party/Matcha-TTS/configs/debug/fdr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f2d34fa37c31017e749d5a4fc5ae6763e688b46 --- /dev/null +++ b/third_party/Matcha-TTS/configs/debug/fdr.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +# runs 1 train, 1 validation and 1 test step + +defaults: + - default + +trainer: + fast_dev_run: true diff --git a/third_party/Matcha-TTS/configs/debug/limit.yaml b/third_party/Matcha-TTS/configs/debug/limit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..514d77fbd1475b03fff0372e3da3c2fa7ea7d190 --- /dev/null +++ b/third_party/Matcha-TTS/configs/debug/limit.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# uses only 1% of the training data and 5% of validation/test data + +defaults: + - default + +trainer: + max_epochs: 3 + limit_train_batches: 0.01 + limit_val_batches: 0.05 + limit_test_batches: 0.05 diff --git a/third_party/Matcha-TTS/configs/debug/overfit.yaml b/third_party/Matcha-TTS/configs/debug/overfit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9906586a67a12aa81ff69138f589a366dbe2222f --- /dev/null +++ b/third_party/Matcha-TTS/configs/debug/overfit.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +# overfits to 3 batches + +defaults: + - default + +trainer: + max_epochs: 20 + overfit_batches: 3 + +# model ckpt and early stopping need to be disabled during overfitting +callbacks: null diff --git a/third_party/Matcha-TTS/configs/debug/profiler.yaml b/third_party/Matcha-TTS/configs/debug/profiler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..266295f15e0166e1d1b58b88caa7673f4b6493b5 --- /dev/null +++ b/third_party/Matcha-TTS/configs/debug/profiler.yaml @@ -0,0 +1,15 @@ +# @package _global_ + +# runs with execution time profiling + +defaults: + - default + +trainer: + max_epochs: 1 + # profiler: "simple" + profiler: "advanced" + # profiler: "pytorch" + accelerator: gpu + + limit_train_batches: 0.02 diff --git a/third_party/Matcha-TTS/configs/eval.yaml b/third_party/Matcha-TTS/configs/eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..be312992b2a486b04d83a54dbd8f670d94979709 --- /dev/null +++ b/third_party/Matcha-TTS/configs/eval.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - _self_ + - data: mnist # choose datamodule with `test_dataloader()` for evaluation + - model: mnist + - logger: null + - trainer: default + - paths: default + - extras: default + - hydra: default + +task_name: "eval" + +tags: ["dev"] + +# passing checkpoint path is necessary for evaluation +ckpt_path: ??? diff --git a/third_party/Matcha-TTS/configs/experiment/hifi_dataset_piper_phonemizer.yaml b/third_party/Matcha-TTS/configs/experiment/hifi_dataset_piper_phonemizer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e6c57a0d0a399f7463f4ff2d96e1928c435779b --- /dev/null +++ b/third_party/Matcha-TTS/configs/experiment/hifi_dataset_piper_phonemizer.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: hi-fi_en-US_female.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"] + +run_name: hi-fi_en-US_female_piper_phonemizer diff --git a/third_party/Matcha-TTS/configs/experiment/ljspeech.yaml b/third_party/Matcha-TTS/configs/experiment/ljspeech.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5723f42cf3552226c42bd91202cc18818b685f0 --- /dev/null +++ b/third_party/Matcha-TTS/configs/experiment/ljspeech.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: ljspeech.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["ljspeech"] + +run_name: ljspeech diff --git a/third_party/Matcha-TTS/configs/experiment/ljspeech_from_durations.yaml b/third_party/Matcha-TTS/configs/experiment/ljspeech_from_durations.yaml new file mode 100644 index 0000000000000000000000000000000000000000..63f7d298280245b8ae4d3403f8540d0d2e8ada4f --- /dev/null +++ b/third_party/Matcha-TTS/configs/experiment/ljspeech_from_durations.yaml @@ -0,0 +1,19 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: ljspeech.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["ljspeech"] + +run_name: ljspeech + + +data: + load_durations: True + batch_size: 64 diff --git a/third_party/Matcha-TTS/configs/experiment/ljspeech_min_memory.yaml b/third_party/Matcha-TTS/configs/experiment/ljspeech_min_memory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ef554dc633c392b1592d90d9d7734f2329264fdd --- /dev/null +++ b/third_party/Matcha-TTS/configs/experiment/ljspeech_min_memory.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: ljspeech.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["ljspeech"] + +run_name: ljspeech_min + + +model: + out_size: 172 diff --git a/third_party/Matcha-TTS/configs/experiment/multispeaker.yaml b/third_party/Matcha-TTS/configs/experiment/multispeaker.yaml new file mode 100644 index 0000000000000000000000000000000000000000..553842f4e2168db0fee4e44db11b5d086295b044 --- /dev/null +++ b/third_party/Matcha-TTS/configs/experiment/multispeaker.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: vctk.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["multispeaker"] + +run_name: multispeaker diff --git a/third_party/Matcha-TTS/configs/extras/default.yaml b/third_party/Matcha-TTS/configs/extras/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9c6b622283a647fbc513166fc14f016cc3ed8a0 --- /dev/null +++ b/third_party/Matcha-TTS/configs/extras/default.yaml @@ -0,0 +1,8 @@ +# disable python warnings if they annoy you +ignore_warnings: False + +# ask user for tags if none are provided in the config +enforce_tags: True + +# pretty print config tree at the start of the run using Rich library +print_config: True diff --git a/third_party/Matcha-TTS/configs/hparams_search/mnist_optuna.yaml b/third_party/Matcha-TTS/configs/hparams_search/mnist_optuna.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1391183ebcdec3d8f5eb61374e0719d13c7545da --- /dev/null +++ b/third_party/Matcha-TTS/configs/hparams_search/mnist_optuna.yaml @@ -0,0 +1,52 @@ +# @package _global_ + +# example hyperparameter optimization of some experiment with Optuna: +# python train.py -m hparams_search=mnist_optuna experiment=example + +defaults: + - override /hydra/sweeper: optuna + +# choose metric which will be optimized by Optuna +# make sure this is the correct name of some metric logged in lightning module! +optimized_metric: "val/acc_best" + +# here we define Optuna hyperparameter search +# it optimizes for value returned from function with @hydra.main decorator +# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper +hydra: + mode: "MULTIRUN" # set hydra to multirun by default if this config is attached + + sweeper: + _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper + + # storage URL to persist optimization results + # for example, you can use SQLite if you set 'sqlite:///example.db' + storage: null + + # name of the study to persist optimization results + study_name: null + + # number of parallel workers + n_jobs: 1 + + # 'minimize' or 'maximize' the objective + direction: maximize + + # total number of runs that will be executed + n_trials: 20 + + # choose Optuna hyperparameter sampler + # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others + # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html + sampler: + _target_: optuna.samplers.TPESampler + seed: 1234 + n_startup_trials: 10 # number of random sampling runs before optimization starts + + # define hyperparameter search space + params: + model.optimizer.lr: interval(0.0001, 0.1) + data.batch_size: choice(32, 64, 128, 256) + model.net.lin1_size: choice(64, 128, 256) + model.net.lin2_size: choice(64, 128, 256) + model.net.lin3_size: choice(32, 64, 128, 256) diff --git a/third_party/Matcha-TTS/configs/hydra/default.yaml b/third_party/Matcha-TTS/configs/hydra/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1533136b22802a4f81e5387b74e407289edce94d --- /dev/null +++ b/third_party/Matcha-TTS/configs/hydra/default.yaml @@ -0,0 +1,19 @@ +# https://hydra.cc/docs/configure_hydra/intro/ + +# enable color logging +defaults: + - override hydra_logging: colorlog + - override job_logging: colorlog + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} +sweep: + dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num} + +job_logging: + handlers: + file: + # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log diff --git a/third_party/Matcha-TTS/configs/local/.gitkeep b/third_party/Matcha-TTS/configs/local/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/configs/logger/aim.yaml b/third_party/Matcha-TTS/configs/logger/aim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f9f6adad7feb2780c2efd5ddb0ed053621e05f8 --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/aim.yaml @@ -0,0 +1,28 @@ +# https://aimstack.io/ + +# example usage in lightning module: +# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py + +# open the Aim UI with the following command (run in the folder containing the `.aim` folder): +# `aim up` + +aim: + _target_: aim.pytorch_lightning.AimLogger + repo: ${paths.root_dir} # .aim folder will be created here + # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# + + # aim allows to group runs under experiment name + experiment: null # any string, set to "default" if not specified + + train_metric_prefix: "train/" + val_metric_prefix: "val/" + test_metric_prefix: "test/" + + # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) + system_tracking_interval: 10 # set to null to disable system metrics tracking + + # enable/disable logging of system params such as installed packages, git info, env vars, etc. + log_system_params: true + + # enable/disable tracking console logs (default value is true) + capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 diff --git a/third_party/Matcha-TTS/configs/logger/comet.yaml b/third_party/Matcha-TTS/configs/logger/comet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0789274e2137ee6c97ca37a5d56c2b8abaf0aaa --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/comet.yaml @@ -0,0 +1,12 @@ +# https://www.comet.ml + +comet: + _target_: lightning.pytorch.loggers.comet.CometLogger + api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable + save_dir: "${paths.output_dir}" + project_name: "lightning-hydra-template" + rest_api_key: null + # experiment_name: "" + experiment_key: null # set to resume experiment + offline: False + prefix: "" diff --git a/third_party/Matcha-TTS/configs/logger/csv.yaml b/third_party/Matcha-TTS/configs/logger/csv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa028e9c146430c319101ffdfce466514338591c --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/csv.yaml @@ -0,0 +1,7 @@ +# csv logger built in lightning + +csv: + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger + save_dir: "${paths.output_dir}" + name: "csv/" + prefix: "" diff --git a/third_party/Matcha-TTS/configs/logger/many_loggers.yaml b/third_party/Matcha-TTS/configs/logger/many_loggers.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd586800bdccb4e8f4b0236a181b7ddd756ba9ab --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/many_loggers.yaml @@ -0,0 +1,9 @@ +# train with many loggers at once + +defaults: + # - comet + - csv + # - mlflow + # - neptune + - tensorboard + - wandb diff --git a/third_party/Matcha-TTS/configs/logger/mlflow.yaml b/third_party/Matcha-TTS/configs/logger/mlflow.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f8fb7e685fa27fc8141387a421b90a0b9b492d9e --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/mlflow.yaml @@ -0,0 +1,12 @@ +# https://mlflow.org + +mlflow: + _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger + # experiment_name: "" + # run_name: "" + tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI + tags: null + # save_dir: "./mlruns" + prefix: "" + artifact_location: null + # run_id: "" diff --git a/third_party/Matcha-TTS/configs/logger/neptune.yaml b/third_party/Matcha-TTS/configs/logger/neptune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8233c140018ecce6ab62971beed269991d31c89b --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/neptune.yaml @@ -0,0 +1,9 @@ +# https://neptune.ai + +neptune: + _target_: lightning.pytorch.loggers.neptune.NeptuneLogger + api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable + project: username/lightning-hydra-template + # name: "" + log_model_checkpoints: True + prefix: "" diff --git a/third_party/Matcha-TTS/configs/logger/tensorboard.yaml b/third_party/Matcha-TTS/configs/logger/tensorboard.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2bd31f6d8ba68d1f5c36a804885d5b9f9c1a9302 --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/tensorboard.yaml @@ -0,0 +1,10 @@ +# https://www.tensorflow.org/tensorboard/ + +tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.output_dir}/tensorboard/" + name: null + log_graph: False + default_hp_metric: True + prefix: "" + # version: "" diff --git a/third_party/Matcha-TTS/configs/logger/wandb.yaml b/third_party/Matcha-TTS/configs/logger/wandb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ece165889b3d0d9dc750a8f3c7454188cfdf12b7 --- /dev/null +++ b/third_party/Matcha-TTS/configs/logger/wandb.yaml @@ -0,0 +1,16 @@ +# https://wandb.ai + +wandb: + _target_: lightning.pytorch.loggers.wandb.WandbLogger + # name: "" # name of the run (normally generated by wandb) + save_dir: "${paths.output_dir}" + offline: False + id: null # pass correct id to resume experiment! + anonymous: null # enable anonymous logging + project: "lightning-hydra-template" + log_model: False # upload lightning ckpts + prefix: "" # a string to put at the beginning of metric keys + # entity: "" # set to name of your wandb team + group: "" + tags: [] + job_type: "" diff --git a/third_party/Matcha-TTS/configs/model/cfm/default.yaml b/third_party/Matcha-TTS/configs/model/cfm/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d1d9609e2d05c7b0a12a26115520340ac18e584 --- /dev/null +++ b/third_party/Matcha-TTS/configs/model/cfm/default.yaml @@ -0,0 +1,3 @@ +name: CFM +solver: euler +sigma_min: 1e-4 diff --git a/third_party/Matcha-TTS/configs/model/decoder/default.yaml b/third_party/Matcha-TTS/configs/model/decoder/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aaa00e63402ade5c76247a2f1d6b294ec3c61e63 --- /dev/null +++ b/third_party/Matcha-TTS/configs/model/decoder/default.yaml @@ -0,0 +1,7 @@ +channels: [256, 256] +dropout: 0.05 +attention_head_dim: 64 +n_blocks: 1 +num_mid_blocks: 2 +num_heads: 2 +act_fn: snakebeta diff --git a/third_party/Matcha-TTS/configs/model/encoder/default.yaml b/third_party/Matcha-TTS/configs/model/encoder/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d4d5e5adee8f707bd384b682a3ad9a116c40c6ed --- /dev/null +++ b/third_party/Matcha-TTS/configs/model/encoder/default.yaml @@ -0,0 +1,18 @@ +encoder_type: RoPE Encoder +encoder_params: + n_feats: ${model.n_feats} + n_channels: 192 + filter_channels: 768 + filter_channels_dp: 256 + n_heads: 2 + n_layers: 6 + kernel_size: 3 + p_dropout: 0.1 + spk_emb_dim: 64 + n_spks: 1 + prenet: true + +duration_predictor_params: + filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp} + kernel_size: 3 + p_dropout: ${model.encoder.encoder_params.p_dropout} diff --git a/third_party/Matcha-TTS/configs/model/matcha.yaml b/third_party/Matcha-TTS/configs/model/matcha.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e2b5c78ddeb98fcca85093deba1cea3b1d1074e1 --- /dev/null +++ b/third_party/Matcha-TTS/configs/model/matcha.yaml @@ -0,0 +1,16 @@ +defaults: + - _self_ + - encoder: default.yaml + - decoder: default.yaml + - cfm: default.yaml + - optimizer: adam.yaml + +_target_: matcha.models.matcha_tts.MatchaTTS +n_vocab: 178 +n_spks: ${data.n_spks} +spk_emb_dim: 64 +n_feats: 80 +data_statistics: ${data.data_statistics} +out_size: null # Must be divisible by 4 +prior_loss: true +use_precomputed_durations: ${data.load_durations} diff --git a/third_party/Matcha-TTS/configs/model/optimizer/adam.yaml b/third_party/Matcha-TTS/configs/model/optimizer/adam.yaml new file mode 100644 index 0000000000000000000000000000000000000000..42795577474eaee5b0b96845a95e1a11c9152385 --- /dev/null +++ b/third_party/Matcha-TTS/configs/model/optimizer/adam.yaml @@ -0,0 +1,4 @@ +_target_: torch.optim.Adam +_partial_: true +lr: 1e-4 +weight_decay: 0.0 diff --git a/third_party/Matcha-TTS/configs/paths/default.yaml b/third_party/Matcha-TTS/configs/paths/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec81db2d34712909a79be3e42e65efe08c35ecee --- /dev/null +++ b/third_party/Matcha-TTS/configs/paths/default.yaml @@ -0,0 +1,18 @@ +# path to root directory +# this requires PROJECT_ROOT environment variable to exist +# you can replace it with "." if you want the root to be the current working directory +root_dir: ${oc.env:PROJECT_ROOT} + +# path to data directory +data_dir: ${paths.root_dir}/data/ + +# path to logging directory +log_dir: ${paths.root_dir}/logs/ + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics +output_dir: ${hydra:runtime.output_dir} + +# path to working directory +work_dir: ${hydra:runtime.cwd} diff --git a/third_party/Matcha-TTS/configs/train.yaml b/third_party/Matcha-TTS/configs/train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e6f5c2e7b9781758c8d25f941f004ca383c3f494 --- /dev/null +++ b/third_party/Matcha-TTS/configs/train.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +# specify here default configuration +# order of defaults determines the order in which configs override each other +defaults: + - _self_ + - data: ljspeech + - model: matcha + - callbacks: default + - logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - trainer: default + - paths: default + - extras: default + - hydra: default + + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: null + + # config for hyperparameter optimization + - hparams_search: null + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: default + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +# task name, determines output directory path +task_name: "train" + +run_name: ??? + +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` +tags: ["dev"] + +# set False to skip model training +train: True + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test: True + +# simply provide checkpoint path to resume training +ckpt_path: null + +# seed for random number generators in pytorch, numpy and python.random +seed: 1234 diff --git a/third_party/Matcha-TTS/configs/trainer/cpu.yaml b/third_party/Matcha-TTS/configs/trainer/cpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7d6767e60c956567555980654f15e7bb673a41f --- /dev/null +++ b/third_party/Matcha-TTS/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: cpu +devices: 1 diff --git a/third_party/Matcha-TTS/configs/trainer/ddp.yaml b/third_party/Matcha-TTS/configs/trainer/ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..94b43e20ca7bf1f2ea92627fd46906e4f0a273a1 --- /dev/null +++ b/third_party/Matcha-TTS/configs/trainer/ddp.yaml @@ -0,0 +1,9 @@ +defaults: + - default + +strategy: ddp + +accelerator: gpu +devices: [0,1] +num_nodes: 1 +sync_batchnorm: True diff --git a/third_party/Matcha-TTS/configs/trainer/ddp_sim.yaml b/third_party/Matcha-TTS/configs/trainer/ddp_sim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8404419e5c295654967d0dfb73a7366e75be2f1f --- /dev/null +++ b/third_party/Matcha-TTS/configs/trainer/ddp_sim.yaml @@ -0,0 +1,7 @@ +defaults: + - default + +# simulate DDP on CPU, useful for debugging +accelerator: cpu +devices: 2 +strategy: ddp_spawn diff --git a/third_party/Matcha-TTS/configs/trainer/default.yaml b/third_party/Matcha-TTS/configs/trainer/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee3d370d8ca6b08d7ee7a86d34184c2104f0e1ef --- /dev/null +++ b/third_party/Matcha-TTS/configs/trainer/default.yaml @@ -0,0 +1,20 @@ +_target_: lightning.pytorch.trainer.Trainer + +default_root_dir: ${paths.output_dir} + +max_epochs: -1 + +accelerator: gpu +devices: [0] + +# mixed precision for extra speed-up +precision: 16-mixed + +# perform a validation loop every N training epochs +check_val_every_n_epoch: 1 + +# set True to to ensure deterministic results +# makes training slower but gives more reproducibility than just setting seeds +deterministic: False + +gradient_clip_val: 5.0 diff --git a/third_party/Matcha-TTS/configs/trainer/gpu.yaml b/third_party/Matcha-TTS/configs/trainer/gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2389510a90f5f0161cff6ccfcb4a96097ddf9a1 --- /dev/null +++ b/third_party/Matcha-TTS/configs/trainer/gpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: gpu +devices: 1 diff --git a/third_party/Matcha-TTS/configs/trainer/mps.yaml b/third_party/Matcha-TTS/configs/trainer/mps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ecf6d5cc3a34ca127c5510f4a18e989561e38e4 --- /dev/null +++ b/third_party/Matcha-TTS/configs/trainer/mps.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: mps +devices: 1 diff --git a/third_party/Matcha-TTS/matcha/VERSION b/third_party/Matcha-TTS/matcha/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..ea5abc8f95c042c48eff77805a033599f816a545 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/VERSION @@ -0,0 +1 @@ +0.0.7.0 diff --git a/third_party/Matcha-TTS/matcha/__init__.py b/third_party/Matcha-TTS/matcha/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/app.py b/third_party/Matcha-TTS/matcha/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d68fbaa2d10d1faab606d89906af5e8b6baa5aa4 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/app.py @@ -0,0 +1,357 @@ +import tempfile +from argparse import Namespace +from pathlib import Path + +import gradio as gr +import soundfile as sf +import torch + +from matcha.cli import ( + MATCHA_URLS, + VOCODER_URLS, + assert_model_downloaded, + get_device, + load_matcha, + load_vocoder, + process_text, + to_waveform, +) +from matcha.utils.utils import get_user_data_dir, plot_tensor + +LOCATION = Path(get_user_data_dir()) + +args = Namespace( + cpu=False, + model="matcha_vctk", + vocoder="hifigan_univ_v1", + spk=0, +) + +CURRENTLY_LOADED_MODEL = args.model + + +def MATCHA_TTS_LOC(x): + return LOCATION / f"{x}.ckpt" + + +def VOCODER_LOC(x): + return LOCATION / f"{x}" + + +LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png" +RADIO_OPTIONS = { + "Multi Speaker (VCTK)": { + "model": "matcha_vctk", + "vocoder": "hifigan_univ_v1", + }, + "Single Speaker (LJ Speech)": { + "model": "matcha_ljspeech", + "vocoder": "hifigan_T2_v1", + }, +} + +# Ensure all the required models are downloaded +assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"]) +assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"]) +assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"]) +assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"]) + +device = get_device(args) + +# Load default model +model = load_matcha(args.model, MATCHA_TTS_LOC(args.model), device) +vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC(args.vocoder), device) + + +def load_model(model_name, vocoder_name): + model = load_matcha(model_name, MATCHA_TTS_LOC(model_name), device) + vocoder, denoiser = load_vocoder(vocoder_name, VOCODER_LOC(vocoder_name), device) + return model, vocoder, denoiser + + +def load_model_ui(model_type, textbox): + model_name, vocoder_name = RADIO_OPTIONS[model_type]["model"], RADIO_OPTIONS[model_type]["vocoder"] + + global model, vocoder, denoiser, CURRENTLY_LOADED_MODEL # pylint: disable=global-statement + if CURRENTLY_LOADED_MODEL != model_name: + model, vocoder, denoiser = load_model(model_name, vocoder_name) + CURRENTLY_LOADED_MODEL = model_name + + if model_name == "matcha_ljspeech": + spk_slider = gr.update(visible=False, value=-1) + single_speaker_examples = gr.update(visible=True) + multi_speaker_examples = gr.update(visible=False) + length_scale = gr.update(value=0.95) + else: + spk_slider = gr.update(visible=True, value=0) + single_speaker_examples = gr.update(visible=False) + multi_speaker_examples = gr.update(visible=True) + length_scale = gr.update(value=0.85) + + return ( + textbox, + gr.update(interactive=True), + spk_slider, + single_speaker_examples, + multi_speaker_examples, + length_scale, + ) + + +@torch.inference_mode() +def process_text_gradio(text): + output = process_text(1, text, device) + return output["x_phones"][1::2], output["x"], output["x_lengths"] + + +@torch.inference_mode() +def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk): + spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None + output = model.synthesise( + text, + text_length, + n_timesteps=n_timesteps, + temperature=temperature, + spks=spk, + length_scale=length_scale, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: + sf.write(fp.name, output["waveform"], 22050, "PCM_24") + + return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy()) + + +def multispeaker_example_cacher(text, n_timesteps, mel_temp, length_scale, spk): + global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement + if CURRENTLY_LOADED_MODEL != "matcha_vctk": + global model, vocoder, denoiser # pylint: disable=global-statement + model, vocoder, denoiser = load_model("matcha_vctk", "hifigan_univ_v1") + CURRENTLY_LOADED_MODEL = "matcha_vctk" + + phones, text, text_lengths = process_text_gradio(text) + audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk) + return phones, audio, mel_spectrogram + + +def ljspeech_example_cacher(text, n_timesteps, mel_temp, length_scale, spk=-1): + global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement + if CURRENTLY_LOADED_MODEL != "matcha_ljspeech": + global model, vocoder, denoiser # pylint: disable=global-statement + model, vocoder, denoiser = load_model("matcha_ljspeech", "hifigan_T2_v1") + CURRENTLY_LOADED_MODEL = "matcha_ljspeech" + + phones, text, text_lengths = process_text_gradio(text) + audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk) + return phones, audio, mel_spectrogram + + +def main(): + description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching + ### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/) + We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis. Our method: + + + * Is probabilistic + * Has compact memory footprint + * Sounds highly natural + * Is very fast to synthesise from + + + Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS). Read our [arXiv preprint for more details](https://arxiv.org/abs/2309.03199). + Code is available in our [GitHub repository](https://github.com/shivammehta25/Matcha-TTS), along with pre-trained models. + + Cached examples are available at the bottom of the page. + """ + + with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo: + processed_text = gr.State(value=None) + processed_text_len = gr.State(value=None) + + with gr.Box(): + with gr.Row(): + gr.Markdown(description, scale=3) + with gr.Column(): + gr.Image(LOGO_URL, label="Matcha-TTS logo", height=50, width=50, scale=1, show_label=False) + html = '
' + gr.HTML(html) + + with gr.Box(): + radio_options = list(RADIO_OPTIONS.keys()) + model_type = gr.Radio( + radio_options, value=radio_options[0], label="Choose a Model", interactive=True, container=False + ) + + with gr.Row(): + gr.Markdown("# Text Input") + with gr.Row(): + text = gr.Textbox(value="", lines=2, label="Text to synthesise", scale=3) + spk_slider = gr.Slider( + minimum=0, maximum=107, step=1, value=args.spk, label="Speaker ID", interactive=True, scale=1 + ) + + with gr.Row(): + gr.Markdown("### Hyper parameters") + with gr.Row(): + n_timesteps = gr.Slider( + label="Number of ODE steps", + minimum=1, + maximum=100, + step=1, + value=10, + interactive=True, + ) + length_scale = gr.Slider( + label="Length scale (Speaking rate)", + minimum=0.5, + maximum=1.5, + step=0.05, + value=1.0, + interactive=True, + ) + mel_temp = gr.Slider( + label="Sampling temperature", + minimum=0.00, + maximum=2.001, + step=0.16675, + value=0.667, + interactive=True, + ) + + synth_btn = gr.Button("Synthesise") + + with gr.Box(): + with gr.Row(): + gr.Markdown("### Phonetised text") + phonetised_text = gr.Textbox(interactive=False, scale=10, label="Phonetised text") + + with gr.Box(): + with gr.Row(): + mel_spectrogram = gr.Image(interactive=False, label="mel spectrogram") + + # with gr.Row(): + audio = gr.Audio(interactive=False, label="Audio") + + with gr.Row(visible=False) as example_row_lj_speech: + examples = gr.Examples( # pylint: disable=unused-variable + examples=[ + [ + "We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.", + 50, + 0.677, + 0.95, + ], + [ + "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + 2, + 0.677, + 0.95, + ], + [ + "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + 4, + 0.677, + 0.95, + ], + [ + "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + 10, + 0.677, + 0.95, + ], + [ + "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + 50, + 0.677, + 0.95, + ], + [ + "The narrative of these events is based largely on the recollections of the participants.", + 10, + 0.677, + 0.95, + ], + [ + "The jury did not believe him, and the verdict was for the defendants.", + 10, + 0.677, + 0.95, + ], + ], + fn=ljspeech_example_cacher, + inputs=[text, n_timesteps, mel_temp, length_scale], + outputs=[phonetised_text, audio, mel_spectrogram], + cache_examples=True, + ) + + with gr.Row() as example_row_multispeaker: + multi_speaker_examples = gr.Examples( # pylint: disable=unused-variable + examples=[ + [ + "Hello everyone! I am speaker 0 and I am here to tell you that Matcha-TTS is amazing!", + 10, + 0.677, + 0.85, + 0, + ], + [ + "Hello everyone! I am speaker 16 and I am here to tell you that Matcha-TTS is amazing!", + 10, + 0.677, + 0.85, + 16, + ], + [ + "Hello everyone! I am speaker 44 and I am here to tell you that Matcha-TTS is amazing!", + 50, + 0.677, + 0.85, + 44, + ], + [ + "Hello everyone! I am speaker 45 and I am here to tell you that Matcha-TTS is amazing!", + 50, + 0.677, + 0.85, + 45, + ], + [ + "Hello everyone! I am speaker 58 and I am here to tell you that Matcha-TTS is amazing!", + 4, + 0.677, + 0.85, + 58, + ], + ], + fn=multispeaker_example_cacher, + inputs=[text, n_timesteps, mel_temp, length_scale, spk_slider], + outputs=[phonetised_text, audio, mel_spectrogram], + cache_examples=True, + label="Multi Speaker Examples", + ) + + model_type.change(lambda x: gr.update(interactive=False), inputs=[synth_btn], outputs=[synth_btn]).then( + load_model_ui, + inputs=[model_type, text], + outputs=[text, synth_btn, spk_slider, example_row_lj_speech, example_row_multispeaker, length_scale], + ) + + synth_btn.click( + fn=process_text_gradio, + inputs=[ + text, + ], + outputs=[phonetised_text, processed_text, processed_text_len], + api_name="matcha_tts", + queue=True, + ).then( + fn=synthesise_mel, + inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale, spk_slider], + outputs=[audio, mel_spectrogram], + ) + + demo.queue().launch(share=True) + + +if __name__ == "__main__": + main() diff --git a/third_party/Matcha-TTS/matcha/cli.py b/third_party/Matcha-TTS/matcha/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..7daf13073a01326cc8150a0f29453e635f31d719 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/cli.py @@ -0,0 +1,419 @@ +import argparse +import datetime as dt +import os +import warnings +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import soundfile as sf +import torch + +from matcha.hifigan.config import v1 +from matcha.hifigan.denoiser import Denoiser +from matcha.hifigan.env import AttrDict +from matcha.hifigan.models import Generator as HiFiGAN +from matcha.models.matcha_tts import MatchaTTS +from matcha.text import sequence_to_text, text_to_sequence +from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse + +MATCHA_URLS = { + "matcha_ljspeech": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_ljspeech.ckpt", + "matcha_vctk": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_vctk.ckpt", +} + +VOCODER_URLS = { + "hifigan_T2_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1", # Old url: https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link + "hifigan_univ_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/g_02500000", # Old url: https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link +} + +MULTISPEAKER_MODEL = { + "matcha_vctk": {"vocoder": "hifigan_univ_v1", "speaking_rate": 0.85, "spk": 0, "spk_range": (0, 107)} +} + +SINGLESPEAKER_MODEL = {"matcha_ljspeech": {"vocoder": "hifigan_T2_v1", "speaking_rate": 0.95, "spk": None}} + + +def plot_spectrogram_to_numpy(spectrogram, filename): + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.title("Synthesised Mel-Spectrogram") + fig.canvas.draw() + plt.savefig(filename) + + +def process_text(i: int, text: str, device: torch.device): + print(f"[{i}] - Input text: {text}") + x = torch.tensor( + intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0), + dtype=torch.long, + device=device, + )[None] + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) + x_phones = sequence_to_text(x.squeeze(0).tolist()) + print(f"[{i}] - Phonetised text: {x_phones[1::2]}") + + return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones} + + +def get_texts(args): + if args.text: + texts = [args.text] + else: + with open(args.file, encoding="utf-8") as f: + texts = f.readlines() + return texts + + +def assert_required_models_available(args): + save_dir = get_user_data_dir() + if not hasattr(args, "checkpoint_path") and args.checkpoint_path is None: + model_path = args.checkpoint_path + else: + model_path = save_dir / f"{args.model}.ckpt" + assert_model_downloaded(model_path, MATCHA_URLS[args.model]) + + vocoder_path = save_dir / f"{args.vocoder}" + assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder]) + return {"matcha": model_path, "vocoder": vocoder_path} + + +def load_hifigan(checkpoint_path, device): + h = AttrDict(v1) + hifigan = HiFiGAN(h).to(device) + hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"]) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def load_vocoder(vocoder_name, checkpoint_path, device): + print(f"[!] Loading {vocoder_name}!") + vocoder = None + if vocoder_name in ("hifigan_T2_v1", "hifigan_univ_v1"): + vocoder = load_hifigan(checkpoint_path, device) + else: + raise NotImplementedError( + f"Vocoder {vocoder_name} not implemented! define a load_<> method for it" + ) + + denoiser = Denoiser(vocoder, mode="zeros") + print(f"[+] {vocoder_name} loaded!") + return vocoder, denoiser + + +def load_matcha(model_name, checkpoint_path, device): + print(f"[!] Loading {model_name}!") + model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device) + _ = model.eval() + + print(f"[+] {model_name} loaded!") + return model + + +def to_waveform(mel, vocoder, denoiser=None): + audio = vocoder(mel).clamp(-1, 1) + if denoiser is not None: + audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze() + + return audio.cpu().squeeze() + + +def save_to_folder(filename: str, output: dict, folder: str): + folder = Path(folder) + folder.mkdir(exist_ok=True, parents=True) + plot_spectrogram_to_numpy(np.array(output["mel"].squeeze().float().cpu()), f"{filename}.png") + np.save(folder / f"{filename}", output["mel"].cpu().numpy()) + sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24") + return folder.resolve() / f"{filename}.wav" + + +def validate_args(args): + assert ( + args.text or args.file + ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." + assert args.temperature >= 0, "Sampling temperature cannot be negative" + assert args.steps > 0, "Number of ODE steps must be greater than 0" + + if args.checkpoint_path is None: + # When using pretrained models + if args.model in SINGLESPEAKER_MODEL: + args = validate_args_for_single_speaker_model(args) + + if args.model in MULTISPEAKER_MODEL: + args = validate_args_for_multispeaker_model(args) + else: + # When using a custom model + if args.vocoder != "hifigan_univ_v1": + warn_ = "[-] Using custom model checkpoint! I would suggest passing --vocoder hifigan_univ_v1, unless the custom model is trained on LJ Speech." + warnings.warn(warn_, UserWarning) + if args.speaking_rate is None: + args.speaking_rate = 1.0 + + if args.batched: + assert args.batch_size > 0, "Batch size must be greater than 0" + assert args.speaking_rate > 0, "Speaking rate must be greater than 0" + + return args + + +def validate_args_for_multispeaker_model(args): + if args.vocoder is not None: + if args.vocoder != MULTISPEAKER_MODEL[args.model]["vocoder"]: + warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {MULTISPEAKER_MODEL[args.model]['vocoder']}" + warnings.warn(warn_, UserWarning) + else: + args.vocoder = MULTISPEAKER_MODEL[args.model]["vocoder"] + + if args.speaking_rate is None: + args.speaking_rate = MULTISPEAKER_MODEL[args.model]["speaking_rate"] + + spk_range = MULTISPEAKER_MODEL[args.model]["spk_range"] + if args.spk is not None: + assert ( + args.spk >= spk_range[0] and args.spk <= spk_range[-1] + ), f"Speaker ID must be between {spk_range} for this model." + else: + available_spk_id = MULTISPEAKER_MODEL[args.model]["spk"] + warn_ = f"[!] Speaker ID not provided! Using speaker ID {available_spk_id}" + warnings.warn(warn_, UserWarning) + args.spk = available_spk_id + + return args + + +def validate_args_for_single_speaker_model(args): + if args.vocoder is not None: + if args.vocoder != SINGLESPEAKER_MODEL[args.model]["vocoder"]: + warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {SINGLESPEAKER_MODEL[args.model]['vocoder']}" + warnings.warn(warn_, UserWarning) + else: + args.vocoder = SINGLESPEAKER_MODEL[args.model]["vocoder"] + + if args.speaking_rate is None: + args.speaking_rate = SINGLESPEAKER_MODEL[args.model]["speaking_rate"] + + if args.spk != SINGLESPEAKER_MODEL[args.model]["spk"]: + warn_ = f"[-] Ignoring speaker id {args.spk} for {args.model}" + warnings.warn(warn_, UserWarning) + args.spk = SINGLESPEAKER_MODEL[args.model]["spk"] + + return args + + +@torch.inference_mode() +def cli(): + parser = argparse.ArgumentParser( + description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" + ) + parser.add_argument( + "--model", + type=str, + default="matcha_ljspeech", + help="Model to use", + choices=MATCHA_URLS.keys(), + ) + + parser.add_argument( + "--checkpoint_path", + type=str, + default=None, + help="Path to the custom model checkpoint", + ) + + parser.add_argument( + "--vocoder", + type=str, + default=None, + help="Vocoder to use (default: will use the one suggested with the pretrained model))", + choices=VOCODER_URLS.keys(), + ) + parser.add_argument("--text", type=str, default=None, help="Text to synthesize") + parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") + parser.add_argument("--spk", type=int, default=None, help="Speaker ID") + parser.add_argument( + "--temperature", + type=float, + default=0.667, + help="Variance of the x0 noise (default: 0.667)", + ) + parser.add_argument( + "--speaking_rate", + type=float, + default=None, + help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", + ) + parser.add_argument("--steps", type=int, default=10, help="Number of ODE steps (default: 10)") + parser.add_argument("--cpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") + parser.add_argument( + "--denoiser_strength", + type=float, + default=0.00025, + help="Strength of the vocoder bias denoiser (default: 0.00025)", + ) + parser.add_argument( + "--output_folder", + type=str, + default=os.getcwd(), + help="Output folder to save results (default: current dir)", + ) + parser.add_argument("--batched", action="store_true", help="Batched inference (default: False)") + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size only useful when --batched (default: 32)" + ) + + args = parser.parse_args() + + args = validate_args(args) + device = get_device(args) + print_config(args) + paths = assert_required_models_available(args) + + if args.checkpoint_path is not None: + print(f"[🍵] Loading custom model from {args.checkpoint_path}") + paths["matcha"] = args.checkpoint_path + args.model = "custom_model" + + model = load_matcha(args.model, paths["matcha"], device) + vocoder, denoiser = load_vocoder(args.vocoder, paths["vocoder"], device) + + texts = get_texts(args) + + spk = torch.tensor([args.spk], device=device, dtype=torch.long) if args.spk is not None else None + if len(texts) == 1 or not args.batched: + unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk) + else: + batched_synthesis(args, device, model, vocoder, denoiser, texts, spk) + + +class BatchedSynthesisDataset(torch.utils.data.Dataset): + def __init__(self, processed_texts): + self.processed_texts = processed_texts + + def __len__(self): + return len(self.processed_texts) + + def __getitem__(self, idx): + return self.processed_texts[idx] + + +def batched_collate_fn(batch): + x = [] + x_lengths = [] + + for b in batch: + x.append(b["x"].squeeze(0)) + x_lengths.append(b["x_lengths"]) + + x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) + x_lengths = torch.concat(x_lengths, dim=0) + return {"x": x, "x_lengths": x_lengths} + + +def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk): + total_rtf = [] + total_rtf_w = [] + processed_text = [process_text(i, text, "cpu") for i, text in enumerate(texts)] + dataloader = torch.utils.data.DataLoader( + BatchedSynthesisDataset(processed_text), + batch_size=args.batch_size, + collate_fn=batched_collate_fn, + num_workers=8, + ) + for i, batch in enumerate(dataloader): + i = i + 1 + start_t = dt.datetime.now() + b = batch["x"].shape[0] + output = model.synthesise( + batch["x"].to(device), + batch["x_lengths"].to(device), + n_timesteps=args.steps, + temperature=args.temperature, + spks=spk.expand(b) if spk is not None else spk, + length_scale=args.speaking_rate, + ) + + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + t = (dt.datetime.now() - start_t).total_seconds() + rtf_w = t * 22050 / (output["waveform"].shape[-1]) + print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}") + print(f"[🍵-Batch: {i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}") + total_rtf.append(output["rtf"]) + total_rtf_w.append(rtf_w) + for j in range(output["mel"].shape[0]): + base_name = f"utterance_{j:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{j:03d}" + length = output["mel_lengths"][j] + new_dict = {"mel": output["mel"][j][:, :length], "waveform": output["waveform"][j][: length * 256]} + location = save_to_folder(base_name, new_dict, args.output_folder) + print(f"[🍵-{j}] Waveform saved: {location}") + + print("".join(["="] * 100)) + print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}") + print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}") + print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!") + + +def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk): + total_rtf = [] + total_rtf_w = [] + for i, text in enumerate(texts): + i = i + 1 + base_name = f"utterance_{i:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{i:03d}" + + print("".join(["="] * 100)) + text = text.strip() + text_processed = process_text(i, text, device) + + print(f"[🍵] Whisking Matcha-T(ea)TS for: {i}") + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=args.steps, + temperature=args.temperature, + spks=spk, + length_scale=args.speaking_rate, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + # RTF with HiFiGAN + t = (dt.datetime.now() - start_t).total_seconds() + rtf_w = t * 22050 / (output["waveform"].shape[-1]) + print(f"[🍵-{i}] Matcha-TTS RTF: {output['rtf']:.4f}") + print(f"[🍵-{i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}") + total_rtf.append(output["rtf"]) + total_rtf_w.append(rtf_w) + + location = save_to_folder(base_name, output, args.output_folder) + print(f"[+] Waveform saved: {location}") + + print("".join(["="] * 100)) + print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}") + print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}") + print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!") + + +def print_config(args): + print("[!] Configurations: ") + print(f"\t- Model: {args.model}") + print(f"\t- Vocoder: {args.vocoder}") + print(f"\t- Temperature: {args.temperature}") + print(f"\t- Speaking rate: {args.speaking_rate}") + print(f"\t- Number of ODE steps: {args.steps}") + print(f"\t- Speaker: {args.spk}") + + +def get_device(args): + if torch.cuda.is_available() and not args.cpu: + print("[+] GPU Available! Using GPU") + device = torch.device("cuda") + else: + print("[-] GPU not available or forced CPU run! Using CPU") + device = torch.device("cpu") + return device + + +if __name__ == "__main__": + cli() diff --git a/third_party/Matcha-TTS/matcha/data/__init__.py b/third_party/Matcha-TTS/matcha/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/data/components/__init__.py b/third_party/Matcha-TTS/matcha/data/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/data/text_mel_datamodule.py b/third_party/Matcha-TTS/matcha/data/text_mel_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..e10dfcb8bba8fbd1d04272a70d5acfe886ae5107 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/data/text_mel_datamodule.py @@ -0,0 +1,274 @@ +import random +from pathlib import Path +from typing import Any, Dict, Optional + +import numpy as np +import torch +import torchaudio as ta +from lightning import LightningDataModule +from torch.utils.data.dataloader import DataLoader + +from matcha.text import text_to_sequence +from matcha.utils.audio import mel_spectrogram +from matcha.utils.model import fix_len_compatibility, normalize +from matcha.utils.utils import intersperse + + +def parse_filelist(filelist_path, split_char="|"): + with open(filelist_path, encoding="utf-8") as f: + filepaths_and_text = [line.strip().split(split_char) for line in f] + return filepaths_and_text + + +class TextMelDataModule(LightningDataModule): + def __init__( # pylint: disable=unused-argument + self, + name, + train_filelist_path, + valid_filelist_path, + batch_size, + num_workers, + pin_memory, + cleaners, + add_blank, + n_spks, + n_fft, + n_feats, + sample_rate, + hop_length, + win_length, + f_min, + f_max, + data_statistics, + seed, + load_durations, + ): + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument + """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + + This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be + careful not to execute things like random split twice! + """ + # load and split datasets only if not loaded already + + self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init + self.hparams.train_filelist_path, + self.hparams.n_spks, + self.hparams.cleaners, + self.hparams.add_blank, + self.hparams.n_fft, + self.hparams.n_feats, + self.hparams.sample_rate, + self.hparams.hop_length, + self.hparams.win_length, + self.hparams.f_min, + self.hparams.f_max, + self.hparams.data_statistics, + self.hparams.seed, + self.hparams.load_durations, + ) + self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init + self.hparams.valid_filelist_path, + self.hparams.n_spks, + self.hparams.cleaners, + self.hparams.add_blank, + self.hparams.n_fft, + self.hparams.n_feats, + self.hparams.sample_rate, + self.hparams.hop_length, + self.hparams.win_length, + self.hparams.f_min, + self.hparams.f_max, + self.hparams.data_statistics, + self.hparams.seed, + self.hparams.load_durations, + ) + + def train_dataloader(self): + return DataLoader( + dataset=self.trainset, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=True, + collate_fn=TextMelBatchCollate(self.hparams.n_spks), + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.validset, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + collate_fn=TextMelBatchCollate(self.hparams.n_spks), + ) + + def teardown(self, stage: Optional[str] = None): + """Clean up after fit or test.""" + pass # pylint: disable=unnecessary-pass + + def state_dict(self): + """Extra things to save to checkpoint.""" + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Things to do when loading checkpoint.""" + pass # pylint: disable=unnecessary-pass + + +class TextMelDataset(torch.utils.data.Dataset): + def __init__( + self, + filelist_path, + n_spks, + cleaners, + add_blank=True, + n_fft=1024, + n_mels=80, + sample_rate=22050, + hop_length=256, + win_length=1024, + f_min=0.0, + f_max=8000, + data_parameters=None, + seed=None, + load_durations=False, + ): + self.filepaths_and_text = parse_filelist(filelist_path) + self.n_spks = n_spks + self.cleaners = cleaners + self.add_blank = add_blank + self.n_fft = n_fft + self.n_mels = n_mels + self.sample_rate = sample_rate + self.hop_length = hop_length + self.win_length = win_length + self.f_min = f_min + self.f_max = f_max + self.load_durations = load_durations + + if data_parameters is not None: + self.data_parameters = data_parameters + else: + self.data_parameters = {"mel_mean": 0, "mel_std": 1} + random.seed(seed) + random.shuffle(self.filepaths_and_text) + + def get_datapoint(self, filepath_and_text): + if self.n_spks > 1: + filepath, spk, text = ( + filepath_and_text[0], + int(filepath_and_text[1]), + filepath_and_text[2], + ) + else: + filepath, text = filepath_and_text[0], filepath_and_text[1] + spk = None + + text, cleaned_text = self.get_text(text, add_blank=self.add_blank) + mel = self.get_mel(filepath) + + durations = self.get_durations(filepath, text) if self.load_durations else None + + return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations} + + def get_durations(self, filepath, text): + filepath = Path(filepath) + data_dir, name = filepath.parent.parent, filepath.stem + + try: + dur_loc = data_dir / "durations" / f"{name}.npy" + durs = torch.from_numpy(np.load(dur_loc).astype(int)) + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n" + ) from e + + assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match" + + return durs + + def get_mel(self, filepath): + audio, sr = ta.load(filepath) + assert sr == self.sample_rate + mel = mel_spectrogram( + audio, + self.n_fft, + self.n_mels, + self.sample_rate, + self.hop_length, + self.win_length, + self.f_min, + self.f_max, + center=False, + ).squeeze() + mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"]) + return mel + + def get_text(self, text, add_blank=True): + text_norm, cleaned_text = text_to_sequence(text, self.cleaners) + if self.add_blank: + text_norm = intersperse(text_norm, 0) + text_norm = torch.IntTensor(text_norm) + return text_norm, cleaned_text + + def __getitem__(self, index): + datapoint = self.get_datapoint(self.filepaths_and_text[index]) + return datapoint + + def __len__(self): + return len(self.filepaths_and_text) + + +class TextMelBatchCollate: + def __init__(self, n_spks): + self.n_spks = n_spks + + def __call__(self, batch): + B = len(batch) + y_max_length = max([item["y"].shape[-1] for item in batch]) + y_max_length = fix_len_compatibility(y_max_length) + x_max_length = max([item["x"].shape[-1] for item in batch]) + n_feats = batch[0]["y"].shape[-2] + + y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) + x = torch.zeros((B, x_max_length), dtype=torch.long) + durations = torch.zeros((B, x_max_length), dtype=torch.long) + + y_lengths, x_lengths = [], [] + spks = [] + filepaths, x_texts = [], [] + for i, item in enumerate(batch): + y_, x_ = item["y"], item["x"] + y_lengths.append(y_.shape[-1]) + x_lengths.append(x_.shape[-1]) + y[i, :, : y_.shape[-1]] = y_ + x[i, : x_.shape[-1]] = x_ + spks.append(item["spk"]) + filepaths.append(item["filepath"]) + x_texts.append(item["x_text"]) + if item["durations"] is not None: + durations[i, : item["durations"].shape[-1]] = item["durations"] + + y_lengths = torch.tensor(y_lengths, dtype=torch.long) + x_lengths = torch.tensor(x_lengths, dtype=torch.long) + spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None + + return { + "x": x, + "x_lengths": x_lengths, + "y": y, + "y_lengths": y_lengths, + "spks": spks, + "filepaths": filepaths, + "x_texts": x_texts, + "durations": durations if not torch.eq(durations, 0).all() else None, + } diff --git a/third_party/Matcha-TTS/matcha/hifigan/LICENSE b/third_party/Matcha-TTS/matcha/hifigan/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..91751daed806f63ac594cf077a3065f719a41662 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/Matcha-TTS/matcha/hifigan/README.md b/third_party/Matcha-TTS/matcha/hifigan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5db25850451a794b1db1b15b08e82c1d802edbb3 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/README.md @@ -0,0 +1,101 @@ +# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis + +### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae + +In our [paper](https://arxiv.org/abs/2010.05646), +we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
+We provide our implementation and pretrained models as open source in this repository. + +**Abstract :** +Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. +Although such methods improve the sampling efficiency and memory usage, +their sample quality has not yet reached that of autoregressive and flow-based generative models. +In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. +As speech audio consists of sinusoidal signals with various periods, +we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. +A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method +demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than +real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen +speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times +faster than real-time on CPU with comparable quality to an autoregressive counterpart. + +Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. + +## Pre-requisites + +1. Python >= 3.6 +2. Clone this repository. +3. Install python requirements. Please refer [requirements.txt](requirements.txt) +4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). + And move all wav files to `LJSpeech-1.1/wavs` + +## Training + +``` +python train.py --config config_v1.json +``` + +To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
+Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
+You can change the path by adding `--checkpoint_path` option. + +Validation loss during training with V1 generator.
+![validation loss](./validation_loss.png) + +## Pretrained Model + +You can also use pretrained models we provide.
+[Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
+Details of each folder are as in follows: + +| Folder Name | Generator | Dataset | Fine-Tuned | +| ------------ | --------- | --------- | ------------------------------------------------------ | +| LJ_V1 | V1 | LJSpeech | No | +| LJ_V2 | V2 | LJSpeech | No | +| LJ_V3 | V3 | LJSpeech | No | +| LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| VCTK_V1 | V1 | VCTK | No | +| VCTK_V2 | V2 | VCTK | No | +| VCTK_V3 | V3 | VCTK | No | +| UNIVERSAL_V1 | V1 | Universal | No | + +We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. + +## Fine-Tuning + +1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
+ The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
+ Example: + ` Audio File : LJ001-0001.wav +Mel-Spectrogram File : LJ001-0001.npy` +2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
+3. Run the following command. + ``` + python train.py --fine_tuning True --config config_v1.json + ``` + For other command line options, please refer to the training section. + +## Inference from wav file + +1. Make `test_files` directory and copy wav files into the directory. +2. Run the following command. + ` python inference.py --checkpoint_file [generator checkpoint file path]` + Generated wav files are saved in `generated_files` by default.
+ You can change the path by adding `--output_dir` option. + +## Inference for end-to-end speech synthesis + +1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
+ You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), + [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. +2. Run the following command. + ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` + Generated wav files are saved in `generated_files_from_mel` by default.
+ You can change the path by adding `--output_dir` option. + +## Acknowledgements + +We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) +and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. diff --git a/third_party/Matcha-TTS/matcha/hifigan/__init__.py b/third_party/Matcha-TTS/matcha/hifigan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/hifigan/config.py b/third_party/Matcha-TTS/matcha/hifigan/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b3abea9e151a08864353d32066bd4935e24b82e7 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/config.py @@ -0,0 +1,28 @@ +v1 = { + "resblock": "1", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0004, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_initial_channel": 256, + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + "sampling_rate": 22050, + "fmin": 0, + "fmax": 8000, + "fmax_loss": None, + "num_workers": 4, + "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, +} diff --git a/third_party/Matcha-TTS/matcha/hifigan/denoiser.py b/third_party/Matcha-TTS/matcha/hifigan/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..9fd33312a09b1940374a0e29a97fe3a1a1dac7d2 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/denoiser.py @@ -0,0 +1,64 @@ +# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py + +"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" +import torch + + +class Denoiser(torch.nn.Module): + """Removes model bias from audio produced with waveglow""" + + def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): + super().__init__() + self.filter_length = filter_length + self.hop_length = int(filter_length / n_overlap) + self.win_length = win_length + + dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device + self.device = device + if mode == "zeros": + mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) + elif mode == "normal": + mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) + else: + raise Exception(f"Mode {mode} if not supported") + + def stft_fn(audio, n_fft, hop_length, win_length, window): + spec = torch.stft( + audio, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + return_complex=True, + ) + spec = torch.view_as_real(spec) + return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) + + self.stft = lambda x: stft_fn( + audio=x, + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=torch.hann_window(self.win_length, device=device), + ) + self.istft = lambda x, y: torch.istft( + torch.complex(x * torch.cos(y), x * torch.sin(y)), + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=torch.hann_window(self.win_length, device=device), + ) + + with torch.no_grad(): + bias_audio = vocoder(mel_input).float().squeeze(0) + bias_spec, _ = self.stft(bias_audio) + + self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) + + @torch.inference_mode() + def forward(self, audio, strength=0.0005): + audio_spec, audio_angles = self.stft(audio) + audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength + audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) + audio_denoised = self.istft(audio_spec_denoised, audio_angles) + return audio_denoised diff --git a/third_party/Matcha-TTS/matcha/hifigan/env.py b/third_party/Matcha-TTS/matcha/hifigan/env.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea4f948a3f002921bf9bc24f52cbc1c0b1fc2ec --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/env.py @@ -0,0 +1,17 @@ +""" from https://github.com/jik876/hifi-gan """ + +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/third_party/Matcha-TTS/matcha/hifigan/meldataset.py b/third_party/Matcha-TTS/matcha/hifigan/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8b43ea7965e04a52d5427a485ee911b743057c4a --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/meldataset.py @@ -0,0 +1,217 @@ +""" from https://github.com/jik876/hifi-gan """ + +import math +import os +import random + +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from librosa.util import normalize +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def get_dataset_filelist(a): + with open(a.input_training_file, encoding="utf-8") as fi: + training_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + + with open(a.input_validation_file, encoding="utf-8") as fi: + validation_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + return training_files, validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__( + self, + training_files, + segment_size, + n_fft, + num_mels, + hop_size, + win_size, + sampling_rate, + fmin, + fmax, + split=True, + shuffle=True, + n_cache_reuse=1, + device=None, + fmax_loss=None, + fine_tuning=False, + base_mels_path=None, + ): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.cached_wav = None + self.n_cache_reuse = n_cache_reuse + self._cache_ref_count = 0 + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + def __getitem__(self, index): + filename = self.audio_files[index] + if self._cache_ref_count == 0: + audio, sampling_rate = load_wav(filename) + audio = audio / MAX_WAV_VALUE + if not self.fine_tuning: + audio = normalize(audio) * 0.95 + self.cached_wav = audio + if sampling_rate != self.sampling_rate: + raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") + self._cache_ref_count = self.n_cache_reuse + else: + audio = self.cached_wav + self._cache_ref_count -= 1 + + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) + + if not self.fine_tuning: + if self.split: + if audio.size(1) >= self.segment_size: + max_audio_start = audio.size(1) - self.segment_size + audio_start = random.randint(0, max_audio_start) + audio = audio[:, audio_start : audio_start + self.segment_size] + else: + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + + mel = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax, + center=False, + ) + else: + mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start : mel_start + frames_per_seg] + audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] + else: + mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + + mel_loss = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax_loss, + center=False, + ) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + def __len__(self): + return len(self.audio_files) diff --git a/third_party/Matcha-TTS/matcha/hifigan/models.py b/third_party/Matcha-TTS/matcha/hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..d209d9a4e99ec29e4167a5a2eaa62d72b3eff694 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/models.py @@ -0,0 +1,368 @@ +""" from https://github.com/jik876/hifi-gan """ + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .xutils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.h = h + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super().__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super().__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for _, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/third_party/Matcha-TTS/matcha/hifigan/xutils.py b/third_party/Matcha-TTS/matcha/hifigan/xutils.py new file mode 100644 index 0000000000000000000000000000000000000000..eefadcb7a1d0bf9015e636b88fee3e22c9771bc5 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/hifigan/xutils.py @@ -0,0 +1,60 @@ +""" from https://github.com/jik876/hifi-gan """ + +import glob +import os + +import matplotlib +import torch +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print(f"Saving checkpoint to {filepath}") + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + "????????") + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] diff --git a/third_party/Matcha-TTS/matcha/models/__init__.py b/third_party/Matcha-TTS/matcha/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/models/baselightningmodule.py b/third_party/Matcha-TTS/matcha/models/baselightningmodule.py new file mode 100644 index 0000000000000000000000000000000000000000..f8abe7b44f44688ff00720f7e56e34b75894d176 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/models/baselightningmodule.py @@ -0,0 +1,210 @@ +""" +This is a base lightning module that can be used to train a model. +The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. +""" +import inspect +from abc import ABC +from typing import Any, Dict + +import torch +from lightning import LightningModule +from lightning.pytorch.utilities import grad_norm + +from matcha import utils +from matcha.utils.utils import plot_tensor + +log = utils.get_pylogger(__name__) + + +class BaseLightningClass(LightningModule, ABC): + def update_data_statistics(self, data_statistics): + if data_statistics is None: + data_statistics = { + "mel_mean": 0.0, + "mel_std": 1.0, + } + + self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) + self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) + + def configure_optimizers(self) -> Any: + optimizer = self.hparams.optimizer(params=self.parameters()) + if self.hparams.scheduler not in (None, {}): + scheduler_args = {} + # Manage last epoch for exponential schedulers + if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: + if hasattr(self, "ckpt_loaded_epoch"): + current_epoch = self.ckpt_loaded_epoch - 1 + else: + current_epoch = -1 + + scheduler_args.update({"optimizer": optimizer}) + scheduler = self.hparams.scheduler.scheduler(**scheduler_args) + scheduler.last_epoch = current_epoch + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": self.hparams.scheduler.lightning_args.interval, + "frequency": self.hparams.scheduler.lightning_args.frequency, + "name": "learning_rate", + }, + } + + return {"optimizer": optimizer} + + def get_losses(self, batch): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + spks = batch["spks"] + + dur_loss, prior_loss, diff_loss, *_ = self( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + spks=spks, + out_size=self.out_size, + durations=batch["durations"], + ) + return { + "dur_loss": dur_loss, + "prior_loss": prior_loss, + "diff_loss": diff_loss, + } + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init + + def training_step(self, batch: Any, batch_idx: int): + loss_dict = self.get_losses(batch) + self.log( + "step", + float(self.global_step), + on_step=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + self.log( + "sub_loss/train_dur_loss", + loss_dict["dur_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/train_prior_loss", + loss_dict["prior_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/train_diff_loss", + loss_dict["diff_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + total_loss = sum(loss_dict.values()) + self.log( + "loss/train", + total_loss, + on_step=True, + on_epoch=True, + logger=True, + prog_bar=True, + sync_dist=True, + ) + + return {"loss": total_loss, "log": loss_dict} + + def validation_step(self, batch: Any, batch_idx: int): + loss_dict = self.get_losses(batch) + self.log( + "sub_loss/val_dur_loss", + loss_dict["dur_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/val_prior_loss", + loss_dict["prior_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/val_diff_loss", + loss_dict["diff_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + total_loss = sum(loss_dict.values()) + self.log( + "loss/val", + total_loss, + on_step=True, + on_epoch=True, + logger=True, + prog_bar=True, + sync_dist=True, + ) + + return total_loss + + def on_validation_end(self) -> None: + if self.trainer.is_global_zero: + one_batch = next(iter(self.trainer.val_dataloaders)) + if self.current_epoch == 0: + log.debug("Plotting original samples") + for i in range(2): + y = one_batch["y"][i].unsqueeze(0).to(self.device) + self.logger.experiment.add_image( + f"original/{i}", + plot_tensor(y.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + + log.debug("Synthesising...") + for i in range(2): + x = one_batch["x"][i].unsqueeze(0).to(self.device) + x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) + spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None + output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) + y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] + attn = output["attn"] + self.logger.experiment.add_image( + f"generated_enc/{i}", + plot_tensor(y_enc.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + self.logger.experiment.add_image( + f"generated_dec/{i}", + plot_tensor(y_dec.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + self.logger.experiment.add_image( + f"alignment/{i}", + plot_tensor(attn.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + + def on_before_optimizer_step(self, optimizer): + self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) diff --git a/third_party/Matcha-TTS/matcha/models/components/__init__.py b/third_party/Matcha-TTS/matcha/models/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/models/components/decoder.py b/third_party/Matcha-TTS/matcha/models/components/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1137cd7008e9d07b4f306926a82e44c2b2cddbdf --- /dev/null +++ b/third_party/Matcha-TTS/matcha/models/components/decoder.py @@ -0,0 +1,443 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from conformer import ConformerBlock +from diffusers.models.activations import get_activation +from einops import pack, rearrange, repeat + +from matcha.models.components.transformer import BasicTransformerBlock + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Block1D(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv1d(dim, dim_out, 3, padding=1), + torch.nn.GroupNorm(groups, dim_out), + nn.Mish(), + ) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock1D(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super().__init__() + self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) + + self.block1 = Block1D(dim, dim_out, groups=groups) + self.block2 = Block1D(dim_out, dim_out, groups=groups) + + self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class Downsample1D(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs): + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + + +class ConformerWrapper(ConformerBlock): + def __init__( # pylint: disable=useless-super-delegation + self, + *, + dim, + dim_head=64, + heads=8, + ff_mult=4, + conv_expansion_factor=2, + conv_kernel_size=31, + attn_dropout=0, + ff_dropout=0, + conv_dropout=0, + conv_causal=False, + ): + super().__init__( + dim=dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + conv_dropout=conv_dropout, + conv_causal=conv_causal, + ) + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + ): + return super().forward(x=hidden_states, mask=attention_mask.bool()) + + +class Decoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + down_block_type="transformer", + mid_block_type="transformer", + up_block_type="transformer", + ): + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + down_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for i in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + self.get_block( + mid_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + + resnet = ResnetBlock1D( + dim=2 * input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + up_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + + self.initialize_weights() + # nn.init.normal_(self.final_proj.weight) + + @staticmethod + def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn): + if block_type == "conformer": + block = ConformerWrapper( + dim=dim, + dim_head=attention_head_dim, + heads=num_heads, + ff_mult=1, + conv_expansion_factor=2, + ff_dropout=dropout, + attn_dropout=dropout, + conv_dropout=dropout, + conv_kernel_size=31, + ) + elif block_type == "transformer": + block = BasicTransformerBlock( + dim=dim, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + else: + raise ValueError(f"Unknown block type {block_type}") + + return block + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c") + mask_down = rearrange(mask_down, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_down, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_down = rearrange(mask_down, "b t -> b 1 t") + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c") + mask_mid = rearrange(mask_mid, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_mid, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_mid = rearrange(mask_mid, "b t -> b 1 t") + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t) + x = rearrange(x, "b c t -> b t c") + mask_up = rearrange(mask_up, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_up, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_up = rearrange(mask_up, "b t -> b 1 t") + x = upsample(x * mask_up) + + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + + return output * mask diff --git a/third_party/Matcha-TTS/matcha/models/components/flow_matching.py b/third_party/Matcha-TTS/matcha/models/components/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..5cad7431ef66a8d11da32a77c1af7f6e31d6b774 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/models/components/flow_matching.py @@ -0,0 +1,132 @@ +from abc import ABC + +import torch +import torch.nn.functional as F + +from matcha.models.components.decoder import Decoder +from matcha.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +class BASECFM(torch.nn.Module, ABC): + def __init__( + self, + n_feats, + cfm_params, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.n_feats = n_feats + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.solver = cfm_params.solver + if hasattr(cfm_params, "sigma_min"): + self.sigma_min = cfm_params.sigma_min + else: + self.sigma_min = 1e-4 + + self.estimator = None + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = torch.randn_like(mu) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + for step in range(1, len(t_span)): + dphi_dt = self.estimator(x, mask, mu, t, spks, cond) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( + torch.sum(mask) * u.shape[1] + ) + return loss, y + + +class CFM(BASECFM): + def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + + in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) + # Just change the architecture of the estimator here + self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) diff --git a/third_party/Matcha-TTS/matcha/models/components/text_encoder.py b/third_party/Matcha-TTS/matcha/models/components/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a388d05d6351fa2c9d9632fed0942d51fbec067b --- /dev/null +++ b/third_party/Matcha-TTS/matcha/models/components/text_encoder.py @@ -0,0 +1,410 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import math + +import torch +import torch.nn as nn +from einops import rearrange + +import matcha.utils as utils +from matcha.utils.model import sequence_mask + +log = utils.get_pylogger(__name__) + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.conv_layers = torch.nn.ModuleList() + self.norm_layers = torch.nn.ModuleList() + self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class RotaryPositionalEmbeddings(nn.Module): + """ + ## RoPE module + + Rotary encoding transforms pairs of features by rotating in the 2D plane. + That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. + Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it + by an angle depending on the position of the token. + """ + + def __init__(self, d: int, base: int = 10_000): + r""" + * `d` is the number of features $d$ + * `base` is the constant used for calculating $\Theta$ + """ + super().__init__() + + self.base = base + self.d = int(d) + self.cos_cached = None + self.sin_cached = None + + def _build_cache(self, x: torch.Tensor): + r""" + Cache $\cos$ and $\sin$ values + """ + # Return if cache is already built + if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: + return + + # Get sequence length + seq_len = x.shape[0] + + # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.einsum("n,d->nd", seq_idx, theta) + + # Concatenate so that for row $m$ we have + # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ + idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) + + # Cache them + self.cos_cached = idx_theta2.cos()[:, None, None, :] + self.sin_cached = idx_theta2.sin()[:, None, None, :] + + def _neg_half(self, x: torch.Tensor): + # $\frac{d}{2}$ + d_2 = self.d // 2 + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + + def forward(self, x: torch.Tensor): + """ + * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` + """ + # Cache $\cos$ and $\sin$ values + x = rearrange(x, "b h t d -> t b h d") + + self._build_cache(x) + + # Split the features, we can choose to apply rotary embeddings only to a partial set of features. + x_rope, x_pass = x[..., : self.d], x[..., self.d :] + + # Calculate + # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + neg_half_x = self._neg_half(x_rope) + + x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) + + return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + heads_share=True, + p_dropout=0.0, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + + # from https://nn.labml.ai/transformers/rope/index.html + self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) + key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) + value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) + + query = self.query_rotary_pe(query) + key = self.key_rotary_pe(key) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + @staticmethod + def _attention_bias_proximal(length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class TextEncoder(nn.Module): + def __init__( + self, + encoder_type, + encoder_params, + duration_predictor_params, + n_vocab, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.encoder_type = encoder_type + self.n_vocab = n_vocab + self.n_feats = encoder_params.n_feats + self.n_channels = encoder_params.n_channels + self.spk_emb_dim = spk_emb_dim + self.n_spks = n_spks + + self.emb = torch.nn.Embedding(n_vocab, self.n_channels) + torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5) + + if encoder_params.prenet: + self.prenet = ConvReluNorm( + self.n_channels, + self.n_channels, + self.n_channels, + kernel_size=5, + n_layers=3, + p_dropout=0.5, + ) + else: + self.prenet = lambda x, x_mask: x + + self.encoder = Encoder( + encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0), + encoder_params.filter_channels, + encoder_params.n_heads, + encoder_params.n_layers, + encoder_params.kernel_size, + encoder_params.p_dropout, + ) + + self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) + self.proj_w = DurationPredictor( + self.n_channels + (spk_emb_dim if n_spks > 1 else 0), + duration_predictor_params.filter_channels_dp, + duration_predictor_params.kernel_size, + duration_predictor_params.p_dropout, + ) + + def forward(self, x, x_lengths, spks=None): + """Run forward pass to the transformer based encoder and duration predictor + + Args: + x (torch.Tensor): text input + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): text input lengths + shape: (batch_size,) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size,) + + Returns: + mu (torch.Tensor): average output of the encoder + shape: (batch_size, n_feats, max_text_length) + logw (torch.Tensor): log duration predicted by the duration predictor + shape: (batch_size, 1, max_text_length) + x_mask (torch.Tensor): mask for the text input + shape: (batch_size, 1, max_text_length) + """ + x = self.emb(x) * math.sqrt(self.n_channels) + x = torch.transpose(x, 1, -1) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x, x_mask) + if self.n_spks > 1: + x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) + x = self.encoder(x, x_mask) + mu = self.proj_m(x) * x_mask + + x_dp = torch.detach(x) + logw = self.proj_w(x_dp, x_mask) + + return mu, logw, x_mask diff --git a/third_party/Matcha-TTS/matcha/models/components/transformer.py b/third_party/Matcha-TTS/matcha/models/components/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1afa3aff5383912209e508676c6885e13ef4ee --- /dev/null +++ b/third_party/Matcha-TTS/matcha/models/components/transformer.py @@ -0,0 +1,316 @@ +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from diffusers.models.attention import ( + GEGLU, + GELU, + AdaLayerNorm, + AdaLayerNormZero, + ApproximateGELU, +) +from diffusers.models.attention_processor import Attention +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.utils.torch_utils import maybe_allow_in_graph + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = out_features if isinstance(out_features, list) else [out_features] + self.proj = LoRACompatibleLinear(in_features, out_features) + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) + self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + x = self.proj(x) + if self.alpha_logscale: + alpha = torch.exp(self.alpha) + beta = torch.exp(self.beta) + else: + alpha = self.alpha + beta = self.beta + + x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) + + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + elif activation_fn == "snakebeta": + act_fn = SnakeBeta(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + # scale_qk=False, # uncomment this to not to use flash attention + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states diff --git a/third_party/Matcha-TTS/matcha/models/matcha_tts.py b/third_party/Matcha-TTS/matcha/models/matcha_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..07f95ad2e31a2de94974c21f15e28ab5445ff6fc --- /dev/null +++ b/third_party/Matcha-TTS/matcha/models/matcha_tts.py @@ -0,0 +1,244 @@ +import datetime as dt +import math +import random + +import torch + +import matcha.utils.monotonic_align as monotonic_align +from matcha import utils +from matcha.models.baselightningmodule import BaseLightningClass +from matcha.models.components.flow_matching import CFM +from matcha.models.components.text_encoder import TextEncoder +from matcha.utils.model import ( + denormalize, + duration_loss, + fix_len_compatibility, + generate_path, + sequence_mask, +) + +log = utils.get_pylogger(__name__) + + +class MatchaTTS(BaseLightningClass): # 🍵 + def __init__( + self, + n_vocab, + n_spks, + spk_emb_dim, + n_feats, + encoder, + decoder, + cfm, + data_statistics, + out_size, + optimizer=None, + scheduler=None, + prior_loss=True, + use_precomputed_durations=False, + ): + super().__init__() + + self.save_hyperparameters(logger=False) + + self.n_vocab = n_vocab + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.n_feats = n_feats + self.out_size = out_size + self.prior_loss = prior_loss + self.use_precomputed_durations = use_precomputed_durations + + if n_spks > 1: + self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) + + self.encoder = TextEncoder( + encoder.encoder_type, + encoder.encoder_params, + encoder.duration_predictor_params, + n_vocab, + n_spks, + spk_emb_dim, + ) + + self.decoder = CFM( + in_channels=2 * encoder.encoder_params.n_feats, + out_channel=encoder.encoder_params.n_feats, + cfm_params=cfm, + decoder_params=decoder, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + + self.update_data_statistics(data_statistics) + + @torch.inference_mode() + def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0): + """ + Generates mel-spectrogram from text. Returns: + 1. encoder outputs + 2. decoder outputs + 3. generated alignment + + Args: + x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): lengths of texts in batch. + shape: (batch_size,) + n_timesteps (int): number of steps to use for reverse diffusion in decoder. + temperature (float, optional): controls variance of terminal distribution. + spks (bool, optional): speaker ids. + shape: (batch_size,) + length_scale (float, optional): controls speech pace. + Increase value to slow down generated speech and vice versa. + + Returns: + dict: { + "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Average mel spectrogram generated by the encoder + "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Refined mel spectrogram improved by the CFM + "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length), + # Alignment map between text and mel spectrogram + "mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Denormalized mel spectrogram + "mel_lengths": torch.Tensor, shape: (batch_size,), + # Lengths of mel spectrograms + "rtf": float, + # Real-time factor + """ + # For RTF computation + t = dt.datetime.now() + + if self.n_spks > 1: + # Get speaker embedding + spks = self.spk_emb(spks.long()) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) + + w = torch.exp(logw) * x_mask + w_ceil = torch.ceil(w) * length_scale + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = y_lengths.max() + y_max_length_ = fix_len_compatibility(y_max_length) + + # Using obtained durations `w` construct alignment map `attn` + y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + + # Align encoded text and get mu_y + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + encoder_outputs = mu_y[:, :, :y_max_length] + + # Generate sample tracing the probability flow + decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks) + decoder_outputs = decoder_outputs[:, :, :y_max_length] + + t = (dt.datetime.now() - t).total_seconds() + rtf = t * 22050 / (decoder_outputs.shape[-1] * 256) + + return { + "encoder_outputs": encoder_outputs, + "decoder_outputs": decoder_outputs, + "attn": attn[:, :, :y_max_length], + "mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std), + "mel_lengths": y_lengths, + "rtf": rtf, + } + + def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None): + """ + Computes 3 losses: + 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). + 2. prior loss: loss between mel-spectrogram and encoder outputs. + 3. flow matching loss: loss between mel-spectrogram and decoder outputs. + + Args: + x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): lengths of texts in batch. + shape: (batch_size,) + y (torch.Tensor): batch of corresponding mel-spectrograms. + shape: (batch_size, n_feats, max_mel_length) + y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. + shape: (batch_size,) + out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. + Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. + spks (torch.Tensor, optional): speaker ids. + shape: (batch_size,) + """ + if self.n_spks > 1: + # Get speaker embedding + spks = self.spk_emb(spks) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) + y_max_length = y.shape[-1] + + y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + + if self.use_precomputed_durations: + attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1)) + else: + # Use MAS to find most likely alignment `attn` between text and mel-spectrogram + with torch.no_grad(): + const = -0.5 * math.log(2 * math.pi) * self.n_feats + factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) + y_square = torch.matmul(factor.transpose(1, 2), y**2) + y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) + mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) + log_prior = y_square - y_mu_double + mu_square + const + + attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) + attn = attn.detach() # b, t_text, T_mel + + # Compute loss between predicted log-scaled durations and those obtained from MAS + # refered to as prior loss in the paper + logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask + dur_loss = duration_loss(logw, logw_, x_lengths) + + # Cut a small segment of mel-spectrogram in order to increase batch size + # - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it + # - Do not need this hack for Matcha-TTS, but it works with it as well + if not isinstance(out_size, type(None)): + max_offset = (y_lengths - out_size).clamp(0) + offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) + out_offset = torch.LongTensor( + [torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges] + ).to(y_lengths) + attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device) + y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device) + + y_cut_lengths = [] + for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): + y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0) + y_cut_lengths.append(y_cut_length) + cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length + y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper] + attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper] + + y_cut_lengths = torch.LongTensor(y_cut_lengths) + y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask) + + attn = attn_cut + y = y_cut + y_mask = y_cut_mask + + # Align encoded text with mel-spectrogram and get mu_y segment + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + + # Compute loss of the decoder + diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond) + + if self.prior_loss: + prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) + prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) + else: + prior_loss = 0 + + return dur_loss, prior_loss, diff_loss, attn diff --git a/third_party/Matcha-TTS/matcha/onnx/__init__.py b/third_party/Matcha-TTS/matcha/onnx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/matcha/onnx/export.py b/third_party/Matcha-TTS/matcha/onnx/export.py new file mode 100644 index 0000000000000000000000000000000000000000..9b795086158e1ad8a4bb5cd92306f3fa765f71ea --- /dev/null +++ b/third_party/Matcha-TTS/matcha/onnx/export.py @@ -0,0 +1,181 @@ +import argparse +import random +from pathlib import Path + +import numpy as np +import torch +from lightning import LightningModule + +from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder + +DEFAULT_OPSET = 15 + +SEED = 1234 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +class MatchaWithVocoder(LightningModule): + def __init__(self, matcha, vocoder): + super().__init__() + self.matcha = matcha + self.vocoder = vocoder + + def forward(self, x, x_lengths, scales, spks=None): + mel, mel_lengths = self.matcha(x, x_lengths, scales, spks) + wavs = self.vocoder(mel).clamp(-1, 1) + lengths = mel_lengths * 256 + return wavs.squeeze(1), lengths + + +def get_exportable_module(matcha, vocoder, n_timesteps): + """ + Return an appropriate `LighteningModule` and output-node names + based on whether the vocoder is embedded in the final graph + """ + + def onnx_forward_func(x, x_lengths, scales, spks=None): + """ + Custom forward function for accepting + scaler parameters as tensors + """ + # Extract scaler parameters from tensors + temperature = scales[0] + length_scale = scales[1] + output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale) + return output["mel"], output["mel_lengths"] + + # Monkey-patch Matcha's forward function + matcha.forward = onnx_forward_func + + if vocoder is None: + model, output_names = matcha, ["mel", "mel_lengths"] + else: + model = MatchaWithVocoder(matcha, vocoder) + output_names = ["wav", "wav_lengths"] + return model, output_names + + +def get_inputs(is_multi_speaker): + """ + Create dummy inputs for tracing + """ + dummy_input_length = 50 + x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long) + x_lengths = torch.LongTensor([dummy_input_length]) + + # Scales + temperature = 0.667 + length_scale = 1.0 + scales = torch.Tensor([temperature, length_scale]) + + model_inputs = [x, x_lengths, scales] + input_names = [ + "x", + "x_lengths", + "scales", + ] + + if is_multi_speaker: + spks = torch.LongTensor([1]) + model_inputs.append(spks) + input_names.append("spks") + + return tuple(model_inputs), input_names + + +def main(): + parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX") + + parser.add_argument( + "checkpoint_path", + type=str, + help="Path to the model checkpoint", + ) + parser.add_argument("output", type=str, help="Path to output `.onnx` file") + parser.add_argument( + "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)" + ) + parser.add_argument( + "--vocoder-name", + type=str, + choices=list(VOCODER_URLS.keys()), + default=None, + help="Name of the vocoder to embed in the ONNX graph", + ) + parser.add_argument( + "--vocoder-checkpoint-path", + type=str, + default=None, + help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience", + ) + parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15") + + args = parser.parse_args() + + print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}") + print(f"Setting n_timesteps to {args.n_timesteps}") + + checkpoint_path = Path(args.checkpoint_path) + matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu") + + if args.vocoder_name or args.vocoder_checkpoint_path: + assert ( + args.vocoder_name and args.vocoder_checkpoint_path + ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." + vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") + else: + vocoder = None + + is_multi_speaker = matcha.n_spks > 1 + + dummy_input, input_names = get_inputs(is_multi_speaker) + model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps) + + # Set dynamic shape for inputs/outputs + dynamic_axes = { + "x": {0: "batch_size", 1: "time"}, + "x_lengths": {0: "batch_size"}, + } + + if vocoder is None: + dynamic_axes.update( + { + "mel": {0: "batch_size", 2: "time"}, + "mel_lengths": {0: "batch_size"}, + } + ) + else: + print("Embedding the vocoder in the ONNX graph") + dynamic_axes.update( + { + "wav": {0: "batch_size", 1: "time"}, + "wav_lengths": {0: "batch_size"}, + } + ) + + if is_multi_speaker: + dynamic_axes["spks"] = {0: "batch_size"} + + # Create the output directory (if not exists) + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + + model.to_onnx( + args.output, + dummy_input, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=args.opset, + export_params=True, + do_constant_folding=True, + ) + print(f"[🍵] ONNX model exported to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/third_party/Matcha-TTS/matcha/onnx/infer.py b/third_party/Matcha-TTS/matcha/onnx/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..89ca92559c6df3776a07a038d7838242a3d19189 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/onnx/infer.py @@ -0,0 +1,168 @@ +import argparse +import os +import warnings +from pathlib import Path +from time import perf_counter + +import numpy as np +import onnxruntime as ort +import soundfile as sf +import torch + +from matcha.cli import plot_spectrogram_to_numpy, process_text + + +def validate_args(args): + assert ( + args.text or args.file + ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." + assert args.temperature >= 0, "Sampling temperature cannot be negative" + assert args.speaking_rate >= 0, "Speaking rate must be greater than 0" + return args + + +def write_wavs(model, inputs, output_dir, external_vocoder=None): + if external_vocoder is None: + print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") + t0 = perf_counter() + wavs, wav_lengths = model.run(None, inputs) + infer_secs = perf_counter() - t0 + mel_infer_secs = vocoder_infer_secs = None + else: + print("[🍵] Generating mel using Matcha") + mel_t0 = perf_counter() + mels, mel_lengths = model.run(None, inputs) + mel_infer_secs = perf_counter() - mel_t0 + print("Generating waveform from mel using external vocoder") + vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} + vocoder_t0 = perf_counter() + wavs = external_vocoder.run(None, vocoder_inputs)[0] + vocoder_infer_secs = perf_counter() - vocoder_t0 + wavs = wavs.squeeze(1) + wav_lengths = mel_lengths * 256 + infer_secs = mel_infer_secs + vocoder_infer_secs + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)): + output_filename = output_dir.joinpath(f"output_{i + 1}.wav") + audio = wav[:wav_length] + print(f"Writing audio to {output_filename}") + sf.write(output_filename, audio, 22050, "PCM_24") + + wav_secs = wav_lengths.sum() / 22050 + print(f"Inference seconds: {infer_secs}") + print(f"Generated wav seconds: {wav_secs}") + rtf = infer_secs / wav_secs + if mel_infer_secs is not None: + mel_rtf = mel_infer_secs / wav_secs + print(f"Matcha RTF: {mel_rtf}") + if vocoder_infer_secs is not None: + vocoder_rtf = vocoder_infer_secs / wav_secs + print(f"Vocoder RTF: {vocoder_rtf}") + print(f"Overall RTF: {rtf}") + + +def write_mels(model, inputs, output_dir): + t0 = perf_counter() + mels, mel_lengths = model.run(None, inputs) + infer_secs = perf_counter() - t0 + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for i, mel in enumerate(mels): + output_stem = output_dir.joinpath(f"output_{i + 1}") + plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png")) + np.save(output_stem.with_suffix(".numpy"), mel) + + wav_secs = (mel_lengths * 256).sum() / 22050 + print(f"Inference seconds: {infer_secs}") + print(f"Generated wav seconds: {wav_secs}") + rtf = infer_secs / wav_secs + print(f"RTF: {rtf}") + + +def main(): + parser = argparse.ArgumentParser( + description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" + ) + parser.add_argument( + "model", + type=str, + help="ONNX model to use", + ) + parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)") + parser.add_argument("--text", type=str, default=None, help="Text to synthesize") + parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") + parser.add_argument("--spk", type=int, default=None, help="Speaker ID") + parser.add_argument( + "--temperature", + type=float, + default=0.667, + help="Variance of the x0 noise (default: 0.667)", + ) + parser.add_argument( + "--speaking-rate", + type=float, + default=1.0, + help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", + ) + parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") + parser.add_argument( + "--output-dir", + type=str, + default=os.getcwd(), + help="Output folder to save results (default: current dir)", + ) + + args = parser.parse_args() + args = validate_args(args) + + if args.gpu: + providers = ["GPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + model = ort.InferenceSession(args.model, providers=providers) + + model_inputs = model.get_inputs() + model_outputs = list(model.get_outputs()) + + if args.text: + text_lines = args.text.splitlines() + else: + with open(args.file, encoding="utf-8") as file: + text_lines = file.read().splitlines() + + processed_lines = [process_text(0, line, "cpu") for line in text_lines] + x = [line["x"].squeeze() for line in processed_lines] + # Pad + x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) + x = x.detach().cpu().numpy() + x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64) + inputs = { + "x": x, + "x_lengths": x_lengths, + "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32), + } + is_multi_speaker = len(model_inputs) == 4 + if is_multi_speaker: + if args.spk is None: + args.spk = 0 + warn = "[!] Speaker ID not provided! Using speaker ID 0" + warnings.warn(warn, UserWarning) + inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64) + + has_vocoder_embedded = model_outputs[0].name == "wav" + if has_vocoder_embedded: + write_wavs(model, inputs, args.output_dir) + elif args.vocoder: + external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) + write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) + else: + warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" + warnings.warn(warn, UserWarning) + write_mels(model, inputs, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/third_party/Matcha-TTS/matcha/text/__init__.py b/third_party/Matcha-TTS/matcha/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c75d6b5714a0a2d30b95e00a5377c13f29d9b8a --- /dev/null +++ b/third_party/Matcha-TTS/matcha/text/__init__.py @@ -0,0 +1,53 @@ +""" from https://github.com/keithito/tacotron """ +from matcha.text import cleaners +from matcha.text.symbols import symbols + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension + + +def text_to_sequence(text, cleaner_names): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [] + + clean_text = _clean_text(text, cleaner_names) + for symbol in clean_text: + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + return sequence, clean_text + + +def cleaned_text_to_sequence(cleaned_text): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] + return sequence + + +def sequence_to_text(sequence): + """Converts a sequence of IDs back to a string""" + result = "" + for symbol_id in sequence: + s = _id_to_symbol[symbol_id] + result += s + return result + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception("Unknown cleaner: %s" % name) + text = cleaner(text) + return text diff --git a/third_party/Matcha-TTS/matcha/text/cleaners.py b/third_party/Matcha-TTS/matcha/text/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..36776e355257625749290f04c705335e72ffcb52 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/text/cleaners.py @@ -0,0 +1,121 @@ +""" from https://github.com/keithito/tacotron + +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +""" + +import logging +import re + +import phonemizer +from unidecode import unidecode + +# To avoid excessive logging we set the log level of the phonemizer package to Critical +critical_logger = logging.getLogger("phonemizer") +critical_logger.setLevel(logging.CRITICAL) + +# Intializing the phonemizer globally significantly reduces the speed +# now the phonemizer is not initialising at every call +# Might be less flexible, but it is much-much faster +global_phonemizer = phonemizer.backend.EspeakBackend( + language="en-us", + preserve_punctuation=True, + with_stress=True, + language_switch="remove-flags", + logger=critical_logger, +) + + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + """Pipeline for non-English text that transliterates to ASCII.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners2(text): + """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] + phonemes = collapse_whitespace(phonemes) + return phonemes + + +# I am removing this due to incompatibility with several version of python +# However, if you want to use it, you can uncomment it +# and install piper-phonemize with the following command: +# pip install piper-phonemize + +# import piper_phonemize +# def english_cleaners_piper(text): +# """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" +# text = convert_to_ascii(text) +# text = lowercase(text) +# text = expand_abbreviations(text) +# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) +# phonemes = collapse_whitespace(phonemes) +# return phonemes diff --git a/third_party/Matcha-TTS/matcha/text/numbers.py b/third_party/Matcha-TTS/matcha/text/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..f99a8686dcb73532091122613e74bd643a8a327f --- /dev/null +++ b/third_party/Matcha-TTS/matcha/text/numbers.py @@ -0,0 +1,71 @@ +""" from https://github.com/keithito/tacotron """ + +import re + +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return f"{dollars} {dollar_unit}, {cents} {cent_unit}" + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return f"{dollars} {dollar_unit}" + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return f"{cents} {cent_unit}" + else: + return "zero dollars" + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/third_party/Matcha-TTS/matcha/text/symbols.py b/third_party/Matcha-TTS/matcha/text/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..7018df549a1e50c3be20416069b6913c641024bd --- /dev/null +++ b/third_party/Matcha-TTS/matcha/text/symbols.py @@ -0,0 +1,17 @@ +""" from https://github.com/keithito/tacotron + +Defines the set of symbols used in text input to the model. +""" +_pad = "_" +_punctuation = ';:,.!?¡¿—…"«»“” ' +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_letters_ipa = ( + "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" +) + + +# Export all symbols: +symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +# Special symbol ids +SPACE_ID = symbols.index(" ") diff --git a/third_party/Matcha-TTS/matcha/train.py b/third_party/Matcha-TTS/matcha/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d64c6c44af2622be5e6bf368616feb6619ed7e --- /dev/null +++ b/third_party/Matcha-TTS/matcha/train.py @@ -0,0 +1,122 @@ +from typing import Any, Dict, List, Optional, Tuple + +import hydra +import lightning as L +import rootutils +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from matcha import utils + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from src import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + + +log = utils.get_pylogger(__name__) + + +@utils.task_wrapper +def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: A DictConfig configuration composed by Hydra. + :return: A tuple with metrics and dict with all instantiated objects. + """ + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access + trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") +def main(cfg: DictConfig) -> Optional[float]: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + utils.extras(cfg) + + # train the model + metric_dict, _ = train(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/third_party/Matcha-TTS/matcha/utils/__init__.py b/third_party/Matcha-TTS/matcha/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..074db6461184e8cbb86d977cb41d9ebd918e958a --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/__init__.py @@ -0,0 +1,5 @@ +from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers +from matcha.utils.logging_utils import log_hyperparameters +from matcha.utils.pylogger import get_pylogger +from matcha.utils.rich_utils import enforce_tags, print_config_tree +from matcha.utils.utils import extras, get_metric_value, task_wrapper diff --git a/third_party/Matcha-TTS/matcha/utils/audio.py b/third_party/Matcha-TTS/matcha/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..0bcd74df47fb006f68deb5a5f4a4c2fb0aa84f57 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/audio.py @@ -0,0 +1,82 @@ +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if f"{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py b/third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py new file mode 100644 index 0000000000000000000000000000000000000000..49ed3c1b072cc3292c899b200d657a8beec197f8 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py @@ -0,0 +1,112 @@ +r""" +The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it +when needed. + +Parameters from hparam.py will be used +""" +import argparse +import json +import os +import sys +from pathlib import Path + +import rootutils +import torch +from hydra import compose, initialize +from omegaconf import open_dict +from tqdm.auto import tqdm + +from matcha.data.text_mel_datamodule import TextMelDataModule +from matcha.utils.logging_utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): + """Generate data mean and standard deviation helpful in data normalisation + + Args: + data_loader (torch.utils.data.Dataloader): _description_ + out_channels (int): mel spectrogram channels + """ + total_mel_sum = 0 + total_mel_sq_sum = 0 + total_mel_len = 0 + + for batch in tqdm(data_loader, leave=False): + mels = batch["y"] + mel_lengths = batch["y_lengths"] + + total_mel_len += torch.sum(mel_lengths) + total_mel_sum += torch.sum(mels) + total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) + + data_mean = total_mel_sum / (total_mel_len * out_channels) + data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) + + return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input-config", + type=str, + default="vctk.yaml", + help="The name of the yaml config file under configs/data", + ) + + parser.add_argument( + "-b", + "--batch-size", + type=int, + default="256", + help="Can have increased batch size for faster computation", + ) + + parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + required=False, + help="force overwrite the file", + ) + args = parser.parse_args() + output_file = Path(args.input_config).with_suffix(".json") + + if os.path.exists(output_file) and not args.force: + print("File already exists. Use -f to force overwrite") + sys.exit(1) + + with initialize(version_base="1.3", config_path="../../configs/data"): + cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + + root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") + + with open_dict(cfg): + del cfg["hydra"] + del cfg["_target_"] + cfg["data_statistics"] = None + cfg["seed"] = 1234 + cfg["batch_size"] = args.batch_size + cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) + cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["load_durations"] = False + + text_mel_datamodule = TextMelDataModule(**cfg) + text_mel_datamodule.setup() + data_loader = text_mel_datamodule.train_dataloader() + log.info("Dataloader loaded! Now computing stats...") + params = compute_data_statistics(data_loader, cfg["n_feats"]) + print(params) + json.dump( + params, + open(output_file, "w"), + ) + + +if __name__ == "__main__": + main() diff --git a/third_party/Matcha-TTS/matcha/utils/get_durations_from_trained_model.py b/third_party/Matcha-TTS/matcha/utils/get_durations_from_trained_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe2f35c4238756158370ed1463bfa06f05f7e3d --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/get_durations_from_trained_model.py @@ -0,0 +1,195 @@ +r""" +The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it +when needed. + +Parameters from hparam.py will be used +""" +import argparse +import json +import os +import sys +from pathlib import Path + +import lightning +import numpy as np +import rootutils +import torch +from hydra import compose, initialize +from omegaconf import open_dict +from torch import nn +from tqdm.auto import tqdm + +from matcha.cli import get_device +from matcha.data.text_mel_datamodule import TextMelDataModule +from matcha.models.matcha_tts import MatchaTTS +from matcha.utils.logging_utils import pylogger +from matcha.utils.utils import get_phoneme_durations + +log = pylogger.get_pylogger(__name__) + + +def save_durations_to_folder( + attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str +): + durations = attn.squeeze().sum(1)[:x_length].numpy() + durations_json = get_phoneme_durations(durations, text) + output = output_folder / Path(filepath).name.replace(".wav", ".npy") + with open(output.with_suffix(".json"), "w", encoding="utf-8") as f: + json.dump(durations_json, f, indent=4, ensure_ascii=False) + + np.save(output, durations) + + +@torch.inference_mode() +def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder): + """Generate durations from the model for each datapoint and save it in a folder + + Args: + data_loader (torch.utils.data.DataLoader): Dataloader + model (nn.Module): MatchaTTS model + device (torch.device): GPU or CPU + """ + + for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + spks = batch["spks"] + x = x.to(device) + y = y.to(device) + x_lengths = x_lengths.to(device) + y_lengths = y_lengths.to(device) + spks = spks.to(device) if spks is not None else None + + _, _, _, attn = model( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + spks=spks, + ) + attn = attn.cpu() + for i in range(attn.shape[0]): + save_durations_to_folder( + attn[i], + x_lengths[i].item(), + y_lengths[i].item(), + batch["filepaths"][i], + output_folder, + batch["x_texts"][i], + ) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input-config", + type=str, + default="ljspeech.yaml", + help="The name of the yaml config file under configs/data", + ) + + parser.add_argument( + "-b", + "--batch-size", + type=int, + default="32", + help="Can have increased batch size for faster computation", + ) + + parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + required=False, + help="force overwrite the file", + ) + parser.add_argument( + "-c", + "--checkpoint_path", + type=str, + required=True, + help="Path to the checkpoint file to load the model from", + ) + + parser.add_argument( + "-o", + "--output-folder", + type=str, + default=None, + help="Output folder to save the data statistics", + ) + + parser.add_argument( + "--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)" + ) + + args = parser.parse_args() + + with initialize(version_base="1.3", config_path="../../configs/data"): + cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + + root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") + + with open_dict(cfg): + del cfg["hydra"] + del cfg["_target_"] + cfg["seed"] = 1234 + cfg["batch_size"] = args.batch_size + cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) + cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["load_durations"] = False + + if args.output_folder is not None: + output_folder = Path(args.output_folder) + else: + output_folder = Path(cfg["train_filelist_path"]).parent / "durations" + + print(f"Output folder set to: {output_folder}") + + if os.path.exists(output_folder) and not args.force: + print("Folder already exists. Use -f to force overwrite") + sys.exit(1) + + output_folder.mkdir(parents=True, exist_ok=True) + + print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}") + print("Loading model...") + device = get_device(args) + model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device) + + text_mel_datamodule = TextMelDataModule(**cfg) + text_mel_datamodule.setup() + try: + print("Computing stats for training set if exists...") + train_dataloader = text_mel_datamodule.train_dataloader() + compute_durations(train_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No training set found") + + try: + print("Computing stats for validation set if exists...") + val_dataloader = text_mel_datamodule.val_dataloader() + compute_durations(val_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No validation set found") + + try: + print("Computing stats for test set if exists...") + test_dataloader = text_mel_datamodule.test_dataloader() + compute_durations(test_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No test set found") + + print(f"[+] Done! Data statistics saved to: {output_folder}") + + +if __name__ == "__main__": + # Helps with generating durations for the dataset to train other architectures + # that cannot learn to align due to limited size of dataset + # Example usage: + # python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model + # This will create a folder in data/processed_data/durations/ljspeech with the durations + main() diff --git a/third_party/Matcha-TTS/matcha/utils/instantiators.py b/third_party/Matcha-TTS/matcha/utils/instantiators.py new file mode 100644 index 0000000000000000000000000000000000000000..5547b4ed61ed8c21e63c528f58526a949879a94f --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/instantiators.py @@ -0,0 +1,56 @@ +from typing import List + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from matcha.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/third_party/Matcha-TTS/matcha/utils/logging_utils.py b/third_party/Matcha-TTS/matcha/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a12d1ddafa25ca3ae8e497bcd7de2191f13659b --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/logging_utils.py @@ -0,0 +1,53 @@ +from typing import Any, Dict + +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import OmegaConf + +from matcha.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) + hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/third_party/Matcha-TTS/matcha/utils/model.py b/third_party/Matcha-TTS/matcha/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..869cc6092f5952930534c47544fae88308e96abf --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/model.py @@ -0,0 +1,90 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import numpy as np +import torch + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def fix_len_compatibility(length, num_downsamplings_in_unet=2): + factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) + length = (length / factor).ceil() * factor + if not torch.onnx.is_in_onnx_export(): + return length.int().item() + else: + return length + + +def convert_pad_shape(pad_shape): + inverted_shape = pad_shape[::-1] + pad_shape = [item for sublist in inverted_shape for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + device = duration.device + + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path * mask + return path + + +def duration_loss(logw, logw_, lengths): + loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) + return loss + + +def normalize(data, mu, std): + if not isinstance(mu, (float, int)): + if isinstance(mu, list): + mu = torch.tensor(mu, dtype=data.dtype, device=data.device) + elif isinstance(mu, torch.Tensor): + mu = mu.to(data.device) + elif isinstance(mu, np.ndarray): + mu = torch.from_numpy(mu).to(data.device) + mu = mu.unsqueeze(-1) + + if not isinstance(std, (float, int)): + if isinstance(std, list): + std = torch.tensor(std, dtype=data.dtype, device=data.device) + elif isinstance(std, torch.Tensor): + std = std.to(data.device) + elif isinstance(std, np.ndarray): + std = torch.from_numpy(std).to(data.device) + std = std.unsqueeze(-1) + + return (data - mu) / std + + +def denormalize(data, mu, std): + if not isinstance(mu, float): + if isinstance(mu, list): + mu = torch.tensor(mu, dtype=data.dtype, device=data.device) + elif isinstance(mu, torch.Tensor): + mu = mu.to(data.device) + elif isinstance(mu, np.ndarray): + mu = torch.from_numpy(mu).to(data.device) + mu = mu.unsqueeze(-1) + + if not isinstance(std, float): + if isinstance(std, list): + std = torch.tensor(std, dtype=data.dtype, device=data.device) + elif isinstance(std, torch.Tensor): + std = std.to(data.device) + elif isinstance(std, np.ndarray): + std = torch.from_numpy(std).to(data.device) + std = std.unsqueeze(-1) + + return data * std + mu diff --git a/third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py b/third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eee6e0d47c2e3612ef02bc17442e6886998e5a94 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py @@ -0,0 +1,22 @@ +import numpy as np +import torch + +from matcha.utils.monotonic_align.core import maximum_path_c + + +def maximum_path(value, mask): + """Cython optimised version. + value: [b, t_x, t_y] + mask: [b, t_x, t_y] + """ + value = value * mask + device = value.device + dtype = value.dtype + value = value.data.cpu().numpy().astype(np.float32) + path = np.zeros_like(value).astype(np.int32) + mask = mask.data.cpu().numpy() + + t_x_max = mask.sum(1)[:, 0].astype(np.int32) + t_y_max = mask.sum(2)[:, 0].astype(np.int32) + maximum_path_c(path, value, t_x_max, t_y_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) diff --git a/third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx b/third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx new file mode 100644 index 0000000000000000000000000000000000000000..091fcc3a50a51f3d3fee47a70825260757e6d885 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx @@ -0,0 +1,47 @@ +import numpy as np + +cimport cython +cimport numpy as np + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[x, y-1] + if x == 0: + if y == 0: + v_prev = 0. + else: + v_prev = max_neg_val + else: + v_prev = value[x-1, y-1] + value[x, y] = max(v_cur, v_prev) + value[x, y] + + for y in range(t_y - 1, -1, -1): + path[index, y] = 1 + if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: + cdef int b = values.shape[0] + + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) diff --git a/third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py b/third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..f22bc6a35a5a04c9e6d7b82040973722c9b770c9 --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py @@ -0,0 +1,7 @@ +# from distutils.core import setup +# from Cython.Build import cythonize +# import numpy + +# setup(name='monotonic_align', +# ext_modules=cythonize("core.pyx"), +# include_dirs=[numpy.get_include()]) diff --git a/third_party/Matcha-TTS/matcha/utils/pylogger.py b/third_party/Matcha-TTS/matcha/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..61600678029362e110f655edb91d5f3bc5b1cd1c --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/pylogger.py @@ -0,0 +1,21 @@ +import logging + +from lightning.pytorch.utilities import rank_zero_only + + +def get_pylogger(name: str = __name__) -> logging.Logger: + """Initializes a multi-GPU-friendly python command line logger. + + :param name: The name of the logger, defaults to ``__name__``. + + :return: A logger object. + """ + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger diff --git a/third_party/Matcha-TTS/matcha/utils/rich_utils.py b/third_party/Matcha-TTS/matcha/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f602f6e9351d948946eb419eb4e420190ea634bc --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/rich_utils.py @@ -0,0 +1,101 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from matcha.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + _ = ( + queue.append(field) + if field in cfg + else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config. + + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/third_party/Matcha-TTS/matcha/utils/utils.py b/third_party/Matcha-TTS/matcha/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc3a48ec2b532ff8e034181d71ed5f4d7823d9be --- /dev/null +++ b/third_party/Matcha-TTS/matcha/utils/utils.py @@ -0,0 +1,259 @@ +import os +import sys +import warnings +from importlib.util import find_spec +from math import ceil +from pathlib import Path +from typing import Any, Callable, Dict, Tuple + +import gdown +import matplotlib.pyplot as plt +import numpy as np +import torch +import wget +from omegaconf import DictConfig + +from matcha.utils import pylogger, rich_utils + +log = pylogger.get_pylogger(__name__) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + + :param cfg: A DictConfig object containing the config tree. + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ... + return metric_dict, object_dict + ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. + """ + + def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule. + + :param metric_dict: A dict containing metric values. + :param metric_name: The name of the metric to retrieve. + :return: The value of the metric. + """ + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise ValueError( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def intersperse(lst, item): + # Adds blank symbol + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def save_figure_to_numpy(fig): + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +def plot_tensor(tensor): + plt.style.use("default") + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +def save_plot(tensor, savepath): + plt.style.use("default") + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.tight_layout() + fig.canvas.draw() + plt.savefig(savepath) + plt.close() + + +def to_numpy(tensor): + if isinstance(tensor, np.ndarray): + return tensor + elif isinstance(tensor, torch.Tensor): + return tensor.detach().cpu().numpy() + elif isinstance(tensor, list): + return np.array(tensor) + else: + raise TypeError("Unsupported type for conversion to numpy array") + + +def get_user_data_dir(appname="matcha_tts"): + """ + Args: + appname (str): Name of application + + Returns: + Path: path to user data directory + """ + + MATCHA_HOME = os.environ.get("MATCHA_HOME") + if MATCHA_HOME is not None: + ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) + elif sys.platform == "win32": + import winreg # pylint: disable=import-outside-toplevel + + key = winreg.OpenKey( + winreg.HKEY_CURRENT_USER, + r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", + ) + dir_, _ = winreg.QueryValueEx(key, "Local AppData") + ans = Path(dir_).resolve(strict=False) + elif sys.platform == "darwin": + ans = Path("~/Library/Application Support/").expanduser() + else: + ans = Path.home().joinpath(".local/share") + + final_path = ans.joinpath(appname) + final_path.mkdir(parents=True, exist_ok=True) + return final_path + + +def assert_model_downloaded(checkpoint_path, url, use_wget=True): + if Path(checkpoint_path).exists(): + log.debug(f"[+] Model already present at {checkpoint_path}!") + print(f"[+] Model already present at {checkpoint_path}!") + return + log.info(f"[-] Model not found at {checkpoint_path}! Will download it") + print(f"[-] Model not found at {checkpoint_path}! Will download it") + checkpoint_path = str(checkpoint_path) + if not use_wget: + gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) + else: + wget.download(url=url, out=checkpoint_path) + + +def get_phoneme_durations(durations, phones): + prev = durations[0] + merged_durations = [] + # Convolve with stride 2 + for i in range(1, len(durations), 2): + if i == len(durations) - 2: + # if it is last take full value + next_half = durations[i + 1] + else: + next_half = ceil(durations[i + 1] / 2) + + curr = prev + durations[i] + next_half + prev = durations[i + 1] - next_half + merged_durations.append(curr) + + assert len(phones) == len(merged_durations) + assert len(merged_durations) == (len(durations) - 1) // 2 + + merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long) + start = torch.tensor(0) + duration_json = [] + for i, duration in enumerate(merged_durations): + duration_json.append( + { + phones[i]: { + "starttime": start.item(), + "endtime": duration.item(), + "duration": duration.item() - start.item(), + } + } + ) + start = duration + + assert list(duration_json[-1].values())[0]["endtime"] == sum( + durations + ), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}" + return duration_json diff --git a/third_party/Matcha-TTS/notebooks/.gitkeep b/third_party/Matcha-TTS/notebooks/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/Matcha-TTS/pyproject.toml b/third_party/Matcha-TTS/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..74aa39300a61b8b3607dc634d68aa47013141ec5 --- /dev/null +++ b/third_party/Matcha-TTS/pyproject.toml @@ -0,0 +1,51 @@ +[build-system] +requires = ["setuptools", "wheel", "cython==0.29.35", "numpy==1.24.3", "packaging"] + +[tool.black] +line-length = 120 +target-version = ['py310'] +exclude = ''' + +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + )/ + | foo.py # also separately exclude a file named foo.py in + # the root of the project +) +''' + +[tool.pytest.ini_options] +addopts = [ + "--color=yes", + "--durations=0", + "--strict-markers", + "--doctest-modules", +] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::UserWarning", +] +log_cli = "True" +markers = [ + "slow: slow tests", +] +minversion = "6.0" +testpaths = "tests/" + +[tool.coverage.report] +exclude_lines = [ + "pragma: nocover", + "raise NotImplementedError", + "raise NotImplementedError()", + "if __name__ == .__main__.:", +] diff --git a/third_party/Matcha-TTS/requirements.txt b/third_party/Matcha-TTS/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b0eabfbd38a9483d84d3ea960671460fe69c89e --- /dev/null +++ b/third_party/Matcha-TTS/requirements.txt @@ -0,0 +1,44 @@ +# --------- pytorch --------- # +torch>=2.0.0 +torchvision>=0.15.0 +lightning>=2.0.0 +torchmetrics>=0.11.4 + +# --------- hydra --------- # +hydra-core==1.3.2 +hydra-colorlog==1.2.0 +hydra-optuna-sweeper==1.2.0 + +# --------- loggers --------- # +# wandb +# neptune-client +# mlflow +# comet-ml +# aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 + +# --------- others --------- # +rootutils # standardizing the project root setup +pre-commit # hooks for applying linters on commit +rich # beautiful text formatting in terminal +pytest # tests +# sh # for running bash commands in some tests (linux/macos only) +phonemizer # phonemization of text +tensorboard +librosa +Cython +numpy +einops +inflect +Unidecode +scipy +torchaudio +matplotlib +pandas +conformer==0.3.2 +diffusers # developed using version ==0.25.0 +notebook +ipywidgets +gradio==3.43.2 +gdown +wget +seaborn diff --git a/third_party/Matcha-TTS/scripts/schedule.sh b/third_party/Matcha-TTS/scripts/schedule.sh new file mode 100644 index 0000000000000000000000000000000000000000..44b3da1116ef4d54e9acffee7d639d549e136d45 --- /dev/null +++ b/third_party/Matcha-TTS/scripts/schedule.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# Schedule execution of many runs +# Run from root folder with: bash scripts/schedule.sh + +python src/train.py trainer.max_epochs=5 logger=csv + +python src/train.py trainer.max_epochs=10 logger=csv diff --git a/third_party/Matcha-TTS/setup.py b/third_party/Matcha-TTS/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a49c2ccd4b7eaa960f15cb682a40d8595101b2ab --- /dev/null +++ b/third_party/Matcha-TTS/setup.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +import os + +import numpy +from Cython.Build import cythonize +from setuptools import Extension, find_packages, setup + +exts = [ + Extension( + name="matcha.utils.monotonic_align.core", + sources=["matcha/utils/monotonic_align/core.pyx"], + ) +] + +with open("README.md", encoding="utf-8") as readme_file: + README = readme_file.read() + +cwd = os.path.dirname(os.path.abspath(__file__)) +with open(os.path.join(cwd, "matcha", "VERSION")) as fin: + version = fin.read().strip() + +setup( + name="matcha-tts", + version=version, + description="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching", + long_description=README, + long_description_content_type="text/markdown", + author="Shivam Mehta", + author_email="shivam.mehta25@gmail.com", + url="https://shivammehta25.github.io/Matcha-TTS", + install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))], + include_dirs=[numpy.get_include()], + include_package_data=True, + packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]), + # use this to customize global commands available in the terminal after installing the package + entry_points={ + "console_scripts": [ + "matcha-data-stats=matcha.utils.generate_data_statistics:main", + "matcha-tts=matcha.cli:cli", + "matcha-tts-app=matcha.app:main", + "matcha-tts-get-durations=matcha.utils.get_durations_from_trained_model:main", + ] + }, + ext_modules=cythonize(exts, language_level=3), + python_requires=">=3.9.0", +) diff --git a/third_party/Matcha-TTS/synthesis.ipynb b/third_party/Matcha-TTS/synthesis.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..1e47c534f93b2901910386cd02a75018abfc2570 --- /dev/null +++ b/third_party/Matcha-TTS/synthesis.ipynb @@ -0,0 +1,419 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f37f4e3b-f764-4502-a6a2-6417bd9bfab9", + "metadata": {}, + "source": [ + "# Matcha-TTS: A fast TTS architecture with conditional flow matching\n", + "---\n", + "[Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)\n", + "\n", + "We introduce Matcha-TTS, a new encoder-decoder architecture for speedy TTS acoustic modelling, trained using optimal-transport conditional flow matching (OT-CFM). This yields an ODE-based decoder capable of high output quality in fewer synthesis steps than models trained using score matching. Careful design choices additionally ensure each synthesis step is fast to run. The method is probabilistic, non-autoregressive, and learns to speak from scratch without external alignments. Compared to strong pre-trained baseline models, the Matcha-TTS system has the smallest memory footprint, rivals the speed of the fastest models on long utterances, and attains the highest mean opinion score in a listening test.\n", + "\n", + "Demo Page: https://shivammehta25.github.io/Matcha-TTS \\\n", + "Code: https://github.com/shivammehta25/Matcha-TTS\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "148f4bc0-c28e-4670-9a5e-4c7928ab8992", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: CUDA_VISIBLE_DEVICES=0\n" + ] + } + ], + "source": [ + "%env CUDA_VISIBLE_DEVICES=0" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8d5876c0-b47e-4c80-9e9c-62550f81b64e", + "metadata": {}, + "outputs": [], + "source": [ + "import datetime as dt\n", + "from pathlib import Path\n", + "\n", + "import IPython.display as ipd\n", + "import numpy as np\n", + "import soundfile as sf\n", + "import torch\n", + "from tqdm.auto import tqdm\n", + "\n", + "# Hifigan imports\n", + "from matcha.hifigan.config import v1\n", + "from matcha.hifigan.denoiser import Denoiser\n", + "from matcha.hifigan.env import AttrDict\n", + "from matcha.hifigan.models import Generator as HiFiGAN\n", + "# Matcha imports\n", + "from matcha.models.matcha_tts import MatchaTTS\n", + "from matcha.text import sequence_to_text, text_to_sequence\n", + "from matcha.utils.model import denormalize\n", + "from matcha.utils.utils import get_user_data_dir, intersperse" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b1a30306-588c-4f22-8d9b-e2676880b0e5", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "# This allows for real time code changes being reflected in the notebook, no need to restart the kernel" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a312856b-01a9-4d75-a4c8-4666dffa0692", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "markdown", + "id": "88f3b3c3-d014-443b-84eb-e143cdec3e21", + "metadata": {}, + "source": [ + "## Filepaths" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7640a4c1-44ce-447c-a8ff-45012fb7bddd", + "metadata": {}, + "outputs": [], + "source": [ + "MATCHA_CHECKPOINT = get_user_data_dir()/\"matcha_ljspeech.ckpt\"\n", + "HIFIGAN_CHECKPOINT = get_user_data_dir() / \"hifigan_T2_v1\"\n", + "OUTPUT_FOLDER = \"synth_output\"" + ] + }, + { + "cell_type": "markdown", + "id": "6477a3a9-71f2-4d2f-bb86-bdf3e31c2461", + "metadata": {}, + "source": [ + "## Load Matcha-TTS" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "26a16230-04ba-4825-a844-2fb5ab945e24", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model loaded! Parameter count: 18,204,193\n" + ] + } + ], + "source": [ + "def load_model(checkpoint_path):\n", + " model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)\n", + " model.eval()\n", + " return model\n", + "count_params = lambda x: f\"{sum(p.numel() for p in x.parameters()):,}\"\n", + "\n", + "\n", + "model = load_model(MATCHA_CHECKPOINT)\n", + "print(f\"Model loaded! Parameter count: {count_params(model)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3077b84b-e3b6-42e1-a84b-2f7084b13f92", + "metadata": {}, + "source": [ + "## Load HiFi-GAN (Vocoder)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f6b68184-968d-4868-9029-f0c40e9e68af", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Removing weight norm...\n" + ] + } + ], + "source": [ + "def load_vocoder(checkpoint_path):\n", + " h = AttrDict(v1)\n", + " hifigan = HiFiGAN(h).to(device)\n", + " hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])\n", + " _ = hifigan.eval()\n", + " hifigan.remove_weight_norm()\n", + " return hifigan\n", + "\n", + "vocoder = load_vocoder(HIFIGAN_CHECKPOINT)\n", + "denoiser = Denoiser(vocoder, mode='zeros')" + ] + }, + { + "cell_type": "markdown", + "id": "4cbc2ba0-09ff-40e2-9e60-6b77b534f9fb", + "metadata": {}, + "source": [ + "### Helper functions to synthesise" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "880a1879-24fd-4757-849c-850339120796", + "metadata": {}, + "outputs": [], + "source": [ + "@torch.inference_mode()\n", + "def process_text(text: str):\n", + " x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2'])[0], 0),dtype=torch.long, device=device)[None]\n", + " x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)\n", + " x_phones = sequence_to_text(x.squeeze(0).tolist())\n", + " return {\n", + " 'x_orig': text,\n", + " 'x': x,\n", + " 'x_lengths': x_lengths,\n", + " 'x_phones': x_phones\n", + " }\n", + "\n", + "\n", + "@torch.inference_mode()\n", + "def synthesise(text, spks=None):\n", + " text_processed = process_text(text)\n", + " start_t = dt.datetime.now()\n", + " output = model.synthesise(\n", + " text_processed['x'], \n", + " text_processed['x_lengths'],\n", + " n_timesteps=n_timesteps,\n", + " temperature=temperature,\n", + " spks=spks,\n", + " length_scale=length_scale\n", + " )\n", + " # merge everything to one dict \n", + " output.update({'start_t': start_t, **text_processed})\n", + " return output\n", + "\n", + "@torch.inference_mode()\n", + "def to_waveform(mel, vocoder):\n", + " audio = vocoder(mel).clamp(-1, 1)\n", + " audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()\n", + " return audio.cpu().squeeze()\n", + " \n", + "def save_to_folder(filename: str, output: dict, folder: str):\n", + " folder = Path(folder)\n", + " folder.mkdir(exist_ok=True, parents=True)\n", + " np.save(folder / f'{filename}', output['mel'].cpu().numpy())\n", + " sf.write(folder / f'{filename}.wav', output['waveform'], 22050, 'PCM_24')" + ] + }, + { + "cell_type": "markdown", + "id": "78f857e3-2ef7-4c86-b776-596c4d3cf875", + "metadata": {}, + "source": [ + "## Setup text to synthesise" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2e0a9acd-0845-4192-ba09-b9683e28a3ac", + "metadata": {}, + "outputs": [], + "source": [ + "texts = [\n", + " \"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.\"\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "a9da9e2d-99b9-4c6f-8a08-c828e2cba121", + "metadata": {}, + "source": [ + "### Hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f0d216e5-4895-4da8-9d24-9e61021d2556", + "metadata": {}, + "outputs": [], + "source": [ + "## Number of ODE Solver steps\n", + "n_timesteps = 10\n", + "\n", + "## Changes to the speaking rate\n", + "length_scale=1.0\n", + "\n", + "## Sampling temperature\n", + "temperature = 0.667" + ] + }, + { + "cell_type": "markdown", + "id": "b93aac89-c7f8-4975-8510-4e763c9689f4", + "metadata": {}, + "source": [ + "## Synthesis" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "5a227963-aa12-43b9-a706-1168b6fc0ba5", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8342d12401c54017b0e19b8d293a06bf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00\n", + " \n", + " Your browser does not support the audio element.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of ODE steps: 10\n", + "Mean RTF:\t\t\t\t0.017228 ± 0.000000\n", + "Mean RTF Waveform (incl. vocoder):\t0.021445 ± 0.000000\n" + ] + } + ], + "source": [ + "outputs, rtfs = [], []\n", + "rtfs_w = []\n", + "for i, text in enumerate(tqdm(texts)):\n", + " output = synthesise(text) #, torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))\n", + " output['waveform'] = to_waveform(output['mel'], vocoder)\n", + "\n", + " # Compute Real Time Factor (RTF) with HiFi-GAN\n", + " t = (dt.datetime.now() - output['start_t']).total_seconds()\n", + " rtf_w = t * 22050 / (output['waveform'].shape[-1])\n", + "\n", + " ## Pretty print\n", + " print(f\"{'*' * 53}\")\n", + " print(f\"Input text - {i}\")\n", + " print(f\"{'-' * 53}\")\n", + " print(output['x_orig'])\n", + " print(f\"{'*' * 53}\")\n", + " print(f\"Phonetised text - {i}\")\n", + " print(f\"{'-' * 53}\")\n", + " print(output['x_phones'])\n", + " print(f\"{'*' * 53}\")\n", + " print(f\"RTF:\\t\\t{output['rtf']:.6f}\")\n", + " print(f\"RTF Waveform:\\t{rtf_w:.6f}\")\n", + " rtfs.append(output['rtf'])\n", + " rtfs_w.append(rtf_w)\n", + "\n", + " ## Display the synthesised waveform\n", + " ipd.display(ipd.Audio(output['waveform'], rate=22050))\n", + "\n", + " ## Save the generated waveform\n", + " save_to_folder(i, output, OUTPUT_FOLDER)\n", + "\n", + "print(f\"Number of ODE steps: {n_timesteps}\")\n", + "print(f\"Mean RTF:\\t\\t\\t\\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}\")\n", + "print(f\"Mean RTF Waveform (incl. vocoder):\\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3e85c3f-1623-4647-b40c-fa96907656fc", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}