Skip to content

Commit

Permalink
Code updates
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Apr 14, 2023
1 parent db94bfe commit e6dcc7a
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 51 deletions.
16 changes: 3 additions & 13 deletions src/python/txtinstruct/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,7 @@ def generate(self, rows):
texts = [row["text"] for row in rows]

# Generate statements
statements = self.statement([
self.formatter.format(self.sprompt, context=row["text"]) for row in rows],
truncation=True,
batch_size=len(rows)
)
statements = self.statement([self.formatter.format(self.sprompt, context=row["text"]) for row in rows], truncation=True, batch_size=len(rows))

# Generate template statements
templates = self.template(ids) if self.templates else []
Expand All @@ -122,20 +118,14 @@ def generate(self, rows):
for x, text in enumerate(texts):
output = {"context": text, "statements": []}
for question in queue[x]:
output["statements"].append({
"source": question,
"target": targets[index]
})
output["statements"].append({"source": question, "target": targets[index]})

index += 1

# Generate unanswerable statement
y = random.choice([i for i in range(0, len(texts)) if i != x])
statement = random.choice([statements[y], templates[y]]) if templates else statements[y]
output["statements"].append({
"source": statement,
"target": "I don't have data on that"
})
output["statements"].append({"source": statement, "target": "I don't have data on that"})

outputs.append(output)

Expand Down
8 changes: 2 additions & 6 deletions src/python/txtinstruct/models/bashsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,10 @@ def generate(self, path):
lang = ["ar", "en", "fr", "de", "hi", "it", "nl", "ro", "ru", "zh"]
lang1, lang2 = random.choice(lang), random.choice(lang)
self.append(
output,
f"{find} -translate {lang1}",
f"select id, translate(text, '{lang1}') text, score from txtai where similar('{query}')"
output, f"{find} -translate {lang1}", f"select id, translate(text, '{lang1}') text, score from txtai where similar('{query}')"
)
self.append(
output,
f"{find} -translate {lang2}",
f"select id, translate(text, '{lang2}') text, score from txtai where similar('{query}')"
output, f"{find} -translate {lang2}", f"select id, translate(text, '{lang2}') text, score from txtai where similar('{query}')"
)
self.append(
output,
Expand Down
16 changes: 3 additions & 13 deletions src/python/txtinstruct/models/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@ def __call__(self, base, data, task, prompt=None, **kwargs):
prompt = prompt if prompt else self.defaultprompt(task)

# Build training dataset
train = Dataset.from_generator(self.generate, gen_kwargs=({
"data": data,
"task": task,
"prompt": prompt
})
)
train = Dataset.from_generator(self.generate, gen_kwargs=({"data": data, "task": task, "prompt": prompt}))

# Train model
trainer = HFTrainer()
Expand All @@ -60,14 +55,9 @@ def generate(self, data, task, prompt):
for row in data:
for statement in row["statements"]:
if task == "language-generation":
yield {
"text": formatter.format(prompt, statement=statement["source"], context=row["context"]) + statement["target"]
}
yield {"text": formatter.format(prompt, statement=statement["source"], context=row["context"]) + statement["target"]}
else:
yield {
"source": formatter.format(prompt, statement=statement["source"], context=row["context"]),
"target": statement["target"]
}
yield {"source": formatter.format(prompt, statement=statement["source"], context=row["context"]), "target": statement["target"]}

def defaultprompt(self, task):
"""
Expand Down
26 changes: 13 additions & 13 deletions src/python/txtinstruct/models/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@ def __call__(self, base, data, task, prompt=None, **kwargs):
prompt = prompt if prompt else self.defaultprompt(task)

# Build training dataset
train = Dataset.from_generator(self.generate, gen_kwargs=({
"data": data,
"task": task,
"prompt": prompt
})
)
train = Dataset.from_generator(self.generate, gen_kwargs=({"data": data, "task": task, "prompt": prompt}))

# Train model
trainer = HFTrainer()
Expand All @@ -62,16 +57,21 @@ def generate(self, data, task, prompt):
context = row["context"]

if task == "language-generation":
yield {
"text": formatter.format(prompt, context=context) + row["question"]
}
yield {"text": formatter.format(prompt, context=context) + row["question"]}
else:
yield {
"source": formatter.format(prompt, context=context),
"target": row["question"]
}
yield {"source": formatter.format(prompt, context=context), "target": row["question"]}

def defaultprompt(self, task):
"""
Default model prompt.
Args:
task: model task
Returns:
default model prompt
"""

prompt = """Generate a question using the context below.
### Context:
{context}
Expand Down
8 changes: 2 additions & 6 deletions src/python/txtinstruct/models/txtsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def generate(self, path):
# Query by date and score
self.append(output, f"{query} since yesterday and score less than 0.5", f"{sql} and entry >= date('now', '-1 day') and score $= 0.5")
self.append(
output,
f"{query} with a score greater than 0.2 since yesterday",
f"{sql} and score >= 0.2 and entry >= date('now', '-1 day')"
output, f"{query} with a score greater than 0.2 since yesterday", f"{sql} and score >= 0.2 and entry >= date('now', '-1 day')"
)

# Query by text field
Expand All @@ -94,9 +92,7 @@ def generate(self, path):
# Query with OR
self.append(output, f"{query} having text equal data or field as snippet", f"{sql} and (text = 'data' or field = 'snippet')")
self.append(
output,
f"{query} having text as data or field equal snippet value",
f"{sql} and (text = 'data' or field = 'snippet value')"
output, f"{query} having text as data or field equal snippet value", f"{sql} and (text = 'data' or field = 'snippet value')"
)
self.append(output, f"{query} with field equal snippet or text as data", f"{sql} and (field = 'snippet' or text = 'data')")

Expand Down

0 comments on commit e6dcc7a

Please sign in to comment.