Skip to content

Commit

Permalink
update pipeline & readme
Browse files Browse the repository at this point in the history
  • Loading branch information
SivilTaram committed Nov 13, 2022
1 parent c25f350 commit 2f9777e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
4 changes: 4 additions & 0 deletions awakening_latent_grounding/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# README

This code is still incomplete. We will actively update it when the code is totally ready.

# Environment Setup

```
Expand Down
16 changes: 9 additions & 7 deletions awakening_latent_grounding/inference/pipeline_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@
import torch
from .pipeline_base import NLBindingInferencePipeline


class NLBindingTorchScriptPipeline(NLBindingInferencePipeline):
def __init__(self,
model_dir: str,
greedy_linking: bool,
threshold: float=0.2,
num_threads: int=8,
use_gpu: bool = torch.cuda.is_available()
) -> None:
model_dir: str,
greedy_linking: bool,
threshold: float = 0.2,
num_threads: int = 8,
use_gpu: bool = torch.cuda.is_available()
) -> None:
super().__init__(model_dir, greedy_linking=greedy_linking, threshold=threshold)

self.device = torch.device('cuda') if use_gpu else torch.device('cpu')
model_file = "nl_binding.script.bin"

torch.set_num_interop_threads(2)
torch.set_num_threads(num_threads)
print('Torch model Threads: {}, {}, {}'.format(torch.get_num_interop_threads(), torch.get_num_threads(), self.device))
print('Torch model Threads: {}, {}, {}'.format(torch.get_num_interop_threads(), torch.get_num_threads(),
self.device))
model_ckpt_path = os.path.join(model_dir, model_file)
self.model = torch.jit.load(model_ckpt_path, map_location=self.device)
self.model.eval()
Expand Down

0 comments on commit 2f9777e

Please sign in to comment.