赞
踩
【start:20231103】
timm库在huggingface无法联网时,huggingface会一直报网络错误,这时如果要使用预训练权重,需要采用本地读取方法,那么:
【huggingface官网链接】https://huggingface.co/timm/tf_efficientnetv2_s.in1k
【timm库github链接】https://github.com/huggingface/pytorch-image-models#introduction
【ref】python timm库 python timm库下载
介绍了timm的github项目
【ref】timm-手动下载模型
提到了timm的有新、旧两种下载链接(huggingface、torch)的问题,但未解决用torch链接来下载权重的问题
【ref】Timm预训练权重下载失败的解决方案~
提到了改pretrained_cfg_overlay
这个参数的方法,解决了使用huggingface链接下载的权重的问题
【ref】[Pytorch] timm.create_model()通过指定pretrained_cfg从本地加载pretrained模型
解决了用torch链接下载权重的问题
C:\Users\lenovo>ping huggingface.co
正在 Ping huggingface.co [31.13.83.34] 具有 32 字节的数据:
请求超时。
请求超时。
请求超时。
请求超时。
发现huggingface网络确实不太行
在linux服务器上,“timm库在huggingface无法联网时预训练权重无法下载”这个问题是在某一个具体项目中发现的:
# Define the model with the function API.
model = cppnet_base_multiclass(
enc_name="tf_efficientnetv2_s",
n_rays=32, # number of predicted rays
type_classes=len(pannuke_module.type_classes),
)
然后会报以下错:
开启外网时报的错
TimeoutError: timed out
...
1367 ) from head_call_error
1369 # From now on, etag and commit_hash are not None.
1370 assert etag is not None, "etag must have been retrieved from server"
LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...
如上,开启外网后还是会报错,而且不会返回关键文件的下载地址
MaxRetryError("HTTPSConnectionPool(host=‘huggingface.co’, port=443):
不开启外网时报的错
Seed set to 42
**kwargs : {'enc_name': 'tf_efficientnetv2_s'}
kwargs.get("checkpoint_path", None): None
checkpoint_path: None
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /timm/tf_efficientnetv2_s.in21k_ft_in1k/resolve/main/model.safetensors (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7fd5a10b61a0>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: 1792c043-b41c-4c0a-9deb-eda1a8438775)')' thrown while requesting HEAD https://huggingface.co/timm/tf_efficientnetv2_s.in21k_ft_in1k/resolve/main/model.safetensors
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /timm/tf_efficientnetv2_s.in21k_ft_in1k/resolve/main/pytorch_model.bin (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7fd5a10b7700>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: ee241300-1909-4254-aaad-e6a073600cdb)')' thrown while requesting HEAD https://huggingface.co/timm/tf_efficientnetv2_s.in21k_ft_in1k/resolve/main/pytorch_model.bin
如上,不开启外网,huggingface更会报错(MaxRetryError);
不过好消息是他返回了两关键个文件(safetensors
文件和bin
文件)的下载地址;
现在我们先不要在第三方代码上折腾,给自己降低一点难度,改为直接运行通用的timm的API代码
比如要下载tf_efficientnetv2_s
这个预训练权重,可以执行下述代码:
import timm
print(timm.models.create_model('tf_efficientnetv2_s').default_cfg)
返回:
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}
这样我们就得到了权重的路径:‘hf_hub_id’: ‘timm/tf_efficientnetv2_s.in21k_ft_in1k’
和域名组合一下,就能得到完整的地址:https://huggingface.co/timm/tf_efficientnetv2_s.in1k
进入刚才给出的关键文件的地址:
【官网link】https://huggingface.co/timm/tf_efficientnetv2_s.in1k
然后把bin文件下载下来;
(注意,使用API后,不管有没有下载成功,models--timm--tf_efficientnetv2_s.in21k_ft_in1k
文件夹都会自动创建;)
根据代码给出的链接,理论上,我们只要下载bin
文件,然后把他放到"/home/linxq/.cache/huggingface/hub/models--timm--tf_efficientnetv2_s.in21k_ft_in1k/"
路径下就行,具体结果如下:
理论上,只要把HF下载的权重放在相应路径(xxx/.cache/huggingface/hub/)就好了
但是,就算有了权重文件,(可能是因为huggingface要执行固有的联网检测机制)timm仍然会报huggingface的“LocalEntryNotFoundError”网络错误:
LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.
这时候,可以为timm添加一个pretrained_cfg_overlay
参数,
import timm
print(timm.models.create_model('tf_efficientnetv2_s').default_cfg)
pretrained_cfg_overlay = {'file' : r"/home/linxq/.cache/huggingface/hub/models--timm--tf_efficientnetv2_s.in21k_ft_in1k/pytorch_model.bin"}
model = timm.models.create_model('tf_efficientnetv2_s', pretrained=True, pretrained_cfg_overlay=pretrained_cfg_overlay, num_classes=6)
print(model)
打印出model即代表成功!(具体实战请看后面的实战案例)
pretrained_cfg if pretrained: {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'file': '/home/linxq/.cache/huggingface/hub/models--timm--tf_efficientnetv2_s.in21k_ft_in1k/pytorch_model.bin', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}
pretrained_cfg getattr: {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'file': '/home/linxq/.cache/huggingface/hub/models--timm--tf_efficientnetv2_s.in21k_ft_in1k/pytorch_model.bin', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}
EfficientNet(
(conv_stem): Conv2dSame(3, 24, kernel_size=(3, 3), stride=(2, 2), bias=False)
(bn1): BatchNormAct2d(
24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(blocks): Sequential(
(0): Sequential(
(0): ConvBnAct(
(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(drop_path): Identity()
)
(1): ConvBnAct(
(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
...
)
(global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))
(classifier): Linear(in_features=1280, out_features=6, bias=True)
)
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...
另外,我还尝试了一下git方法,但是一样会因为网络问题报错,就不考虑这个方法了:
C:\Users\lenovo>git clone https://huggingface.co/timm/tf_efficientnetv2_s.in1k.git
Cloning into 'tf_efficientnetv2_s.in1k'...
fatal: unable to access 'https://huggingface.co/timm/tf_efficientnetv2_s.in1k.git/': Failed to connect to huggingface.co port 443 after 21082 ms: Couldn't connect to server
【ref】如何批量**huggingface模型和数据集文件
如果自己没有任何终端可以连接上huggingface的官网,是否可以用torch的资源替代huggingface的资源呢?
答案是肯定的,因为其实大多数资源torch都有;
放弃huggingface的tf_efficientnetv2_s.in1k
文件,改下载torch的tf_efficientnetv2_s_21ft1k-d7dafa41.pth
文件
暂时把pretrained
从True
设为False
执行
import timm
model = timm.create_model('tf_efficientnetv2_s', pretrained=False, num_classes=6)
print(model.default_cfg) # 打印url!
返回
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}
其中,url是旧版本对应的链接;hf_hub_id是新版本加入的HF下载链接,这个hf的优先级更高
其中url对应于torch源,现在我们把url中的pth文件下载下来,然后保存到指定路径即可
为了解决刚才说到的问题(已有timm权重后huggingface还是报网络错误);以及,为了完全避开huggingface的链接(因为就算不用huggingface来下载,timn库还是可能用它来检查,这时又会涉及到网络问题)——
可以在pretrained_cfg中直接插入torch链接下载的pth文件。
对此,我们可以用torch链接下载得到的pth文件作为pretrained_cfg['file']
中的权重:
import timm
print(timm.models.create_model('tf_efficientnetv2_s').default_cfg)
pretrained_cfg = timm.models.create_model('tf_efficientnetv2_s').default_cfg
pretrained_cfg['file'] = r"/home/linxq/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21ft1k-d7dafa41.pth"
print(pretrained_cfg)
model = timm.models.create_model('tf_efficientnetv2_s', pretrained=True, pretrained_cfg=pretrained_cfg)
print(model)
打印出model即代表成功!(具体实战请看后面的实战案例)
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier', 'file': '/home/linxq/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21ft1k-d7dafa41.pth'}
pretrained_cfg if pretrained: {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'file': '/home/linxq/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}
pretrained_cfg getattr: {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'file': '/home/linxq/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}
EfficientNet(
(conv_stem): Conv2dSame(3, 24, kernel_size=(3, 3), stride=(2, 2), bias=False)
(bn1): BatchNormAct2d(
24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(blocks): Sequential(
(0): Sequential(
(0): ConvBnAct(
(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(drop_path): Identity()
)
(1): ConvBnAct(
(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
...
)
(global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))
(classifier): Linear(in_features=1280, out_features=1000, bias=True)
)
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...
【ref】[Pytorch] timm.create_model()通过指定pretrained_cfg从本地加载pretrained模型
执行:
import timm
model = timm.create_model('convnext_small', pretrained=False, num_classes=6)
print(model.default_cfg)
返回:
{'url': '', 'hf_hub_id': 'timm/convnext_small.in12k_ft_in1k', 'architecture': 'convnext_small', 'tag': 'in12k_ft_in1k', 'custom_load': False, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 0.95, 'test_crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'num_classes': 1000, 'pool_size': (7, 7), 'first_conv': 'stem.0', 'classifier': 'head.fc'}
如上,返回的url为空,则无法下载pth文件,例如“convnext_small”这个模型的url就为空
这时,就只能去其他地方寻找pth文件了,例如:
https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth
参考:【资源贴】❀资源帖❀ResNet,ConvNeXt,Transformer预训练模型等
再总结一下预训练模型在不同服务器上的保存路径:
timm.models.create_model
在选择pretrained=True
时会默认在本地路径查找是否有相应的pretrained模型参数文件,如果没有则下载到本地指定目录:
Windows: C:\Users\用户名\.cache\torch\hub\checkpoints
Linux: /home/用户名/.cache/torch/hub/checkpoints
获取服务器上模型权重默认存储路径,以便上传模型,一般来说分两个:
huggingface:xxx/.cache/huggingface/hub/
torch:xxx/.cache/torch/hub/checkpoints/
分别对应:从huggingface上下载的权重、从旧链接上下载的权重
【code link】https://github.com/okunator/cellseg_models.pytorch/tree/main
地址:
/home/linxq/code/cell_seg_workflow/src_cite/segment/cellseg_models.pytorch-main/examples/pannuke_nuclei_segmentation_cppnet.ipynb
在最外层的应用代码中设置好enc_name
就好,不用做其他参数的修改(要修改的是内层的timm.create_model部分)
import timm
# pretrained_cfg = timm.models.create_model('tf_efficientnetv2_s').default_cfg
# pretrained_cfg['file'] = r"/home/linxq/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21ft1k-d7dafa41.pth"
# Define the model with the function API.
model = cppnet_base_multiclass(
enc_name="tf_efficientnetv2_s",
n_rays=32, # number of predicted rays
type_classes=len(pannuke_module.type_classes),
pretrained=True,
# pretrained_cfg=pretrained_cfg,
)
/home/linxq/code/cell_seg_workflow/src_cite/segment/cellseg_models.pytorch-main/cellseg_models_pytorch/encoders/timm_encoder.py
找到内层的timm.create_model部分,在其函数中添加包含了file
项的pretrained_cfg参数:
# create the timm model
pretrained_cfg = timm.models.create_model('tf_efficientnetv2_s').default_cfg
pretrained_cfg['file'] = r"/home/linxq/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21ft1k-d7dafa41.pth"
try:
self.backbone = timm.create_model(
name,
pretrained=pretrained,
pretrained_cfg = pretrained_cfg,
checkpoint_path=checkpoint_path,
in_chans=in_channels,
features_only=True,
out_indices=self.out_indices,
**kwargs,
)
except (AttributeError, RuntimeError) as err:
print(err)
raise RuntimeError(
f"timm backbone: {name} is not supported due to missing "
"features_only argument implementation in timm-package."
)
except IndexError as err:
print(err)
raise IndexError(
f"It's possible that the given depth: {depth} is too large for "
f"the given backbone: {name}. Try passing a smaller `depth` argument "
"or a different backbone."
)
图片可视化版:
/home/linxq/code/cell_seg_workflow/src_cite/segment/cellseg_models.pytorch-main/examples/pannuke_nuclei_segmentation_cppnet.ipynb
然后,可以成功运行:
要自定义加载某些特征层,可以通过修改模型的状态字典(state_dict)来实现
'''
自定义加载某些指定的特征层
可以通过修改模型的状态字典(state_dict)来实现
'''
import torch
import timm
# 打印默认配置
print(timm.models.create_model('tf_efficientnetv2_s').default_cfg)
# 指定预训练权重文件路径
pretrained_cfg = timm.models.create_model('tf_efficientnetv2_s').default_cfg
pretrained_cfg['file'] = r"C:\Users\lenovo\.cache\torch\hub\checkpoints\tf_efficientnetv2_s_21ft1k-d7dafa41.pth"
print(pretrained_cfg)
# 创建模型
model = timm.models.create_model('tf_efficientnetv2_s', pretrained=False)
# 加载预训练权重的状态字典
pretrained_state_dict = torch.load(pretrained_cfg['file'])
# 获取当前模型的状态字典
model_state_dict = model.state_dict()
# 自定义加载特定的特征层(例如,如果只想加载模型的卷积层权重)
for key in list(pretrained_state_dict.keys()):
# 根据需求选择性加载特定的键(特定的层)
if 'conv' in key:
model_state_dict[key] = pretrained_state_dict[key]
# 将修改后的状态字典加载到模型中
model.load_state_dict(model_state_dict)
# 打印模型
print(model)
打印模型结构:
EfficientNet(
(conv_stem): Conv2dSame(3, 24, kernel_size=(3, 3), stride=(2, 2), bias=False)
(bn1): BatchNormAct2d(
24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(blocks): Sequential(
(0): Sequential(
(0): ConvBnAct(
(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(drop_path): Identity()
)
(1): ConvBnAct(
(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
...
要冻结模型的某些层,可以通过将这些层的参数的requires_grad属性设置为False来实现
# 冻结模型的某些层,例如冻结所有卷积层
for name, param in model.named_parameters():
if 'conv' in name:
param.requires_grad = False
# 查看哪些层被冻结
for name, param in model.named_parameters():
print(f'{name}: requires_grad={param.requires_grad}')
打印各层的情况:
conv_stem.weight: requires_grad=False
bn1.weight: requires_grad=True
bn1.bias: requires_grad=True
blocks.0.0.conv.weight: requires_grad=False
blocks.0.0.bn1.weight: requires_grad=True
blocks.0.0.bn1.bias: requires_grad=True
blocks.0.1.conv.weight: requires_grad=Fals
...
解决了pretrained问题后,可以再研究一下checkpoint的问题:
前者是预训练权重,只有部分层的参数是定制的;后者是自己训练后得到的权重,所有层的参数都是定制的;
成功进行第一次训练后,会得到一个checkpoint;然而,第二次训练、定义模型时,加入了checkpoint_path,结果报错:
RuntimeError: timm backbone: tf_efficientnetv2_s is not supported due to missing features_only argument implementation in timm-package.
待解决…
RuntimeError: timm backbone: convnext_small is not supported due to missing features_only argument implementation in timm-package.
features_only是一个参数,通常用于控制模型输出。它用于指示模型仅生成特征而不执行最终的分类或回归任务。这在某些情况下可能很有用,特别是当我们只对中间特征表示感兴趣时,而不是最终预测结果。该参数可以让我们提取模型中间层的特征,以便进一步分析或在其他任务中使用。
如果报了以上错误,这可能意味着Timm包中目前不支持ConvNext_Small模型,可能是因为缺少了用于控制输出的features_only参数的实现。这可能是Timm包开发者尚未实现该功能或尚未对ConvNext_Small模型进行适当的集成。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。