当前位置:   article > 正文

lora-scripts代码分析

lora-scripts代码分析

fastapi写的前后端,确实是好代码,没用webui了,这样可以直接在内网就打开页面,还是非常实用的,以往用FastAPI封的很简单,这块值得学习。

  1. prepare_environment()->
  2. - prepare_submodules()
  3. - check_dirs()
  4. - validate_requirements()
  5. - setup_windows_bitsandbytes()
  6. - setup_onnxruntime()
  7. run_tag_editor()->
  8. run_tensorboard()->
  9. uvicorn.run()->mikazuki.app->
  10. - mikazuki.app.application->
  11. -- app=FastAPI(lifespan)
  12. -- app.include_router(proxy_router) # 可以处理该路由器定义的http请求
  13. -- app.include_router(api_router,prefix="/api")->
  14. --- router = APIRouter()
  15. --- @router.post("/interrogate")
  16. ---- async def run_interrogate(req:TaggerInterrogateRequest,background_task)->
  17. ---- interrogator = available_interrogators.get(req.interrogator_model, available_interrogators["wd14-convnextv2-v2"])
  18. ---- background_tasks.add_task()
  19. ----- on_interrogate()->
  20. ----- ratings,tags = interrogator.interrogate(image)->
  21. ----- processed_tags = Interrogator.postprocess_tags(tag,*postprocess_opts)
  22. --- @router.post("/run")
  23. --- async def create_toml_file(request:Request)->
  24. --- json_data = await request.body()->
  25. --- config = json.loads(json_data.decode("utf-8"))->
  26. --- trainer_file = trainer_mapping[model_train_type]->
  27. --- validated, message = train_utils.validate_model(config["pretrained_model_name_or_path"])
  28. --- result = process.run_train(toml_file, trainer_file, gpu_ids, suggest_cpu_threads)
  29. ---- python -m accelerate.commands.launch --num_cpu_threads_per_process=4 \
  30. "/home/image_team/image_team_docker_home/lgd/e_commerce_sd/tools/sd_lora/sd-scripts/train_network.py" \
  31. --config_file="E:\comprehensive_library\e_commerce_sd\tools\sd_lora\config\lora_sd_sciprts.toml"

主要是两个接口

  1. import sys
  2. import json
  3. import requests
  4. from pathlib import Path
  5. root_dir = Path(__file__).resolve().parent.parent
  6. sys.path.append(str(root_dir))
  7. interrogator = False
  8. lora_new = True
  9. lora_expert = False
  10. input_json_interrogator = {
  11. "path": "/home/image_team/image_team_docker_home/lgd/common/lora-scripts-main/lion/",
  12. "interrogator_model": "wd14-convnextv2-v2",
  13. "threshold": 0.35,
  14. "additional_tags": "",
  15. "exclude_tags": "",
  16. "escape_tag": True,
  17. "batch_input_recursive": False,
  18. "batch_output_action_on_conflict": "ignore",
  19. "replace_underscore": True
  20. }
  21. input_json_lora_new = {
  22. 'pretrained_model_name_or_path': '/home/image_team/image_team_docker_home/lgd/e_commerce_sd/stable-diffusion-webui/models/Stable-diffusion/revAnimated_v121.safetensors',
  23. 'train_data_dir': '/home/image_team/image_team_docker_home/lgd/common/lora-scripts-main/lion/',
  24. 'resolution': '512,512',
  25. 'enable_bucket': True,
  26. 'min_bucket_reso': 256,
  27. 'max_bucket_reso': 1024,
  28. 'output_name': 'aki',
  29. 'output_dir': './output',
  30. 'save_model_as': 'safetensors',
  31. 'save_every_n_epochs': 2,
  32. 'max_train_epochs': 10,
  33. 'train_batch_size': 1,
  34. 'network_train_unet_only': False,
  35. 'network_train_text_encoder_only': False,
  36. 'learning_rate': 0.0001,
  37. 'unet_lr': 0.0001,
  38. 'text_encoder_lr': 1e-05,
  39. 'lr_scheduler': 'cosine_with_restarts',
  40. 'optimizer_type': 'AdamW8bit',
  41. 'lr_scheduler_num_cycles': 1,
  42. 'network_module': 'networks.lora',
  43. 'network_dim': 32,
  44. 'network_alpha': 32,
  45. 'logging_dir': './logs',
  46. 'caption_extension': '.txt',
  47. 'shuffle_caption': True,
  48. 'keep_tokens': 0,
  49. 'max_token_length': 255,
  50. 'seed': 1337,
  51. 'prior_loss_weight': 1,
  52. 'clip_skip': 2,
  53. 'mixed_precision': 'fp16',
  54. 'save_precision': 'fp16',
  55. 'xformers': True,
  56. 'cache_latents': True,
  57. 'persistent_data_loader_workers': True,
  58. 'lr_warmup_steps': 0,
  59. 'gpu_ids': ['0', '1', '2', '3']}
  60. # input_json_lora_expert = {
  61. #
  62. # }
  63. if interrogator:
  64. results = requests.post("http://10.111.132.198:28000/api/interrogate", json=input_json_interrogator)
  65. print(results.json())
  66. if lora_new:
  67. results = requests.post("http://10.111.132.198:28000/api/run", json=input_json_lora_new)
  68. print(results.json())

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/349546
推荐阅读
相关标签
  

闽ICP备14008679号