diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 3e11743ee..a9f7a3864 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -536,6 +536,20 @@ def set_embedding_config( # pyre-ignore [24] def cmd_conf(func: Callable) -> Callable: + def _load_config_file(config_path: str, is_json: bool = False) -> Dict[str, Any]: + if not config_path: + return {} + + try: + with open(config_path, "r") as f: + if is_json: + return json.load(f) or {} + else: + return yaml.safe_load(f) or {} + except Exception as e: + logger.warning(f"Failed to load config because {e}. Proceeding without it.") + return {} + # pyre-ignore [3] def wrapper() -> Any: sig = inspect.signature(func) @@ -548,6 +562,13 @@ def wrapper() -> Any: help="YAML config file for benchmarking", ) + parser.add_argument( + "--json_config", + type=str, + default=None, + help="JSON config file for benchmarking", + ) + # Add loglevel argument with current logger level as default parser.add_argument( "--loglevel", @@ -558,18 +579,18 @@ def wrapper() -> Any: pre_args, _ = parser.parse_known_args() - yaml_defaults: Dict[str, Any] = {} - if pre_args.yaml_config: - try: - with open(pre_args.yaml_config, "r") as f: - yaml_defaults = yaml.safe_load(f) or {} - logger.info( - f"Loaded YAML config from {pre_args.yaml_config}: {yaml_defaults}" - ) - except Exception as e: - logger.warning( - f"Failed to load YAML config because {e}. Proceeding without it." - ) + yaml_defaults: Dict[str, Any] = ( + _load_config_file(pre_args.yaml_config, is_json=False) + if pre_args.yaml_config + else {} + ) + json_defaults: Dict[str, Any] = ( + _load_config_file(pre_args.json_config, is_json=True) + if pre_args.json_config + else {} + ) + # Merge the two dictionaries, JSON overrides YAML + merged_defaults = {**yaml_defaults, **json_defaults} seen_args = set() # track all -- we've added @@ -595,10 +616,10 @@ def wrapper() -> Any: ftype = non_none[0] origin = get_origin(ftype) - # Handle default_factory value and allow YAML config to override it - default_value = yaml_defaults.get( + # Handle default_factory value and allow config to override + default_value = merged_defaults.get( arg_name, # flat lookup - yaml_defaults.get(cls.__name__, {}).get( # hierarchy lookup + merged_defaults.get(cls.__name__, {}).get( # hierarchy lookup arg_name, ( f.default_factory() # pyre-ignore [29]