import logging import polars as pl from utils.embed import embed as embed from utils.paths import DATA logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def load_dataset(file_name: str): features = ["meta_title", "meta_description", "content"] return ( pl.scan_csv(file_name) .with_columns( pl.concat_str([pl.col(c) for c in features], separator="\n\n").alias( "text" ), pl.col("date").str.to_date().alias("date"), ) .rename( { "is_news_article": "is_news", "link_count": "links", "paragraph_count": "paragraphs", } ) .select("text", "is_news", "url", "date", "paragraphs", "links") .collect() ) def main() -> None: for name in ["train", "eval"]: df = load_dataset(DATA / (name + ".csv")) embeds = embed(df.get_column("text").to_list()) df = df.with_columns(pl.Series(embeds).alias("embeds")).write_parquet( DATA / (name + ".parquet") ) if __name__ == "__main__": main()