Save every 50 instead of all at once, so it can be cancelled
This commit is contained in:
parent
8200cd9af1
commit
eac7aee82c
@ -107,19 +107,27 @@ class UbuntuCorpusTrainer2(KaggleTrainer):
|
|||||||
previous_statement_search_text = ""
|
previous_statement_search_text = ""
|
||||||
statements_from_file = []
|
statements_from_file = []
|
||||||
|
|
||||||
|
save_every = 50
|
||||||
|
count = 0
|
||||||
|
|
||||||
async for row in AsyncIter(reader):
|
async for row in AsyncIter(reader):
|
||||||
dialogue_id = row["dialogueID"]
|
dialogue_id = row["dialogueID"]
|
||||||
if dialogue_id != last_dialogue_id:
|
if dialogue_id != last_dialogue_id:
|
||||||
previous_statement_text = None
|
previous_statement_text = None
|
||||||
previous_statement_search_text = ""
|
previous_statement_search_text = ""
|
||||||
last_dialogue_id = dialogue_id
|
last_dialogue_id = dialogue_id
|
||||||
|
count += 1
|
||||||
|
if count >= save_every:
|
||||||
|
if statements_from_file:
|
||||||
|
self.chatbot.storage.create_many(statements_from_file)
|
||||||
|
count = 0
|
||||||
|
|
||||||
if len(row) > 0:
|
if len(row) > 0:
|
||||||
statement = Statement(
|
statement = Statement(
|
||||||
text=row["text"],
|
text=row["text"],
|
||||||
in_response_to=previous_statement_text,
|
in_response_to=previous_statement_text,
|
||||||
conversation="training",
|
conversation="training",
|
||||||
created_at=date_parser.parse(row["date"]),
|
# created_at=date_parser.parse(row["date"]),
|
||||||
persona=row["from"],
|
persona=row["from"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user