1
0
Fork 0
mirror of https://github.com/Findus23/se-simulator.git synced 2024-09-19 15:53:45 +02:00
se-simulator/text_generator.py

106 lines
3.2 KiB
Python
Raw Normal View History

2018-03-19 22:03:32 +01:00
import os
2018-03-10 19:28:02 +01:00
import jsonlines
import markovify
2018-03-16 20:31:43 +01:00
from markov import MarkovText, MarkovUserName
2018-03-16 18:48:54 +01:00
from utils import *
2018-03-10 19:28:02 +01:00
2018-03-16 20:31:43 +01:00
def get_markov(mode):
if mode == "Usernames":
return MarkovUserName
else:
return MarkovText
2018-03-10 19:28:02 +01:00
2018-03-16 20:31:43 +01:00
def get_state_size(mode):
return 1 if mode == "Titles" else 3 if mode == "Usernames" else 2
def load_chain(chainfile, mode):
markov = get_markov(mode)
2018-03-10 19:28:02 +01:00
with open(chainfile, 'r') as myfile:
data = myfile.read()
print("using existing file\n")
2018-03-16 20:31:43 +01:00
return markov.from_json(data)
2018-03-10 19:28:02 +01:00
2018-03-19 22:03:32 +01:00
def generate_chain(sourcedir, chainfile, mode):
2018-03-10 19:28:02 +01:00
combined_cains = None
chainlist = []
2018-03-16 20:31:43 +01:00
markov = get_markov(mode)
2018-03-10 19:28:02 +01:00
i = 0
2018-03-19 22:03:32 +01:00
with jsonlines.open(sourcedir + "/{type}.jsonl".format(type=mode), mode="r") as content:
2018-03-10 19:28:02 +01:00
for text in content:
2018-03-10 20:50:33 +01:00
text = text.strip()
2018-03-16 16:23:25 +01:00
try:
2018-03-16 20:31:43 +01:00
chain = markov(text, get_state_size(mode), retain_original=False)
2018-03-16 16:23:25 +01:00
except KeyError:
continue
2018-03-10 19:28:02 +01:00
chainlist.append(chain)
if i % 100 == 0:
print(i)
if i % 1000 == 0:
subtotal_chain = markovify.combine(chainlist)
if not combined_cains:
combined_cains = subtotal_chain
else:
combined_cains = markovify.combine(models=[combined_cains, subtotal_chain])
chainlist = []
i += 1
subtotal_chain = markovify.combine(chainlist)
chain = markovify.combine([combined_cains, subtotal_chain])
with open(chainfile, 'w') as outfile:
outfile.write(chain.to_json())
2018-03-16 20:31:43 +01:00
print_ram()
return chain
2018-03-19 22:03:32 +01:00
def get_chain(url, mode):
sourcedir = 'sites/{url}'.format(url=url, type=mode)
chainfile = 'sites/{url}/{type}.chain.json'.format(url=url, type=mode)
if os.path.exists(chainfile):
return load_chain(chainfile, mode)
else:
return generate_chain(sourcedir, chainfile, mode)
def generate_text(chain: markovify.Text, model):
if model == "Titles":
return chain.make_short_sentence(70)
2018-03-25 20:35:03 +02:00
if model == "Usernames":
return chain.make_short_sentence(36)
2018-03-25 11:28:37 +02:00
if model == "Questions" or "Answers":
2018-03-24 17:33:10 +01:00
paragraphs = []
sentences = []
count = int((random.randint(2, 6) * random.randint(2, 6) / 5))
for _ in range(count):
sentences.append(chain.make_sentence())
if random.random() < 0.4:
paragraphs.append(sentences)
sentences = []
paragraphs.append(sentences)
return "\n".join([" ".join(paragraph) for paragraph in paragraphs])
return chain.make_sentence()
2018-03-19 22:03:32 +01:00
2018-03-16 20:31:43 +01:00
if __name__ == "__main__":
basedir, mode = get_settings(2)
if mode not in ["Questions", "Answers", "Titles", "Usernames"]:
print("error")
exit()
2018-03-19 22:03:32 +01:00
chain = get_chain("sites/astronomy.stackexchange.com", mode)
2018-03-16 20:31:43 +01:00
for _ in range(10):
# walk = []
# for text in chain.gen():
# if len(walk) > 100:
# break
# walk.append(text)
# result = detokenizer.detokenize(walk, return_str=True)
# print(result.replace("THISISANEWLINE ", "\n"))
print(chain.make_sentence())
print("-----------------------------------")
2018-03-10 19:28:02 +01:00
2018-03-16 20:31:43 +01:00
print_ram()