fb-pixel
Back to Blog

Classifying text with fastText in pySpark

Background

We are working in a customer project where we need to classify hundreds of millions of messages based on the language. Python libraries such as langdetect.py fails to detect the language of the message. The reason behind poor performance for language detection libraries in general is that they are trained on longer texts, and thus, they don't work in our special and rather challenging use case.

fastText [1] was chosen because it has shown excellent performance in text classification [2] and in language detection [3]. However, it is not trivial to run fastText in pySpark, thus, we wrote this guide.

Setting up pySpark, fastText and Jupyter notebooks

To run the provided example, you need to have Apache Spark running either locally, e.g. on your laptop, or in cloud e.g. in AWS EMR. You can install Apache Spark with pySpark by executing

    pip3 install pyspark

You may need to install Jupyter notebooks too

    pip3 install jupyter

We also need fastText-python wrapper.

    pip3 install fasttext

Data

Before we can run a language classifier in Spark, we must train a classifier model. To train a model, we need to have known samples for each language we are interested in. In this experiment, we used the Sentences dataset from https://tatoeba.org/eng/downloads . So, you need to download http://downloads.tatoeba.org/exports/sentences.tar.bz2 and store it in data/ and decompress the file e.g. with command

    tar jxf sentences.tar.bz2

Training a fastText classifier

Train/test split data generation

Whenever we are building a model to make predictions, we need to evaluate its performance. Here is one example of how we can divide our known data into train and test splits.

    # Define schema for the CSV file
    schema = StructType([StructField("sentence_id", IntegerType(), True),
                         StructField("language_code", StringType(), True),
                         StructField("text", StringType(), True)])
    
    # Read CSV into spark dataframe
    spark_df = spark.read.csv('data/sentences.csv',
                              schema=schema, sep='\t')
    
    # Split data into train and test
    train_df, test_df = spark_df.randomSplit([0.8, 0.2], 42)

Training a model

After splitting the data into train and test sets, we can start to train models. Python fastText-wrapper takes a filename and the name for the trained model file as inputs. We need to store training data into a file with the following format:

    __label__<class> <text>

We can use the following blocks to store train and test samples in a file in the master node:

    import fasttext
    TRAIN_FILE = 'data/fasttext_train.txt'
    TEST_FILE = 'data/fasttext_test.txt'
    MODEL_FILE = 'data/fasttext_language'
    # Storing train data
    with open(TRAIN_FILE, 'w') as fp:
        for i, row in enumerate(train_df.toLocalIterator()):
            fp.write('__label__%s %s\n' % (row['language_code'],
                                           row['text']))
            if i % 1000 == 0:
                print(i)
    
    # Storing test samples
    with open(TEST_FILE, 'w') as fp:
        for i, row in enumerate(test_df.toLocalIterator()):
            fp.write('__label__%s %s\n' % (row['language_code'],
                                           row['text']))
            if i % 1000 == 0:
                print(i)

After saving training samples in a training file we can start training the model. This can be done with fasttext.py as follows:

    # We are training a supervised model with default parameters
    # This the same as running fastText cli:
    # ./fasttext supervised -input <trainfile> -model <modelfile>
    model = fasttext.supervised(TRAIN_FILE, MODEL_FILE)

Classifying messages

To classify messages using the previously trained model we can write:

    import fasttext
    model = fasttext.load_model(MODEL_FILE)
    pred = model.predict(['Hello World!'])

To classify messages stored in a Spark DataFrame, we need to use Spark SQL’s User Defined Function (UDF). The UDF takes a function as an argument. Therefore we need to build a wrapper around the fasttext classifier which includes a trained model (model) and classification function (model.predict), and that returns only the class label as a string instead of a list of predictions.

Language classifier wrapper

Here is our wrapper which includes a fastText model and a function and returns the predicted language for a given message:

    # filename: fasttext_lang_classifier.py
    
    # We need fasttext to load the model and make predictions
    import fasttext
    
    # Load model (loads when this library is being imported)
    model = fasttext.load_model('data/model_fasttext.bin')
    
    # This is the function we use in UDF to predict the language of a given msg
    def predict_language(msg):
        pred = model.predict([msg])[0][0]
        pred = pred.replace('__label__', '')
        return pred

Running the language classifier

In order to use the custom fastText language wrapper library in Spark we use UDF as follows:

    # Import needed libraries
    from pyspark.sql.functions import col, udf
    # Import our custom fastText language classifier lib
    import fasttext_lang_classifier
    # Create a udf language classifier function
    udf_predict_language = udf(fasttext_lang_classifier.predict_language)

We now add our custom wrapper library and trained model in the Spark context:

    sc.addFile('data/model_fasttext.bin')
    sc.addPyFile('fasttext_lang_classifier.py')

It is important to add files to Spark context especially when you are running Spark scripts on many worker nodes (e.g. AWS EMR). It will sync files on all the worker nodes so that each worker can use the model to classify its partition of samples / rows.

Finally, we have a trained model in a UDF function that uses our custom library and data loaded in a Spark DF. We can now predict languages for the messages in the DataFrame as follows:

    messages = messages.withColumn('predicted_lang',
                                   udf_predict_language(col('text')))

The method spark.sql.withColumn creates a new column, predicted_lang, which stores the predicted language for each message. We have classified messages using our custom udf_predict_language function. It takes a column with messages to be classified as input, col('text'), and returns a column consisting of strings of the predicted languages. The result is stored in predicted_lang.

References

  1. fasText website: https://fasttext.cc/
  2. Bag of Tricks for Efficient Text Classification, Armand Joulin, Edouard Grave, Piotr Bojanowski, Tomas Mikolov, arXiv:1607.01759 [cs.CL] (2016)
  3. Which Encoding is the Best for Text Classification in Chinese, English, Japanese and Korean?, Xiang Zhang, Yann LeCun, (2017) arXiv:1708.02657 [cs.CL]

Author

  • Teemu Kinnunen
    Data Scientist