赞
踩
get_experiment_logger里面放的是什么呢?
- def get_experiment_logger(
- config: DictConfig | ListConfig,
- ) -> Logger | Iterable[Logger] | bool:
- """Return a logger based on the choice of logger in the config file.
- Args:
- config (DictConfig): config.yaml file for the corresponding anomalib model.
- Raises:
- ValueError: for any logger types apart from false and tensorboard
- Returns:
- Logger | Iterable[Logger] | bool]: Logger
- """
- logger.info("Loading the experiment logger(s)")
-
- # TODO remove when logger is deprecated from project
- if "logger" in config.project.keys():
- warnings.warn(
- "'logger' key will be deprecated from 'project' section of the config file."
- " Please use the logging section in config file.",
- DeprecationWarning,
- )
- if "logging" not in config:
- config.logging = {"logger": config.project.logger, "log_graph": False}
- else:
- config.logging.logger = config.project.logger
-
- if config.logging.logger in (None, False):
- return False
-
- print("-------------------hahahhahhah")
- print(config.logging.logger)
- print("-------------------hahahhahhah-end")
-
- logger_list: list[Logger] = []
- if isinstance(config.logging.logger, str):
- config.logging.logger = [config.logging.logger]
-
- print("------------------------gao1")
- for experiment_logger in config.logging.logger:
- print("-------------------------gao2")
- print(experiment_logger)
- if experiment_logger == "tensorboard":
- logger_list.append(
- AnomalibTensorBoardLogger(
- name="Tensorboard Logs",
- save_dir=os.path.join(config.project.path, "logs"),
- log_graph=config.logging.log_graph,
- )
- )
- elif experiment_logger == "wandb":
- wandb_logdir = os.path.join(config.project.path, "logs")
- Path(wandb_logdir).mkdir(parents=True, exist_ok=True)
- name = (
- config.model.name
- if "category" not in config.dataset.keys()
- else f"{config.dataset.category} {config.model.name}"
- )
- logger_list.append(
- AnomalibWandbLogger(
- project=config.dataset.name,
- name=name,
- save_dir=wandb_logdir,
- )
- )
- elif experiment_logger == "comet":
- comet_logdir = os.path.join(config.project.path, "logs")
- Path(comet_logdir).mkdir(parents=True, exist_ok=True)
- run_name = (
- config.model.name
- if "category" not in config.dataset.keys()
- else f"{config.dataset.category} {config.model.name}"
- )
- logger_list.append(
- AnomalibCometLogger(project_name=config.dataset.name, experiment_name=run_name, save_dir=comet_logdir)
- )
- elif experiment_logger == "csv":
- logger_list.append(CSVLogger(save_dir=os.path.join(config.project.path, "logs")))
- else:
- raise UnknownLogger(
- f"Unknown logger type: {config.logging.logger}. "
- f"Available loggers are: {AVAILABLE_LOGGERS}.\n"
- f"To enable the logger, set `project.logger` to `true` or use one of available loggers in config.yaml\n"
- f"To disable the logger, set `project.logger` to `false`."
- )
- print("-------------------------gao3")
- print(logger_list)
- print("-------------------------gao3-end")
- return logger_list
你看最后这几行,我打印出来的logger_list,是一个[]
白折腾了,呵呵。
callbacks = get_callbacks(config) 看看里面都是啥:
- def get_callbacks(config: DictConfig | ListConfig) -> list[Callback]:
- """Return base callbacks for all the lightning models.
- Args:
- config (DictConfig): Model config
- Return:
- (list[Callback]): List of callbacks.
- """
- logger.info("Loading the callbacks")
-
- callbacks: list[Callback] = []
-
- monitor_metric = None if "early_stopping" not in config.model.keys() else config.model.early_stopping.metric
- monitor_mode = "max" if "early_stopping" not in config.model.keys() else config.model.early_stopping.mode
-
- checkpoint = ModelCheckpoint(
- dirpath=os.path.join(config.project.path, "weights"),
- filename="model",
- monitor=monitor_metric,
- mode=monitor_mode,
- auto_insert_metric_name=False,
- )
-
- callbacks.extend([checkpoint, TimerCallback()])
-
- if "resume_from_checkpoint" in config.trainer.keys() and config.trainer.resume_from_checkpoint is not None:
- load_model = LoadModelCallback(config.trainer.resume_from_checkpoint)
- callbacks.append(load_model)
-
- # Add post-processing configurations to AnomalyModule.
- image_threshold = (
- config.metrics.threshold.manual_image if "manual_image" in config.metrics.threshold.keys() else None
- )
- pixel_threshold = (
- config.metrics.threshold.manual_pixel if "manual_pixel" in config.metrics.threshold.keys() else None
- )
- post_processing_callback = PostProcessingConfigurationCallback(
- threshold_method=config.metrics.threshold.method,
- manual_image_threshold=image_threshold,
- manual_pixel_threshold=pixel_threshold,
- )
- callbacks.append(post_processing_callback)
-
- # Add metric configuration to the model via MetricsConfigurationCallback
- metrics_callback = MetricsConfigurationCallback(
- config.dataset.task,
- config.metrics.get("image", None),
- config.metrics.get("pixel", None),
- )
- callbacks.append(metrics_callback)
-
- if "normalization_method" in config.model.keys() and not config.model.normalization_method == "none":
- if config.model.normalization_method == "cdf":
- if config.model.name in ("padim", "stfpm"):
- if "nncf" in config.optimization and config.optimization.nncf.apply:
- raise NotImplementedError("CDF Score Normalization is currently not compatible with NNCF.")
- callbacks.append(CdfNormalizationCallback())
- else:
- raise NotImplementedError("Score Normalization is currently supported for PADIM and STFPM only.")
- elif config.model.normalization_method == "min_max":
- callbacks.append(MinMaxNormalizationCallback())
- else:
- raise ValueError(f"Normalization method not recognized: {config.model.normalization_method}")
-
- add_visualizer_callback(callbacks, config)
-
- if "optimization" in config.keys():
- if "nncf" in config.optimization and config.optimization.nncf.apply:
- # NNCF wraps torch's jit which conflicts with kornia's jit calls.
- # Hence, nncf is imported only when required
- nncf_module = import_module("anomalib.utils.callbacks.nncf.callback")
- nncf_callback = getattr(nncf_module, "NNCFCallback")
- nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf))
- callbacks.append(
- nncf_callback(
- config=nncf_config,
- export_dir=os.path.join(config.project.path, "compressed"),
- )
- )
- if config.optimization.export_mode is not None:
- from .export import ( # pylint: disable=import-outside-toplevel
- ExportCallback,
- )
-
- logger.info("Setting model export to %s", config.optimization.export_mode)
- callbacks.append(
- ExportCallback(
- input_size=config.model.input_size,
- dirpath=config.project.path,
- filename="model",
- export_mode=ExportMode(config.optimization.export_mode),
- )
- )
- else:
- warnings.warn(f"Export option: {config.optimization.export_mode} not found. Defaulting to no model export")
-
- # Add callback to log graph to loggers
- if config.logging.log_graph not in (None, False):
- callbacks.append(GraphLogger())
-
- print("-------------gao callbacks")
- print(callbacks)
- print("-------------gao callbacks-end")
-
- return callbacks
看看我最后打印出来的callbacks,是啥样的?
终于来到了最总要的地方了:
Trainer的第一个参数,是下面这样婶儿的:
第二个参数,前面我们说了,就是[]
第三个参数,就是上面那一堆callbacks
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。