From aea1afef64514e84eabdc7b72893072323ed7ed7 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 18 Jul 2025 16:22:15 +0000 Subject: [PATCH] Remove test_[tedlium3|mustc]_lightning.py --- .../emformer_rnnt/test_mustc_lightning.py | 76 ------------------ .../emformer_rnnt/test_tedlium3_lightning.py | 79 ------------------- 2 files changed, 155 deletions(-) delete mode 100644 test/torchaudio_unittest/example/emformer_rnnt/test_mustc_lightning.py delete mode 100644 test/torchaudio_unittest/example/emformer_rnnt/test_tedlium3_lightning.py diff --git a/test/torchaudio_unittest/example/emformer_rnnt/test_mustc_lightning.py b/test/torchaudio_unittest/example/emformer_rnnt/test_mustc_lightning.py deleted file mode 100644 index c3e42606bb..0000000000 --- a/test/torchaudio_unittest/example/emformer_rnnt/test_mustc_lightning.py +++ /dev/null @@ -1,76 +0,0 @@ -from contextlib import contextmanager -from functools import partial -from unittest.mock import patch - -import torch -from parameterized import parameterized -from torchaudio._internal.module_utils import is_module_available -from torchaudio_unittest.common_utils import skipIfNoModule, TorchaudioTestCase - -from .utils import MockCustomDataset, MockDataloader, MockSentencePieceProcessor - -if is_module_available("pytorch_lightning", "sentencepiece"): - from asr.emformer_rnnt.mustc.lightning import MuSTCRNNTModule - - -class MockMUSTC: - def __init__(self, *args, **kwargs): - pass - - def __getitem__(self, n: int): - return ( - torch.rand(1, 32640), - "sup", - ) - - def __len__(self): - return 10 - - -@contextmanager -def get_lightning_module(): - with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch( - "asr.emformer_rnnt.mustc.lightning.GlobalStatsNormalization", new=torch.nn.Identity - ), patch("asr.emformer_rnnt.mustc.lightning.MUSTC", new=MockMUSTC), patch( - "asr.emformer_rnnt.mustc.lightning.CustomDataset", new=MockCustomDataset - ), patch( - "torch.utils.data.DataLoader", new=MockDataloader - ): - yield MuSTCRNNTModule( - mustc_path="mustc_path", - sp_model_path="sp_model_path", - global_stats_path="global_stats_path", - ) - - -@skipIfNoModule("pytorch_lightning") -@skipIfNoModule("sentencepiece") -class TestMuSTCRNNTModule(TorchaudioTestCase): - @classmethod - def setUpClass(cls) -> None: - super().setUpClass() - - @parameterized.expand( - [ - ("training_step", "train_dataloader"), - ("validation_step", "val_dataloader"), - ("test_step", "test_common_dataloader"), - ("test_step", "test_he_dataloader"), - ] - ) - def test_step(self, step_fname, dataloader_fname): - with get_lightning_module() as lightning_module: - dataloader = getattr(lightning_module, dataloader_fname)() - batch = next(iter(dataloader)) - getattr(lightning_module, step_fname)(batch, 0) - - @parameterized.expand( - [ - ("val_dataloader",), - ] - ) - def test_forward(self, dataloader_fname): - with get_lightning_module() as lightning_module: - dataloader = getattr(lightning_module, dataloader_fname)() - batch = next(iter(dataloader)) - lightning_module(batch) diff --git a/test/torchaudio_unittest/example/emformer_rnnt/test_tedlium3_lightning.py b/test/torchaudio_unittest/example/emformer_rnnt/test_tedlium3_lightning.py deleted file mode 100644 index e1804dfcfd..0000000000 --- a/test/torchaudio_unittest/example/emformer_rnnt/test_tedlium3_lightning.py +++ /dev/null @@ -1,79 +0,0 @@ -from contextlib import contextmanager -from functools import partial -from unittest.mock import patch - -import torch -from parameterized import parameterized -from torchaudio._internal.module_utils import is_module_available -from torchaudio_unittest.common_utils import skipIfNoModule, TorchaudioTestCase - -from .utils import MockCustomDataset, MockDataloader, MockSentencePieceProcessor - -if is_module_available("pytorch_lightning", "sentencepiece"): - from asr.emformer_rnnt.tedlium3.lightning import TEDLIUM3RNNTModule - - -class MockTEDLIUM: - def __init__(self, *args, **kwargs): - pass - - def __getitem__(self, n: int): - return ( - torch.rand(1, 32640), - 16000, - "sup", - 2, - 3, - 4, - ) - - def __len__(self): - return 10 - - -@contextmanager -def get_lightning_module(): - with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch( - "asr.emformer_rnnt.tedlium3.lightning.GlobalStatsNormalization", new=torch.nn.Identity - ), patch("torchaudio.datasets.TEDLIUM", new=MockTEDLIUM), patch( - "asr.emformer_rnnt.tedlium3.lightning.CustomDataset", new=MockCustomDataset - ), patch( - "torch.utils.data.DataLoader", new=MockDataloader - ): - yield TEDLIUM3RNNTModule( - tedlium_path="tedlium_path", - sp_model_path="sp_model_path", - global_stats_path="global_stats_path", - ) - - -@skipIfNoModule("pytorch_lightning") -@skipIfNoModule("sentencepiece") -class TestTEDLIUM3RNNTModule(TorchaudioTestCase): - @classmethod - def setUpClass(cls) -> None: - super().setUpClass() - - @parameterized.expand( - [ - ("training_step", "train_dataloader"), - ("validation_step", "val_dataloader"), - ("test_step", "test_dataloader"), - ] - ) - def test_step(self, step_fname, dataloader_fname): - with get_lightning_module() as lightning_module: - dataloader = getattr(lightning_module, dataloader_fname)() - batch = next(iter(dataloader)) - getattr(lightning_module, step_fname)(batch, 0) - - @parameterized.expand( - [ - ("val_dataloader",), - ] - ) - def test_forward(self, dataloader_fname): - with get_lightning_module() as lightning_module: - dataloader = getattr(lightning_module, dataloader_fname)() - batch = next(iter(dataloader)) - lightning_module(batch)