mirror of
https://github.com/Findus23/se-simulator.git
synced 2024-09-19 15:53:45 +02:00
improve generator
This commit is contained in:
parent
6dd5438973
commit
bb3b691caa
5 changed files with 64 additions and 32 deletions
|
@ -7,7 +7,7 @@ from parsexml import parse_posts, parse_comments, parse_usernames
|
|||
from utils import *
|
||||
# os.chdir("/mydir")
|
||||
for file in glob.glob("downloads/**/*.7z"):
|
||||
if not "worldbuilding" in file:
|
||||
if not "raspberry" in file:
|
||||
continue
|
||||
code = os.path.basename(os.path.splitext(file)[0])
|
||||
print(code)
|
||||
|
|
10
markov.py
10
markov.py
|
@ -5,9 +5,17 @@ tokenizer = MosesTokenizer()
|
|||
detokenizer = MosesDetokenizer()
|
||||
|
||||
|
||||
class POSifiedText(markovify.Text):
|
||||
class MarkovText(markovify.Text):
|
||||
def word_split(self, sentence):
|
||||
return tokenizer.tokenize(sentence)
|
||||
|
||||
def word_join(self, words):
|
||||
return detokenizer.detokenize(words, return_str=True)
|
||||
|
||||
|
||||
class MarkovUserName(markovify.Text):
|
||||
def word_split(self, word):
|
||||
return list(word)
|
||||
|
||||
def word_join(self, characters):
|
||||
return "".join(characters)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from xml.etree import ElementTree
|
||||
|
||||
import jsonlines
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from utils import *
|
||||
|
||||
|
@ -22,8 +21,7 @@ def parse_posts(inputdir, outputdir):
|
|||
titles.write(title)
|
||||
body = element.get('Body')
|
||||
if body:
|
||||
soup = BeautifulSoup(body, "lxml")
|
||||
text = soup.get_text()
|
||||
text = html2text(body)
|
||||
if element.get('PostTypeId') == "1":
|
||||
questions.write(text)
|
||||
else:
|
||||
|
|
|
@ -1,38 +1,41 @@
|
|||
import jsonlines
|
||||
import markovify
|
||||
import os
|
||||
from nltk.tokenize.moses import MosesDetokenizer, MosesTokenizer
|
||||
|
||||
from markov import POSifiedText
|
||||
from markov import MarkovText, MarkovUserName
|
||||
from utils import *
|
||||
|
||||
detokenizer = MosesDetokenizer()
|
||||
|
||||
BASEDIR, mode = get_settings(2)
|
||||
if mode not in ["Questions", "Answers", "Titles"]:
|
||||
print("error")
|
||||
exit()
|
||||
chainfile = BASEDIR + '/{type}.chain.json'.format(type=mode)
|
||||
def get_markov(mode):
|
||||
if mode == "Usernames":
|
||||
return MarkovUserName
|
||||
else:
|
||||
return MarkovText
|
||||
|
||||
try:
|
||||
|
||||
def get_state_size(mode):
|
||||
return 1 if mode == "Titles" else 3 if mode == "Usernames" else 2
|
||||
|
||||
|
||||
def load_chain(chainfile, mode):
|
||||
markov = get_markov(mode)
|
||||
with open(chainfile, 'r') as myfile:
|
||||
data = myfile.read()
|
||||
chain = POSifiedText.from_json(data)
|
||||
# raise FileNotFoundError
|
||||
print("using existing file\n")
|
||||
return markov.from_json(data)
|
||||
|
||||
except FileNotFoundError:
|
||||
tokenizer = MosesTokenizer()
|
||||
|
||||
def generate_chain(basedir, mode):
|
||||
combined_cains = None
|
||||
chainlist = []
|
||||
markov = get_markov(mode)
|
||||
i = 0
|
||||
with jsonlines.open(BASEDIR + "/{type}.jsonl".format(type=mode), mode="r") as content:
|
||||
with jsonlines.open(basedir + "/{type}.jsonl".format(type=mode), mode="r") as content:
|
||||
for text in content:
|
||||
text = text.strip()
|
||||
# tokens = tokenizer.tokenize(text=text.replace("\n", " THISISANEWLINE "))
|
||||
try:
|
||||
chain = POSifiedText(text, (1 if mode == "Titles" else 2), retain_original=False)
|
||||
# chain = markovify.Chain([tokens], (1 if mode == "Titles" else 2))
|
||||
chain = markov(text, get_state_size(mode), retain_original=False)
|
||||
except KeyError:
|
||||
continue
|
||||
chainlist.append(chain)
|
||||
|
@ -50,6 +53,20 @@ except FileNotFoundError:
|
|||
chain = markovify.combine([combined_cains, subtotal_chain])
|
||||
with open(chainfile, 'w') as outfile:
|
||||
outfile.write(chain.to_json())
|
||||
print_ram()
|
||||
return chain
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
basedir, mode = get_settings(2)
|
||||
if mode not in ["Questions", "Answers", "Titles", "Usernames"]:
|
||||
print("error")
|
||||
exit()
|
||||
chainfile = basedir + '/{type}.chain.json'.format(type=mode)
|
||||
if os.path.exists(chainfile):
|
||||
chain = load_chain(chainfile, mode)
|
||||
else:
|
||||
chain = generate_chain(basedir, mode)
|
||||
|
||||
for _ in range(10):
|
||||
# walk = []
|
||||
|
|
9
utils.py
9
utils.py
|
@ -2,6 +2,8 @@ import sys
|
|||
|
||||
import resource
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
def print_stats(i, skipped=None):
|
||||
print("{number} total entries".format(number=i))
|
||||
|
@ -14,6 +16,13 @@ def print_ram():
|
|||
print("used {mb}MB".format(mb=resource.getrusage(resource.RUSAGE_SELF).ru_maxrss // 1024))
|
||||
|
||||
|
||||
def html2text(body):
|
||||
soup = BeautifulSoup(body, "lxml")
|
||||
for code in soup.find_all("code"):
|
||||
code.decompose()
|
||||
return soup.get_text()
|
||||
|
||||
|
||||
def get_settings(count):
|
||||
if len(sys.argv) != count + 1:
|
||||
if count == 1:
|
||||
|
|
Loading…
Reference in a new issue