Saturday, April 10, 2021

Using BERT with Scikit Learn to do Text classification¶

BERT

Using BERT with Scikit Learn to do Text classification

Soumil Nitin Shah

Bachelor in Electronic Engineering | Masters in Electrical Engineering | Master in Computer Engineering |

Excellent experience of building scalable and high-performance Software Applications combining distinctive skill sets in Internet of Things (IoT), Machine Learning and Full Stack Web Development in Python.

Step 1:

Define Imports

In [2]:
try:
    import numpy as np
    import pandas as pd

    import torch
    import transformers as ppb # pytorch transformers
    
    
    from sklearn.linear_model import LogisticRegression
    from sklearn.model_selection import cross_val_score
    
    from sklearn.preprocessing import LabelEncoder
    from sklearn.model_selection import train_test_split
    
    from sklearn.naive_bayes import MultinomialNB
    
    import warnings

    import swifter
    import tqdm
    tqdm.pandas()

    warnings.filterwarnings('ignore')
except Exception  as e: pass

Reading Dataset

In [3]:
df = pd.read_csv('https://github.com/clairett/pytorch-sentiment-classification/raw/master/data/SST2/train.tsv', delimiter='\t', header=None)
df = df.dropna(how='all')
In [7]:
df.head(2)
Out[7]:
0 1
0 a stirring , funny and finally transporting re... 1
1 apparently reassembled from the cutting room f... 0
In [20]:
X = df[0]
Y = df[1]
encoder = LabelEncoder()
Y = encoder.fit_transform(Y)
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.3)
Pre Processing
In [21]:
class BertTokenizer(object):

    def __init__(self, text=[]):
        self.text = text

        # For DistilBERT:
        self.model_class, self.tokenizer_class, self.pretrained_weights = (ppb.DistilBertModel, ppb.DistilBertTokenizer, 'distilbert-base-uncased')

        # Load pretrained model/tokenizer
        self.tokenizer = self.tokenizer_class.from_pretrained(self.pretrained_weights)

        self.model = self.model_class.from_pretrained(self.pretrained_weights)

    def get(self):

        df = pd.DataFrame(data={"text":self.text})
        tokenized = df["text"].swifter.apply((lambda x: self.tokenizer.encode(x, add_special_tokens=True)))

        max_len = 0
        for i in tokenized.values:
            if len(i) > max_len:
                max_len = len(i)

        padded = np.array([i + [0]*(max_len-len(i)) for i in tokenized.values])

        attention_mask = np.where(padded != 0, 1, 0)
        input_ids = torch.tensor(padded)
        attention_mask = torch.tensor(attention_mask)

        with torch.no_grad(): last_hidden_states = self.model(input_ids, attention_mask=attention_mask)
        
        features = last_hidden_states[0][:, 0, :].numpy()

        return features
In [22]:
_instance =BertTokenizer(text=x_train)
tokens = _instance.get()

Model

In [23]:
lr_clf = LogisticRegression()
lr_clf.fit(tokens, y_train)
c:\python38\lib\site-packages\sklearn\linear_model\_logistic.py:762: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
Out[23]:
LogisticRegression()

Test

In [25]:
_instance =BertTokenizer(text=x_test)
tokensTest = _instance.get()

In [26]:
predicted = lr_clf.predict(tokensTest)
In [27]:
np.mean(predicted == y_test)
Out[27]:
0.846820809248555

References

1 comment:

  1. Pythonist: Using Bert With Scikit Learn To Do Text Classification¶ >>>>> Download Now

    >>>>> Download Full

    Pythonist: Using Bert With Scikit Learn To Do Text Classification¶ >>>>> Download LINK

    >>>>> Download Now

    Pythonist: Using Bert With Scikit Learn To Do Text Classification¶ >>>>> Download Full

    >>>>> Download LINK

    ReplyDelete

Learn How to Connect to the Glue Data Catalog using AWS Glue Iceberg REST endpoint

gluecat Learn How to Connect to the Glue Data Catalog using AWS Glue Iceberg REST e...