Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build bert: build does not load model #2379

Closed
2 of 4 tasks
Alireza3242 opened this issue Oct 26, 2024 · 3 comments
Closed
2 of 4 tasks

build bert: build does not load model #2379

Alireza3242 opened this issue Oct 26, 2024 · 3 comments
Labels
bug Something isn't working triaged Issue has been triaged by maintainers

Comments

@Alireza3242
Copy link

Alireza3242 commented Oct 26, 2024

System Info

a100

Who can help?

@byshiue

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

My bert model is model.safetensors. But examples/bert/build.py cant read it.
I changed this part of code and worked:

if args.model_dir is not None and os.path.exists(
                os.path.join(args.model_dir, "pytorch_model.bin")):
            state_dict = torch.load(
                os.path.join(args.model_dir, "pytorch_model.bin"))
            hf_bert.load_state_dict(state_dict, strict=False)

to

if args.model_dir is not None and os.path.exists(
                os.path.join(args.model_dir, "pytorch_model.bin")):
            state_dict = torch.load(
                os.path.join(args.model_dir, "pytorch_model.bin"))
            hf_bert.load_state_dict(state_dict, strict=False)
elif args.model_dir is not None and os.path.exists(
                os.path.join(args.model_dir, "model.safetensors")):
            state_dict = safetensors.torch.load_file(os.path.join(args.model_dir, "model.safetensors"))
            hf_bert.load_state_dict(state_dict, strict=False)

Expected behavior

read model

actual behavior

not read model

additional notes

nothing

@Alireza3242 Alireza3242 added the bug Something isn't working label Oct 26, 2024
@Superjomn Superjomn added build triaged Issue has been triaged by maintainers labels Oct 28, 2024
@symphonylyh
Copy link
Collaborator

symphonylyh commented Oct 28, 2024

fixed together with your other issue #2373. Thank you!!
I used a try-catch instead:

from safetensors.torch import load_file
...
if args.model_dir is not None:
            try:
                state_dict = torch.load(
                    os.path.join(args.model_dir, "pytorch_model.bin"))
            except FileNotFoundError:
                state_dict = load_file(os.path.join(args.model_dir, "model.safetensors"))
            hf_bert.load_state_dict(state_dict, strict=False)

@tkhanipov
Copy link

@symphonylyh
Note: the change only covers the case elif args.model == 'BertForSequenceClassification' or args.model == 'RobertaForSequenceClassification'. The if args.model == 'BertModel' or args.model == 'RobertaModel' case remains buggy. Please have another look at #2187, there is a fix for the latter case (the corresponding bug: #2197).

@tkhanipov
Copy link

@symphonylyh Here you said that you were planning to merge the PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants