Clean up docker-compose and run Black formatter over entire codebase

This commit is contained in:
Keannu Bernasol 2024-10-30 22:09:58 +08:00
parent 6c232b3e89
commit 069aba80b1
60 changed files with 1946 additions and 1485 deletions

View file

@ -35,6 +35,7 @@ gunicorn = "*"
django-silk = "*" django-silk = "*"
django-redis = "*" django-redis = "*"
granian = "*" granian = "*"
black = "*"
[dev-packages] [dev-packages]

1492
Pipfile.lock generated

File diff suppressed because it is too large Load diff

View file

@ -6,11 +6,13 @@ from .models import CustomUser
class CustomUserAdmin(UserAdmin): class CustomUserAdmin(UserAdmin):
model = CustomUser model = CustomUser
list_display = ('id', 'is_active', 'user_group',) + UserAdmin.list_display list_display = (
"id",
"is_active",
"user_group",
) + UserAdmin.list_display
# Editable fields per instance # Editable fields per instance
fieldsets = UserAdmin.fieldsets + ( fieldsets = UserAdmin.fieldsets + ((None, {"fields": ("avatar",)}),)
(None, {'fields': ('avatar',)}),
)
admin.site.register(CustomUser, CustomUserAdmin) admin.site.register(CustomUser, CustomUserAdmin)

View file

@ -2,8 +2,8 @@ from django.apps import AppConfig
class AccountsConfig(AppConfig): class AccountsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField' default_auto_field = "django.db.models.BigAutoField"
name = 'accounts' name = "accounts"
def ready(self): def ready(self):
import accounts.signals import accounts.signals

View file

@ -13,38 +13,145 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies = [
('auth', '0012_alter_user_first_name_max_length'), ("auth", "0012_alter_user_first_name_max_length"),
('user_groups', '0001_initial'), ("user_groups", "0001_initial"),
] ]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='CustomUser', name="CustomUser",
fields=[ fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('password', models.CharField(max_length=128, verbose_name='password')), "id",
('last_login', models.DateTimeField(blank=True, null=True, verbose_name='last login')), models.BigAutoField(
('is_superuser', models.BooleanField(default=False, help_text='Designates that this user has all permissions without explicitly assigning them.', verbose_name='superuser status')), auto_created=True,
('username', models.CharField(error_messages={'unique': 'A user with that username already exists.'}, help_text='Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.', max_length=150, unique=True, validators=[django.contrib.auth.validators.UnicodeUsernameValidator()], verbose_name='username')), primary_key=True,
('first_name', models.CharField(blank=True, max_length=150, verbose_name='first name')), serialize=False,
('last_name', models.CharField(blank=True, max_length=150, verbose_name='last name')), verbose_name="ID",
('email', models.EmailField(blank=True, max_length=254, verbose_name='email address')), ),
('is_staff', models.BooleanField(default=False, help_text='Designates whether the user can log into this admin site.', verbose_name='staff status')), ),
('is_active', models.BooleanField(default=True, help_text='Designates whether this user should be treated as active. Unselect this instead of deleting accounts.', verbose_name='active')), ("password", models.CharField(max_length=128, verbose_name="password")),
('date_joined', models.DateTimeField(default=django.utils.timezone.now, verbose_name='date joined')), (
('avatar', django_resized.forms.ResizedImageField(crop=None, force_format='WEBP', keep_meta=True, null=True, quality=100, scale=None, size=[1920, 1080], upload_to='avatars/')), "last_login",
('onboarding', models.BooleanField(default=True)), models.DateTimeField(
('groups', models.ManyToManyField(blank=True, help_text='The groups this user belongs to. A user will get all permissions granted to each of their groups.', related_name='user_set', related_query_name='user', to='auth.group', verbose_name='groups')), blank=True, null=True, verbose_name="last login"
('user_group', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, to='user_groups.usergroup')), ),
('user_permissions', models.ManyToManyField(blank=True, help_text='Specific permissions for this user.', related_name='user_set', related_query_name='user', to='auth.permission', verbose_name='user permissions')), ),
(
"is_superuser",
models.BooleanField(
default=False,
help_text="Designates that this user has all permissions without explicitly assigning them.",
verbose_name="superuser status",
),
),
(
"username",
models.CharField(
error_messages={
"unique": "A user with that username already exists."
},
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
max_length=150,
unique=True,
validators=[
django.contrib.auth.validators.UnicodeUsernameValidator()
],
verbose_name="username",
),
),
(
"first_name",
models.CharField(
blank=True, max_length=150, verbose_name="first name"
),
),
(
"last_name",
models.CharField(
blank=True, max_length=150, verbose_name="last name"
),
),
(
"email",
models.EmailField(
blank=True, max_length=254, verbose_name="email address"
),
),
(
"is_staff",
models.BooleanField(
default=False,
help_text="Designates whether the user can log into this admin site.",
verbose_name="staff status",
),
),
(
"is_active",
models.BooleanField(
default=True,
help_text="Designates whether this user should be treated as active. Unselect this instead of deleting accounts.",
verbose_name="active",
),
),
(
"date_joined",
models.DateTimeField(
default=django.utils.timezone.now, verbose_name="date joined"
),
),
(
"avatar",
django_resized.forms.ResizedImageField(
crop=None,
force_format="WEBP",
keep_meta=True,
null=True,
quality=100,
scale=None,
size=[1920, 1080],
upload_to="avatars/",
),
),
("onboarding", models.BooleanField(default=True)),
(
"groups",
models.ManyToManyField(
blank=True,
help_text="The groups this user belongs to. A user will get all permissions granted to each of their groups.",
related_name="user_set",
related_query_name="user",
to="auth.group",
verbose_name="groups",
),
),
(
"user_group",
models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="user_groups.usergroup",
),
),
(
"user_permissions",
models.ManyToManyField(
blank=True,
help_text="Specific permissions for this user.",
related_name="user_set",
related_query_name="user",
to="auth.permission",
verbose_name="user permissions",
),
),
], ],
options={ options={
'verbose_name': 'user', "verbose_name": "user",
'verbose_name_plural': 'users', "verbose_name_plural": "users",
'abstract': False, "abstract": False,
}, },
managers=[ managers=[
('objects', django.contrib.auth.models.UserManager()), ("objects", django.contrib.auth.models.UserManager()),
], ],
), ),
] ]

View file

@ -15,14 +15,16 @@ class CustomUser(AbstractUser):
# is_admin inherited from base user class # is_admin inherited from base user class
avatar = ResizedImageField( avatar = ResizedImageField(
null=True, force_format="WEBP", quality=100, upload_to='avatars/') null=True, force_format="WEBP", quality=100, upload_to="avatars/"
)
# Used for onboarding processes # Used for onboarding processes
# Set this to False later on once the user makes actions # Set this to False later on once the user makes actions
onboarding = models.BooleanField(default=True) onboarding = models.BooleanField(default=True)
user_group = models.ForeignKey( user_group = models.ForeignKey(
'user_groups.UserGroup', on_delete=models.SET_NULL, null=True) "user_groups.UserGroup", on_delete=models.SET_NULL, null=True
)
@property @property
def group_member(self): def group_member(self):
@ -57,4 +59,4 @@ class CustomUser(AbstractUser):
@property @property
def admin_url(self): def admin_url(self):
return reverse('admin:users_customuser_change', args=(self.pk,)) return reverse("admin:users_customuser_change", args=(self.pk,))

View file

@ -8,13 +8,14 @@ from django.core.cache import cache
from django.core import exceptions as django_exceptions from django.core import exceptions as django_exceptions
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from django.contrib.auth.password_validation import validate_password from django.contrib.auth.password_validation import validate_password
# There can be multiple subject instances with the same name, only differing in course, year level, and semester. We filter them here # There can be multiple subject instances with the same name, only differing in course, year level, and semester. We filter them here
class SimpleCustomUserSerializer(ModelSerializer): class SimpleCustomUserSerializer(ModelSerializer):
class Meta(BaseUserSerializer.Meta): class Meta(BaseUserSerializer.Meta):
model = CustomUser model = CustomUser
fields = ('id', 'username', 'email', 'full_name') fields = ("id", "username", "email", "full_name")
class CustomUserSerializer(BaseUserSerializer): class CustomUserSerializer(BaseUserSerializer):
@ -22,19 +23,36 @@ class CustomUserSerializer(BaseUserSerializer):
class Meta(BaseUserSerializer.Meta): class Meta(BaseUserSerializer.Meta):
model = CustomUser model = CustomUser
fields = ('id', 'username', 'email', 'avatar', 'first_name', fields = (
'last_name', 'user_group', 'group_member', 'group_owner') "id",
read_only_fields = ('id', 'username', 'email', 'user_group', "username",
'group_member', 'group_owner') "email",
"avatar",
"first_name",
"is_new",
"last_name",
"user_group",
"group_member",
"group_owner",
)
read_only_fields = (
"id",
"username",
"email",
"user_group",
"group_member",
"group_owner",
)
def to_representation(self, instance): def to_representation(self, instance):
representation = super().to_representation(instance) representation = super().to_representation(instance)
representation['user_group'] = SimpleUserGroupSerializer( representation["user_group"] = SimpleUserGroupSerializer(
instance.user_group, many=False).data instance.user_group, many=False
).data
return representation return representation
def update(self, instance, validated_data): def update(self, instance, validated_data):
cache.delete(f'user:{instance.id}') cache.delete(f"user:{instance.id}")
return super().update(instance, validated_data) return super().update(instance, validated_data)
@ -42,16 +60,18 @@ class UserRegistrationSerializer(serializers.ModelSerializer):
email = serializers.EmailField(required=True) email = serializers.EmailField(required=True)
username = serializers.CharField(required=True) username = serializers.CharField(required=True)
password = serializers.CharField( password = serializers.CharField(
write_only=True, style={'input_type': 'password', 'placeholder': 'Password'}) write_only=True, style={"input_type": "password", "placeholder": "Password"}
)
first_name = serializers.CharField( first_name = serializers.CharField(
required=True, allow_blank=False, allow_null=False) required=True, allow_blank=False, allow_null=False
)
last_name = serializers.CharField( last_name = serializers.CharField(
required=True, allow_blank=False, allow_null=False) required=True, allow_blank=False, allow_null=False
)
class Meta: class Meta:
model = CustomUser model = CustomUser
fields = ['email', 'username', 'password', fields = ["email", "username", "password", "first_name", "last_name"]
'first_name', 'last_name']
def validate(self, attrs): def validate(self, attrs):
user_attrs = attrs.copy() user_attrs = attrs.copy()
@ -69,14 +89,15 @@ class UserRegistrationSerializer(serializers.ModelSerializer):
raise serializers.ValidationError({"password": errors}) raise serializers.ValidationError({"password": errors})
if self.Meta.model.objects.filter(username=attrs.get("username")).exists(): if self.Meta.model.objects.filter(username=attrs.get("username")).exists():
raise serializers.ValidationError( raise serializers.ValidationError(
"A user with that username already exists.") "A user with that username already exists."
)
return super().validate(attrs) return super().validate(attrs)
def create(self, validated_data): def create(self, validated_data):
user = self.Meta.model(**validated_data) user = self.Meta.model(**validated_data)
user.username = validated_data['username'] user.username = validated_data["username"]
user.is_active = False user.is_active = False
user.set_password(validated_data['password']) user.set_password(validated_data["password"])
user.save() user.save()
return user return user

View file

@ -12,38 +12,37 @@ import json
@receiver(post_migrate) @receiver(post_migrate)
def create_users(sender, **kwargs): def create_users(sender, **kwargs):
if sender.name == "accounts": if sender.name == "accounts":
with open(os.path.join(ROOT_DIR, 'seed_data.json'), "r") as f: with open(os.path.join(ROOT_DIR, "seed_data.json"), "r") as f:
seed_data = json.loads(f.read()) seed_data = json.loads(f.read())
for user in seed_data['users']: for user in seed_data["users"]:
USER = CustomUser.objects.filter( USER = CustomUser.objects.filter(email=user["email"]).first()
email=user['email']).first()
if not USER: if not USER:
if user['password'] == 'USE_REGULAR': if user["password"] == "USE_REGULAR":
password = get_secret('SEED_DATA_PASSWORD') password = get_secret("SEED_DATA_PASSWORD")
elif user['password'] == 'USE_ADMIN': elif user["password"] == "USE_ADMIN":
password = get_secret('SEED_DATA_ADMIN_PASSWORD') password = get_secret("SEED_DATA_ADMIN_PASSWORD")
else: else:
password = user['password'] password = user["password"]
if (user['is_superuser'] == True): if user["is_superuser"] == True:
# Admin users are created regardless of SEED_DATA value # Admin users are created regardless of SEED_DATA value
USER = CustomUser.objects.create_superuser( USER = CustomUser.objects.create_superuser(
username=user['username'], username=user["username"],
email=user['email'], email=user["email"],
password=password, password=password,
) )
print('Created Superuser:', user['email']) print("Created Superuser:", user["email"])
else: else:
# Only create non-admin users if SEED_DATA=True # Only create non-admin users if SEED_DATA=True
if SEED_DATA: if SEED_DATA:
USER = CustomUser.objects.create_user( USER = CustomUser.objects.create_user(
username=user['email'], username=user["email"],
email=user['email'], email=user["email"],
password=password, password=password,
) )
print('Created User:', user['email']) print("Created User:", user["email"])
USER.first_name = user['first_name'] USER.first_name = user["first_name"]
USER.last_name = user['last_name'] USER.last_name = user["last_name"]
USER.is_active = True USER.is_active = True
USER.save() USER.save()
@ -51,53 +50,57 @@ def create_users(sender, **kwargs):
@receiver(post_migrate) @receiver(post_migrate)
def create_celery_beat_schedules(sender, **kwargs): def create_celery_beat_schedules(sender, **kwargs):
if sender.name == "django_celery_beat": if sender.name == "django_celery_beat":
with open(os.path.join(ROOT_DIR, 'seed_data.json'), "r") as f: with open(os.path.join(ROOT_DIR, "seed_data.json"), "r") as f:
seed_data = json.loads(f.read()) seed_data = json.loads(f.read())
# Creating Schedules # Creating Schedules
for schedule in seed_data['schedules']: for schedule in seed_data["schedules"]:
if schedule['type'] == 'crontab': if schedule["type"] == "crontab":
# Check if Schedule already exists # Check if Schedule already exists
SCHEDULE = CrontabSchedule.objects.filter(minute=schedule['minute'], SCHEDULE = CrontabSchedule.objects.filter(
hour=schedule['hour'], minute=schedule["minute"],
day_of_week=schedule['day_of_week'], hour=schedule["hour"],
day_of_month=schedule['day_of_month'], day_of_week=schedule["day_of_week"],
month_of_year=schedule['month_of_year'], day_of_month=schedule["day_of_month"],
timezone=schedule['timezone'] month_of_year=schedule["month_of_year"],
timezone=schedule["timezone"],
).first() ).first()
# If it does not exist, create a new Schedule # If it does not exist, create a new Schedule
if not SCHEDULE: if not SCHEDULE:
SCHEDULE = CrontabSchedule.objects.create( SCHEDULE = CrontabSchedule.objects.create(
minute=schedule['minute'], minute=schedule["minute"],
hour=schedule['hour'], hour=schedule["hour"],
day_of_week=schedule['day_of_week'], day_of_week=schedule["day_of_week"],
day_of_month=schedule['day_of_month'], day_of_month=schedule["day_of_month"],
month_of_year=schedule['month_of_year'], month_of_year=schedule["month_of_year"],
timezone=schedule['timezone'] timezone=schedule["timezone"],
) )
print( print(
f'Created Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute}') f"Created Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute}"
)
else: else:
print( print(
f'Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute} already exists') f"Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute} already exists"
for task in seed_data['scheduled_tasks']: )
TASK = PeriodicTask.objects.filter(name=task['name']).first() for task in seed_data["scheduled_tasks"]:
TASK = PeriodicTask.objects.filter(name=task["name"]).first()
if not TASK: if not TASK:
if task['schedule']['type'] == 'crontab': if task["schedule"]["type"] == "crontab":
SCHEDULE = CrontabSchedule.objects.filter(minute=task['schedule']['minute'], SCHEDULE = CrontabSchedule.objects.filter(
hour=task['schedule']['hour'], minute=task["schedule"]["minute"],
day_of_week=task['schedule']['day_of_week'], hour=task["schedule"]["hour"],
day_of_month=task['schedule']['day_of_month'], day_of_week=task["schedule"]["day_of_week"],
month_of_year=task['schedule']['month_of_year'], day_of_month=task["schedule"]["day_of_month"],
timezone=task['schedule']['timezone'] month_of_year=task["schedule"]["month_of_year"],
timezone=task["schedule"]["timezone"],
).first() ).first()
TASK = PeriodicTask.objects.create( TASK = PeriodicTask.objects.create(
crontab=SCHEDULE, crontab=SCHEDULE,
name=task['name'], name=task["name"],
task=task['task'], task=task["task"],
enabled=task['enabled'] enabled=task["enabled"],
) )
print(f'Created Periodic Task: {TASK.name}') print(f"Created Periodic Task: {TASK.name}")
else: else:
raise Exception('Schedule for Periodic Task not found') raise Exception("Schedule for Periodic Task not found")
else: else:
print(f'Periodic Task: {TASK.name} already exists') print(f"Periodic Task: {TASK.name} already exists")

View file

@ -4,20 +4,26 @@ from celery import shared_task
@shared_task @shared_task
def get_paying_users(): def get_paying_users():
from subscriptions.models import UserSubscription from subscriptions.models import UserSubscription
# Get a list of user subscriptions # Get a list of user subscriptions
active_subscriptions = UserSubscription.objects.filter( active_subscriptions = UserSubscription.objects.filter(valid=True).distinct("user")
valid=True).distinct('user')
# Get paying users # Get paying users
active_users = [] active_users = []
# Paying regular users # Paying regular users
active_users += [ active_users += [
subscription.user.id for subscription in active_subscriptions if subscription.user is not None and subscription.user.user_group is None] subscription.user.id
for subscription in active_subscriptions
if subscription.user is not None and subscription.user.user_group is None
]
# Paying users within groups # Paying users within groups
active_users += [ active_users += [
subscription.user_group.members for subscription in active_subscriptions if subscription.user_group is not None and subscription.user is None] subscription.user_group.members
for subscription in active_subscriptions
if subscription.user_group is not None and subscription.user is None
]
# Return paying users # Return paying users
return active_users return active_users

View file

@ -3,10 +3,10 @@ from rest_framework.routers import DefaultRouter
from accounts import views from accounts import views
router = DefaultRouter() router = DefaultRouter()
router.register(r'users', views.CustomUserViewSet, basename='users') router.register(r"users", views.CustomUserViewSet, basename="users")
urlpatterns = [ urlpatterns = [
path('', include(router.urls)), path("", include(router.urls)),
path('', include('djoser.urls')), path("", include("djoser.urls")),
path('', include('djoser.urls.jwt')), path("", include("djoser.urls.jwt")),
] ]

View file

@ -1,4 +1,3 @@
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
import re import re
@ -6,9 +5,10 @@ import re
class UppercaseValidator(object): class UppercaseValidator(object):
def validate(self, password, user=None): def validate(self, password, user=None):
if not re.findall('[A-Z]', password): if not re.findall("[A-Z]", password):
raise ValidationError( raise ValidationError(
_("The password must contain at least 1 uppercase letter (A-Z).")) _("The password must contain at least 1 uppercase letter (A-Z).")
)
def get_help_text(self): def get_help_text(self):
return _("Your password must contain at least 1 uppercase letter (A-Z).") return _("Your password must contain at least 1 uppercase letter (A-Z).")
@ -16,9 +16,10 @@ class UppercaseValidator(object):
class LowercaseValidator(object): class LowercaseValidator(object):
def validate(self, password, user=None): def validate(self, password, user=None):
if not re.findall('[a-z]', password): if not re.findall("[a-z]", password):
raise ValidationError( raise ValidationError(
_("The password must contain at least 1 lowercase letter (a-z).")) _("The password must contain at least 1 lowercase letter (a-z).")
)
def get_help_text(self): def get_help_text(self):
return _("Your password must contain at least 1 lowercase letter (a-z).") return _("Your password must contain at least 1 lowercase letter (a-z).")
@ -26,19 +27,25 @@ class LowercaseValidator(object):
class SpecialCharacterValidator(object): class SpecialCharacterValidator(object):
def validate(self, password, user=None): def validate(self, password, user=None):
if not re.findall('[@#$%^&*()_+/\<>;:!?]', password): if not re.findall("[@#$%^&*()_+/\<>;:!?]", password):
raise ValidationError( raise ValidationError(
_("The password must contain at least 1 special character (@, #, $, etc.).")) _(
"The password must contain at least 1 special character (@, #, $, etc.)."
)
)
def get_help_text(self): def get_help_text(self):
return _("Your password must contain at least 1 special character (@, #, $, etc.).") return _(
"Your password must contain at least 1 special character (@, #, $, etc.)."
)
class NumberValidator(object): class NumberValidator(object):
def validate(self, password, user=None): def validate(self, password, user=None):
if not any(char.isdigit() for char in password): if not any(char.isdigit() for char in password):
raise ValidationError( raise ValidationError(
_("The password must contain at least one numerical digit (0-9).")) _("The password must contain at least one numerical digit (0-9).")
)
def get_help_text(self): def get_help_text(self):
return _("Your password must contain at least numerical digit (0-9).") return _("Your password must contain at least numerical digit (0-9).")

View file

@ -22,7 +22,7 @@ class CustomUserViewSet(DjoserUserViewSet):
user = self.request.user user = self.request.user
# If user is admin, show all active users # If user is admin, show all active users
if user.is_superuser: if user.is_superuser:
key = 'users' key = "users"
# Get cache # Get cache
queryset = cache.get(key) queryset = cache.get(key)
# Set cache if stale or does not exist # Set cache if stale or does not exist
@ -31,18 +31,17 @@ class CustomUserViewSet(DjoserUserViewSet):
cache.set(key, queryset, 60 * 60) cache.set(key, queryset, 60 * 60)
return queryset return queryset
elif not user.user_group: elif not user.user_group:
key = f'user:{user.id}' key = f"user:{user.id}"
queryset = cache.get(key) queryset = cache.get(key)
if not queryset: if not queryset:
queryset = CustomUser.objects.filter(is_active=True) queryset = CustomUser.objects.filter(is_active=True)
cache.set(key, queryset, 60 * 60) cache.set(key, queryset, 60 * 60)
return queryset return queryset
elif user.user_group: elif user.user_group:
key = f'usergroup_users:{user.user_group.id}' key = f"usergroup_users:{user.user_group.id}"
queryset = cache.get(key) queryset = cache.get(key)
if not queryset: if not queryset:
queryset = CustomUser.objects.filter( queryset = CustomUser.objects.filter(user_group=user.user_group)
user_group=user.user_group)
cache.set(key, queryset, 60 * 60) cache.set(key, queryset, 60 * 60)
return queryset return queryset
else: else:
@ -52,10 +51,10 @@ class CustomUserViewSet(DjoserUserViewSet):
user = self.request.user user = self.request.user
# Clear cache # Clear cache
cache.delete(f'users') cache.delete(f"users")
cache.delete(f'user:{user.id}') cache.delete(f"user:{user.id}")
if user.user_group: if user.user_group:
cache.delete(f'usergroup_users:{user.user_group.id}') cache.delete(f"usergroup_users:{user.user_group.id}")
super().perform_update(serializer, *args, **kwargs) super().perform_update(serializer, *args, **kwargs)
user = serializer.instance user = serializer.instance
@ -84,16 +83,18 @@ class CustomUserViewSet(DjoserUserViewSet):
settings.EMAIL.confirmation(self.request, context).send(to) settings.EMAIL.confirmation(self.request, context).send(to)
# Clear cache # Clear cache
cache.delete('users') cache.delete("users")
cache.delete(f'user:{user.id}') cache.delete(f"user:{user.id}")
if user.user_group: if user.user_group:
cache.delete(f'usergroup_users:{user.user_group.id}') cache.delete(f"usergroup_users:{user.user_group.id}")
except Exception as e: except Exception as e:
print('Warning: Unable to send email') print("Warning: Unable to send email")
print(e) print(e)
@action(methods=['post'], detail=False, url_path='activation', url_name='activation') @action(
methods=["post"], detail=False, url_path="activation", url_name="activation"
)
def activation(self, request, *args, **kwargs): def activation(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data) serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
@ -103,16 +104,16 @@ class CustomUserViewSet(DjoserUserViewSet):
# Construct a response with user's first name, last name, and email # Construct a response with user's first name, last name, and email
user_data = { user_data = {
'first_name': user.first_name, "first_name": user.first_name,
'last_name': user.last_name, "last_name": user.last_name,
'email': user.email, "email": user.email,
'username': user.username "username": user.username,
} }
# Clear cache # Clear cache
cache.delete('users') cache.delete("users")
cache.delete(f'user:{user.id}') cache.delete(f"user:{user.id}")
if user.user_group: if user.user_group:
cache.delete(f'usergroup_users:{user.user_group.id}') cache.delete(f"usergroup_users:{user.user_group.id}")
return Response(user_data, status=status.HTTP_200_OK) return Response(user_data, status=status.HTTP_200_OK)

View file

@ -1,27 +1,31 @@
from django.conf.urls.static import static from django.conf.urls.static import static
from django.contrib.staticfiles.urls import staticfiles_urlpatterns from django.contrib.staticfiles.urls import staticfiles_urlpatterns
from django.urls import path, include from django.urls import path, include
from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView from drf_spectacular.views import (
SpectacularAPIView,
SpectacularRedocView,
SpectacularSwaggerView,
)
from django.contrib import admin from django.contrib import admin
from config.settings import DEBUG, SERVE_MEDIA, MEDIA_ROOT from config.settings import DEBUG, SERVE_MEDIA, MEDIA_ROOT
urlpatterns = [ urlpatterns = [
path('accounts/', include('accounts.urls')), path("accounts/", include("accounts.urls")),
path('subscriptions/', include('subscriptions.urls')), path("subscriptions/", include("subscriptions.urls")),
path('notifications/', include('notifications.urls')), path("notifications/", include("notifications.urls")),
path('billing/', include('billing.urls')), path("billing/", include("billing.urls")),
path('stripe/', include('payments.urls')), path("stripe/", include("payments.urls")),
path('admin/', admin.site.urls), path("admin/", admin.site.urls),
path('schema/', SpectacularAPIView.as_view(), name='schema'), path("schema/", SpectacularAPIView.as_view(), name="schema"),
path('swagger/', path(
SpectacularSwaggerView.as_view(url_name='schema'), name='swagger-ui'), "swagger/", SpectacularSwaggerView.as_view(url_name="schema"), name="swagger-ui"
path('redoc/', ),
SpectacularRedocView.as_view(url_name='schema'), name='redoc'), path("redoc/", SpectacularRedocView.as_view(url_name="schema"), name="redoc"),
] ]
# URLs for local development # URLs for local development
if DEBUG and SERVE_MEDIA: if DEBUG and SERVE_MEDIA:
urlpatterns += staticfiles_urlpatterns() urlpatterns += staticfiles_urlpatterns()
urlpatterns += static( urlpatterns += static("media/", document_root=MEDIA_ROOT)
'media/', document_root=MEDIA_ROOT)
if DEBUG: if DEBUG:
urlpatterns += [path('silk/', include('silk.urls', namespace='silk'))] urlpatterns += [path("silk/", include("silk.urls", namespace="silk"))]

View file

@ -2,6 +2,5 @@ from django.urls import path
from billing import views from billing import views
urlpatterns = [ urlpatterns = [
path('', path("", views.BillingHistoryView.as_view()),
views.BillingHistoryView.as_view()),
] ]

View file

@ -24,7 +24,7 @@ class BillingHistoryView(APIView):
email = requesting_user.email email = requesting_user.email
# Check cache # Check cache
key = f'billing_user:{requesting_user.id}' key = f"billing_user:{requesting_user.id}"
billing_history = cache.get(key) billing_history = cache.get(key)
if not billing_history: if not billing_history:
@ -39,21 +39,23 @@ class BillingHistoryView(APIView):
if len(customers.data) > 0: if len(customers.data) > 0:
# Retrieve the customer's charges (billing history) # Retrieve the customer's charges (billing history)
charges = stripe.Charge.list( charges = stripe.Charge.list(limit=10, customer=customer.id)
limit=10, customer=customer.id)
# Prepare the response # Prepare the response
billing_history = [ billing_history = [
{ {
'email': charge['billing_details']['email'], "email": charge["billing_details"]["email"],
'amount_charged': int(charge['amount']/100), "amount_charged": int(charge["amount"] / 100),
'paid': charge['paid'], "paid": charge["paid"],
'refunded': int(charge['amount_refunded']/100) > 0, "refunded": int(charge["amount_refunded"] / 100) > 0,
'amount_refunded': int(charge['amount_refunded']/100), "amount_refunded": int(charge["amount_refunded"] / 100),
'last_4': charge['payment_method_details']['card']['last4'], "last_4": charge["payment_method_details"]["card"]["last4"],
'receipt_link': charge['receipt_url'], "receipt_link": charge["receipt_url"],
'timestamp': datetime.fromtimestamp(charge['created']).strftime("%m-%d-%Y %I:%M %p"), "timestamp": datetime.fromtimestamp(
} for charge in charges.auto_paging_iter() charge["created"]
).strftime("%m-%d-%Y %I:%M %p"),
}
for charge in charges.auto_paging_iter()
] ]
cache.set(key, billing_history, 60 * 60) cache.set(key, billing_history, 60 * 60)

View file

@ -1,3 +1,3 @@
from .celery import app as celery_app from .celery import app as celery_app
__all__ = ('celery_app',) __all__ = ("celery_app",)

View file

@ -11,6 +11,6 @@ import os
from django.core.asgi import get_asgi_application from django.core.asgi import get_asgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
application = get_asgi_application() application = get_asgi_application()

View file

@ -3,15 +3,15 @@ import os
# Set the default Django settings module for the 'celery' program. # Set the default Django settings module for the 'celery' program.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
app = Celery('config') app = Celery("config")
# Using a string here means the worker doesn't have to serialize # Using a string here means the worker doesn't have to serialize
# the configuration object to child processes. # the configuration object to child processes.
# - namespace='CELERY' means all celery-related configuration keys # - namespace='CELERY' means all celery-related configuration keys
# should have a `CELERY_` prefix. # should have a `CELERY_` prefix.
app.config_from_object('django.conf:settings', namespace='CELERY') app.config_from_object("django.conf:settings", namespace="CELERY")
# Load task modules from all registered Django apps. # Load task modules from all registered Django apps.
app.autodiscover_tasks() app.autodiscover_tasks()

View file

@ -10,9 +10,9 @@ BASE_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = Path(__file__).resolve().parent.parent.parent ROOT_DIR = Path(__file__).resolve().parent.parent.parent
# If you're hosting this with a secret provider, have this set to True # If you're hosting this with a secret provider, have this set to True
USE_VAULT = bool(os.getenv('USE_VAULT', False) == 'True') USE_VAULT = bool(os.getenv("USE_VAULT", False) == "True")
# Have this set to True to serve media and static contents directly via Django # Have this set to True to serve media and static contents directly via Django
SERVE_MEDIA = bool(os.getenv('SERVE_MEDIA', False) == 'True') SERVE_MEDIA = bool(os.getenv("SERVE_MEDIA", False) == "True")
load_dotenv(find_dotenv()) load_dotenv(find_dotenv())
@ -35,98 +35,97 @@ def get_secret(secret_name):
# URL Prefixes # URL Prefixes
URL_SCHEME = 'https' if (get_secret('USE_HTTPS') == 'True') else 'http' URL_SCHEME = "https" if (get_secret("USE_HTTPS") == "True") else "http"
# Backend # Backend
BACKEND_ADDRESS = get_secret('BACKEND_ADDRESS') BACKEND_ADDRESS = get_secret("BACKEND_ADDRESS")
BACKEND_PORT = get_secret('BACKEND_PORT') BACKEND_PORT = get_secret("BACKEND_PORT")
# Frontend # Frontend
FRONTEND_ADDRESS = get_secret('FRONTEND_ADDRESS') FRONTEND_ADDRESS = get_secret("FRONTEND_ADDRESS")
FRONTEND_PORT = get_secret('FRONTEND_PORT') FRONTEND_PORT = get_secret("FRONTEND_PORT")
ALLOWED_HOSTS = ['*'] ALLOWED_HOSTS = ["*"]
CSRF_TRUSTED_ORIGINS = [ CSRF_TRUSTED_ORIGINS = [
# Frontend # Frontend
f'{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}', f"{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}",
f'{URL_SCHEME}://{FRONTEND_ADDRESS}', # For external domains f"{URL_SCHEME}://{FRONTEND_ADDRESS}", # For external domains
# Backend # Backend
f'{URL_SCHEME}://{BACKEND_ADDRESS}:{BACKEND_PORT}', f"{URL_SCHEME}://{BACKEND_ADDRESS}:{BACKEND_PORT}",
f'{URL_SCHEME}://{BACKEND_ADDRESS}' # For external domains f"{URL_SCHEME}://{BACKEND_ADDRESS}", # For external domains
# You can also set up https://*.name.xyz for wildcards here # You can also set up https://*.name.xyz for wildcards here
] ]
# SECURITY WARNING: don't run with debug turned on in production! # SECURITY WARNING: don't run with debug turned on in production!
DEBUG = (get_secret('BACKEND_DEBUG') == 'True') DEBUG = get_secret("BACKEND_DEBUG") == "True"
# Determines whether or not to insert test data within tables # Determines whether or not to insert test data within tables
SEED_DATA = (get_secret('SEED_DATA') == 'True') SEED_DATA = get_secret("SEED_DATA") == "True"
# SECURITY WARNING: keep the secret key used in production secret! # SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = get_secret('SECRET_KEY') SECRET_KEY = get_secret("SECRET_KEY")
# Selenium Config # Selenium Config
# Initiate CAPTCHA solver in test mode # Initiate CAPTCHA solver in test mode
CAPTCHA_TESTING = (get_secret('CAPTCHA_TESTING') == 'True') CAPTCHA_TESTING = get_secret("CAPTCHA_TESTING") == "True"
# If using Selenium and/or the provided CAPTCHA solver, determines whether or not to use proxies # If using Selenium and/or the provided CAPTCHA solver, determines whether or not to use proxies
USE_PROXY = (get_secret('USE_PROXY') == 'True') USE_PROXY = get_secret("USE_PROXY") == "True"
# Stripe (For payments) # Stripe (For payments)
STRIPE_SECRET_KEY = get_secret( STRIPE_SECRET_KEY = get_secret("STRIPE_SECRET_KEY")
"STRIPE_SECRET_KEY") STRIPE_SECRET_WEBHOOK = get_secret("STRIPE_SECRET_WEBHOOK")
STRIPE_SECRET_WEBHOOK = get_secret('STRIPE_SECRET_WEBHOOK') STRIPE_CHECKOUT_URL = f""
STRIPE_CHECKOUT_URL = f''
# Email credentials # Email credentials
EMAIL_HOST = get_secret('EMAIL_HOST') EMAIL_HOST = get_secret("EMAIL_HOST")
EMAIL_HOST_USER = get_secret('EMAIL_HOST_USER') EMAIL_HOST_USER = get_secret("EMAIL_HOST_USER")
EMAIL_HOST_PASSWORD = get_secret('EMAIL_HOST_PASSWORD') EMAIL_HOST_PASSWORD = get_secret("EMAIL_HOST_PASSWORD")
EMAIL_PORT = get_secret('EMAIL_PORT') EMAIL_PORT = get_secret("EMAIL_PORT")
EMAIL_USE_TLS = (get_secret('EMAIL_USE_TLS') == 'True') EMAIL_USE_TLS = get_secret("EMAIL_USE_TLS") == "True"
EMAIL_ADDRESS = (get_secret('EMAIL_ADDRESS') == 'True') EMAIL_ADDRESS = get_secret("EMAIL_ADDRESS") == "True"
# Application definition # Application definition
INSTALLED_APPS = [ INSTALLED_APPS = [
'config', "config",
'unfold', "unfold",
'unfold.contrib.filters', "unfold.contrib.filters",
'unfold.contrib.simple_history', "unfold.contrib.simple_history",
'django.contrib.admin', "django.contrib.admin",
'django.contrib.auth', "django.contrib.auth",
'django.contrib.contenttypes', "django.contrib.contenttypes",
'django.contrib.sessions', "django.contrib.sessions",
'django.contrib.messages', "django.contrib.messages",
'django.contrib.staticfiles', "django.contrib.staticfiles",
'storages', "storages",
'django_extensions', "django_extensions",
'rest_framework', "rest_framework",
'rest_framework_simplejwt', "rest_framework_simplejwt",
'django_celery_results', "django_celery_results",
'django_celery_beat', "django_celery_beat",
'simple_history', "simple_history",
'djoser', "djoser",
'corsheaders', "corsheaders",
'drf_spectacular', "drf_spectacular",
'drf_spectacular_sidecar', "drf_spectacular_sidecar",
'webdriver', "webdriver",
'accounts', "accounts",
'user_groups', "user_groups",
'subscriptions', "subscriptions",
'payments', "payments",
'billing', "billing",
'emails', "emails",
'notifications', "notifications",
'search_results' "search_results",
] ]
if DEBUG: if DEBUG:
INSTALLED_APPS += ['silk'] INSTALLED_APPS += ["silk"]
MIDDLEWARE = [ MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware', "django.middleware.security.SecurityMiddleware",
"silk.middleware.SilkyMiddleware", "silk.middleware.SilkyMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware", "django.contrib.sessions.middleware.SessionMiddleware",
"corsheaders.middleware.CorsMiddleware", "corsheaders.middleware.CorsMiddleware",
'django.middleware.common.CommonMiddleware', "django.middleware.common.CommonMiddleware",
'django.middleware.csrf.CsrfViewMiddleware', "django.middleware.csrf.CsrfViewMiddleware",
'django.contrib.auth.middleware.AuthenticationMiddleware', "django.contrib.auth.middleware.AuthenticationMiddleware",
'django.contrib.messages.middleware.MessageMiddleware', "django.contrib.messages.middleware.MessageMiddleware",
'django.middleware.clickjacking.XFrameOptionsMiddleware', "django.middleware.clickjacking.XFrameOptionsMiddleware",
] ]
DJANGO_LOG_LEVEL = "DEBUG" DJANGO_LOG_LEVEL = "DEBUG"
# Enables VS Code debugger to break on raised exceptions # Enables VS Code debugger to break on raised exceptions
@ -146,111 +145,101 @@ if DEBUG:
} }
else: else:
MIDDLEWARE = [ MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware', "django.middleware.security.SecurityMiddleware",
"whitenoise.middleware.WhiteNoiseMiddleware", "whitenoise.middleware.WhiteNoiseMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware", "django.contrib.sessions.middleware.SessionMiddleware",
"corsheaders.middleware.CorsMiddleware", "corsheaders.middleware.CorsMiddleware",
'django.middleware.common.CommonMiddleware', "django.middleware.common.CommonMiddleware",
'django.middleware.csrf.CsrfViewMiddleware', "django.middleware.csrf.CsrfViewMiddleware",
'django.contrib.auth.middleware.AuthenticationMiddleware', "django.contrib.auth.middleware.AuthenticationMiddleware",
'django.contrib.messages.middleware.MessageMiddleware', "django.contrib.messages.middleware.MessageMiddleware",
'django.middleware.clickjacking.XFrameOptionsMiddleware', "django.middleware.clickjacking.XFrameOptionsMiddleware",
] ]
# Static files (CSS, JavaScript, Images) # Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/4.2/howto/static-files/ # https://docs.djangoproject.com/en/4.2/howto/static-files/
ROOT_URLCONF = 'config.urls' ROOT_URLCONF = "config.urls"
if SERVE_MEDIA: if SERVE_MEDIA:
# Cloud Storage Settings # Cloud Storage Settings
# This is assuming you use the same bucket for media and static containers # This is assuming you use the same bucket for media and static containers
CLOUD_BUCKET = get_secret('CLOUD_BUCKET') CLOUD_BUCKET = get_secret("CLOUD_BUCKET")
MEDIA_CONTAINER = get_secret('MEDIA_CONTAINER') MEDIA_CONTAINER = get_secret("MEDIA_CONTAINER")
STATIC_CONTAINER = get_secret('STATIC_CONTAINER') STATIC_CONTAINER = get_secret("STATIC_CONTAINER")
MEDIA_URL = f'https://{CLOUD_BUCKET}/{MEDIA_CONTAINER}/' MEDIA_URL = f"https://{CLOUD_BUCKET}/{MEDIA_CONTAINER}/"
MEDIA_ROOT = f'https://{CLOUD_BUCKET}/' MEDIA_ROOT = f"https://{CLOUD_BUCKET}/"
STATIC_URL = f'https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/' STATIC_URL = f"https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/"
STATIC_ROOT = f'https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/' STATIC_ROOT = f"https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/"
# Consult django-storages documentation when filling in these values. This will vary depending on your cloud service provider # Consult django-storages documentation when filling in these values. This will vary depending on your cloud service provider
STORAGES = { STORAGES = {
'default': { "default": {
# TODO: Set this up here if you're using cloud storage # TODO: Set this up here if you're using cloud storage
'BACKEND': None, "BACKEND": None,
'OPTIONS': { "OPTIONS": {
# Optional parameters # Optional parameters
}, },
}, },
'staticfiles': { "staticfiles": {
# TODO: Set this up here if you're using cloud storage # TODO: Set this up here if you're using cloud storage
'BACKEND': None, "BACKEND": None,
'OPTIONS': { "OPTIONS": {
# Optional parameters # Optional parameters
}, },
}, },
} }
else: else:
STATIC_URL = 'static/' STATIC_URL = "static/"
STATIC_ROOT = os.path.join(BASE_DIR, 'static') STATIC_ROOT = os.path.join(BASE_DIR, "static")
STATICFILES_STORAGE = "whitenoise.storage.CompressedManifestStaticFilesStorage" STATICFILES_STORAGE = "whitenoise.storage.CompressedManifestStaticFilesStorage"
MEDIA_URL = 'api/v1/media/' MEDIA_URL = "api/v1/media/"
MEDIA_ROOT = os.path.join(BASE_DIR, 'media') MEDIA_ROOT = os.path.join(BASE_DIR, "media")
ROOT_URLCONF = 'config.urls' ROOT_URLCONF = "config.urls"
TEMPLATES = [ TEMPLATES = [
{ {
'BACKEND': 'django.template.backends.django.DjangoTemplates', "BACKEND": "django.template.backends.django.DjangoTemplates",
'DIRS': [ "DIRS": [
BASE_DIR / 'emails/templates/', BASE_DIR / "emails/templates/",
], ],
'APP_DIRS': True, "APP_DIRS": True,
'OPTIONS': { "OPTIONS": {
'context_processors': [ "context_processors": [
'django.template.context_processors.debug', "django.template.context_processors.debug",
'django.template.context_processors.request', "django.template.context_processors.request",
'django.contrib.auth.context_processors.auth', "django.contrib.auth.context_processors.auth",
'django.contrib.messages.context_processors.messages', "django.contrib.messages.context_processors.messages",
], ],
}, },
}, },
] ]
REST_FRAMEWORK = { REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': ( "DEFAULT_AUTHENTICATION_CLASSES": (
'rest_framework_simplejwt.authentication.JWTAuthentication', "rest_framework_simplejwt.authentication.JWTAuthentication",
), ),
'DEFAULT_THROTTLE_CLASSES': [ "DEFAULT_THROTTLE_CLASSES": [
"rest_framework.throttling.AnonRateThrottle",
'rest_framework.throttling.AnonRateThrottle', "rest_framework.throttling.UserRateThrottle",
'rest_framework.throttling.UserRateThrottle'
], ],
"DEFAULT_THROTTLE_RATES": {"anon": "360/min", "user": "1440/min"},
'DEFAULT_THROTTLE_RATES': { "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
'anon': '360/min',
'user': '1440/min'
},
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
} }
# DRF-Spectacular # DRF-Spectacular
SPECTACULAR_SETTINGS = { SPECTACULAR_SETTINGS = {
'TITLE': 'DRF-Template', "TITLE": "DRF-Template",
'DESCRIPTION': 'A Template Project by Keannu Bernasol', "DESCRIPTION": "A Template Project by Keannu Bernasol",
'VERSION': '1.0.0', "VERSION": "1.0.0",
'SERVE_INCLUDE_SCHEMA': False, "SERVE_INCLUDE_SCHEMA": False,
'SWAGGER_UI_DIST': 'SIDECAR', "SWAGGER_UI_DIST": "SIDECAR",
'SWAGGER_UI_FAVICON_HREF': 'SIDECAR', "SWAGGER_UI_FAVICON_HREF": "SIDECAR",
'REDOC_DIST': 'SIDECAR', "REDOC_DIST": "SIDECAR",
} }
WSGI_APPLICATION = 'config.wsgi.application' WSGI_APPLICATION = "config.wsgi.application"
# If you're using an external connection bouncer (eg. PgBouncer), server side cursors must be disabled to avoid any issues # If you're using an external connection bouncer (eg. PgBouncer), server side cursors must be disabled to avoid any issues
USE_BOUNCER = get_secret("USE_BOUNCER") USE_BOUNCER = get_secret("USE_BOUNCER")
@ -266,15 +255,13 @@ else:
DATABASES = { DATABASES = {
"default": { "default": {
"ENGINE": "django.db.backends.postgresql", "ENGINE": "django.db.backends.postgresql",
'DISABLE_SERVER_SIDE_CURSORS': DISABLE_SERVER_SIDE_CURSORS, "DISABLE_SERVER_SIDE_CURSORS": DISABLE_SERVER_SIDE_CURSORS,
"NAME": get_secret("DB_DATABASE"), "NAME": get_secret("DB_DATABASE"),
"USER": get_secret("DB_USERNAME"), "USER": get_secret("DB_USERNAME"),
"PASSWORD": get_secret("DB_PASSWORD"), "PASSWORD": get_secret("DB_PASSWORD"),
"HOST": DB_HOST, "HOST": DB_HOST,
"PORT": DB_PORT, "PORT": DB_PORT,
"OPTIONS": { "OPTIONS": {"sslmode": get_secret("DB_SSL_MODE")},
"sslmode": get_secret("DB_SSL_MODE")
},
} }
} }
# Django Cache # Django Cache
@ -284,34 +271,34 @@ CACHES = {
"LOCATION": f"redis://{get_secret('REDIS_HOST')}:{get_secret('REDIS_PORT')}/2", "LOCATION": f"redis://{get_secret('REDIS_HOST')}:{get_secret('REDIS_PORT')}/2",
"OPTIONS": { "OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient", "CLIENT_CLASS": "django_redis.client.DefaultClient",
} },
} }
} }
AUTH_USER_MODEL = 'accounts.CustomUser' AUTH_USER_MODEL = "accounts.CustomUser"
DJOSER = { DJOSER = {
'SEND_ACTIVATION_EMAIL': True, "SEND_ACTIVATION_EMAIL": True,
'SEND_CONFIRMATION_EMAIL': True, "SEND_CONFIRMATION_EMAIL": True,
'PASSWORD_RESET_CONFIRM_URL': 'reset_password_confirm/{uid}/{token}', "PASSWORD_RESET_CONFIRM_URL": "reset_password_confirm/{uid}/{token}",
'ACTIVATION_URL': 'activation/{uid}/{token}', "ACTIVATION_URL": "activation/{uid}/{token}",
'USER_AUTHENTICATION_RULES': ['djoser.authentication.TokenAuthenticationRule'], "USER_AUTHENTICATION_RULES": ["djoser.authentication.TokenAuthenticationRule"],
'EMAIL': { "EMAIL": {
'activation': 'emails.templates.ActivationEmail', "activation": "emails.templates.ActivationEmail",
'password_reset': 'emails.templates.PasswordResetEmail' "password_reset": "emails.templates.PasswordResetEmail",
}, },
'SERIALIZERS': { "SERIALIZERS": {
'user': 'accounts.serializers.CustomUserSerializer', "user": "accounts.serializers.CustomUserSerializer",
'current_user': 'accounts.serializers.CustomUserSerializer', "current_user": "accounts.serializers.CustomUserSerializer",
'user_create': 'accounts.serializers.UserRegistrationSerializer', "user_create": "accounts.serializers.UserRegistrationSerializer",
}, },
'PERMISSIONS': { "PERMISSIONS": {
# Disable some unneeded endpoints by setting them to admin only # Disable some unneeded endpoints by setting them to admin only
'username_reset': ['rest_framework.permissions.IsAdminUser'], "username_reset": ["rest_framework.permissions.IsAdminUser"],
'username_reset_confirm': ['rest_framework.permissions.IsAdminUser'], "username_reset_confirm": ["rest_framework.permissions.IsAdminUser"],
'set_username': ['rest_framework.permissions.IsAdminUser'], "set_username": ["rest_framework.permissions.IsAdminUser"],
'set_password': ['rest_framework.permissions.IsAdminUser'], "set_password": ["rest_framework.permissions.IsAdminUser"],
} },
} }
# Password validation # Password validation
@ -319,32 +306,32 @@ DJOSER = {
AUTH_PASSWORD_VALIDATORS = [ AUTH_PASSWORD_VALIDATORS = [
{ {
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator",
}, },
{ {
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
"OPTIONS": { "OPTIONS": {
"min_length": 8, "min_length": 8,
} },
}, },
{ {
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
}, },
{ {
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
}, },
# Additional password validators # Additional password validators
{ {
'NAME': 'accounts.validators.SpecialCharacterValidator', "NAME": "accounts.validators.SpecialCharacterValidator",
}, },
{ {
'NAME': 'accounts.validators.LowercaseValidator', "NAME": "accounts.validators.LowercaseValidator",
}, },
{ {
'NAME': 'accounts.validators.UppercaseValidator', "NAME": "accounts.validators.UppercaseValidator",
}, },
{ {
'NAME': 'accounts.validators.NumberValidator', "NAME": "accounts.validators.NumberValidator",
}, },
] ]
@ -352,9 +339,9 @@ AUTH_PASSWORD_VALIDATORS = [
# Internationalization # Internationalization
# https://docs.djangoproject.com/en/4.2/topics/i18n/ # https://docs.djangoproject.com/en/4.2/topics/i18n/
LANGUAGE_CODE = 'en-us' LANGUAGE_CODE = "en-us"
TIME_ZONE = get_secret('TIMEZONE') TIME_ZONE = get_secret("TIMEZONE")
USE_I18N = True USE_I18N = True
@ -364,14 +351,14 @@ USE_TZ = True
# Default primary key field type # Default primary key field type
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field # https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
SITE_NAME = 'DRF-Template' SITE_NAME = "DRF-Template"
# JWT Token Lifetimes # JWT Token Lifetimes
SIMPLE_JWT = { SIMPLE_JWT = {
"ACCESS_TOKEN_LIFETIME": timedelta(hours=1), "ACCESS_TOKEN_LIFETIME": timedelta(hours=1),
"REFRESH_TOKEN_LIFETIME": timedelta(days=3) "REFRESH_TOKEN_LIFETIME": timedelta(days=3),
} }
CORS_ALLOW_ALL_ORIGINS = True CORS_ALLOW_ALL_ORIGINS = True
@ -388,11 +375,19 @@ CELERY_RESULT_BACKEND = get_secret("CELERY_RESULT_BACKEND")
CELERY_RESULT_EXTENDED = True CELERY_RESULT_EXTENDED = True
# Celery Beat Options # Celery Beat Options
CELERY_BEAT_SCHEDULER = 'django_celery_beat.schedulers:DatabaseScheduler' CELERY_BEAT_SCHEDULER = "django_celery_beat.schedulers:DatabaseScheduler"
# Maximum number of rows that can be updated within the Django admin panel # Maximum number of rows that can be updated within the Django admin panel
DATA_UPLOAD_MAX_NUMBER_FIELDS = 20480 DATA_UPLOAD_MAX_NUMBER_FIELDS = 20480
GRAPH_MODELS = { GRAPH_MODELS = {
'app_labels': ['accounts', 'user_groups', 'billing', 'emails', 'payments', 'subscriptions', 'search_results'] "app_labels": [
"accounts",
"user_groups",
"billing",
"emails",
"payments",
"subscriptions",
"search_results",
]
} }

View file

@ -1,5 +1,5 @@
from django.urls import path, include from django.urls import path, include
urlpatterns = [ urlpatterns = [
path('api/v1/', include('api.urls')), path("api/v1/", include("api.urls")),
] ]

View file

@ -11,6 +11,6 @@ import os
from django.core.wsgi import get_wsgi_application from django.core.wsgi import get_wsgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
application = get_wsgi_application() application = get_wsgi_application()

View file

@ -2,5 +2,5 @@ from django.apps import AppConfig
class EmailsConfig(AppConfig): class EmailsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField' default_auto_field = "django.db.models.BigAutoField"
name = 'emails' name = "emails"

View file

@ -3,11 +3,11 @@ from django.utils import timezone
class ActivationEmail(email.ActivationEmail): class ActivationEmail(email.ActivationEmail):
template_name = 'email_activation.html' template_name = "email_activation.html"
class PasswordResetEmail(email.PasswordResetEmail): class PasswordResetEmail(email.PasswordResetEmail):
template_name = 'password_change.html' template_name = "password_change.html"
class SubscriptionAvailedEmail(email.BaseEmailMessage): class SubscriptionAvailedEmail(email.BaseEmailMessage):
@ -19,7 +19,7 @@ class SubscriptionAvailedEmail(email.BaseEmailMessage):
context["subscription_plan"] = context.get("subscription_plan") context["subscription_plan"] = context.get("subscription_plan")
context["subscription"] = context.get("subscription") context["subscription"] = context.get("subscription")
context["price_paid"] = context.get("price_paid") context["price_paid"] = context.get("price_paid")
context['date'] = timezone.now().strftime("%B %d, %I:%M %p") context["date"] = timezone.now().strftime("%B %d, %I:%M %p")
context.update(self.context) context.update(self.context)
return context return context
@ -32,7 +32,7 @@ class SubscriptionRefundedEmail(email.BaseEmailMessage):
context["user"] = context.get("user") context["user"] = context.get("user")
context["subscription_plan"] = context.get("subscription_plan") context["subscription_plan"] = context.get("subscription_plan")
context["refund"] = context.get("refund") context["refund"] = context.get("refund")
context['date'] = timezone.now().strftime("%B %d, %I:%M %p") context["date"] = timezone.now().strftime("%B %d, %I:%M %p")
context.update(self.context) context.update(self.context)
return context return context
@ -44,6 +44,6 @@ class SubscriptionCancelledEmail(email.BaseEmailMessage):
context = super().get_context_data() context = super().get_context_data()
context["user"] = context.get("user") context["user"] = context.get("user")
context["subscription_plan"] = context.get("subscription_plan") context["subscription_plan"] = context.get("subscription_plan")
context['date'] = timezone.now().strftime("%B %d, %I:%M %p") context["date"] = timezone.now().strftime("%B %d, %I:%M %p")
context.update(self.context) context.update(self.context)
return context return context

View file

@ -6,7 +6,7 @@ import sys
def main(): def main():
"""Run administrative tasks.""" """Run administrative tasks."""
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
try: try:
from django.core.management import execute_from_command_line from django.core.management import execute_from_command_line
except ImportError as exc: except ImportError as exc:
@ -18,5 +18,5 @@ def main():
execute_from_command_line(sys.argv) execute_from_command_line(sys.argv)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -6,5 +6,5 @@ from .models import Notification
@admin.register(Notification) @admin.register(Notification)
class NotificationAdmin(ModelAdmin): class NotificationAdmin(ModelAdmin):
model = Notification model = Notification
search_fields = ('id', 'content') search_fields = ("id", "content")
list_display = ['id', 'dismissed'] list_display = ["id", "dismissed"]

View file

@ -2,8 +2,8 @@ from django.apps import AppConfig
class NotificationsConfig(AppConfig): class NotificationsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField' default_auto_field = "django.db.models.BigAutoField"
name = 'notifications' name = "notifications"
def ready(self): def ready(self):
import notifications.signals import notifications.signals

View file

@ -15,13 +15,27 @@ class Migration(migrations.Migration):
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='Notification', name="Notification",
fields=[ fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('content', models.CharField(max_length=1000, null=True)), "id",
('timestamp', models.DateTimeField(auto_now_add=True)), models.BigAutoField(
('dismissed', models.BooleanField(default=False)), auto_created=True,
('recipient', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("content", models.CharField(max_length=1000, null=True)),
("timestamp", models.DateTimeField(auto_now_add=True)),
("dismissed", models.BooleanField(default=False)),
(
"recipient",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
], ],
), ),
] ]

View file

@ -2,8 +2,7 @@ from django.db import models
class Notification(models.Model): class Notification(models.Model):
recipient = models.ForeignKey( recipient = models.ForeignKey("accounts.CustomUser", on_delete=models.CASCADE)
'accounts.CustomUser', on_delete=models.CASCADE)
content = models.CharField(max_length=1000, null=True) content = models.CharField(max_length=1000, null=True)
timestamp = models.DateTimeField(auto_now_add=True, editable=False) timestamp = models.DateTimeField(auto_now_add=True, editable=False)
dismissed = models.BooleanField(default=False) dismissed = models.BooleanField(default=False)

View file

@ -3,10 +3,9 @@ from notifications.models import Notification
class NotificationSerializer(serializers.ModelSerializer): class NotificationSerializer(serializers.ModelSerializer):
timestamp = serializers.DateTimeField( timestamp = serializers.DateTimeField(format="%m-%d-%Y %I:%M %p", read_only=True)
format="%m-%d-%Y %I:%M %p", read_only=True)
class Meta: class Meta:
model = Notification model = Notification
fields = '__all__' fields = "__all__"
read_only_fields = ('id', 'recipient', 'content', 'timestamp') read_only_fields = ("id", "recipient", "content", "timestamp")

View file

@ -9,5 +9,5 @@ from django.core.cache import cache
@receiver(post_save, sender=Notification) @receiver(post_save, sender=Notification)
def clear_cache_after_notification_update(sender, instance, **kwargs): def clear_cache_after_notification_update(sender, instance, **kwargs):
# Clear cache # Clear cache
cache.delete('notifications') cache.delete("notifications")
cache.delete(f'notifications_user:{instance.recipient.id}') cache.delete(f"notifications_user:{instance.recipient.id}")

View file

@ -9,5 +9,4 @@ def cleanup_notifications():
three_days_ago = timezone.now() - timezone.timedelta(days=3) three_days_ago = timezone.now() - timezone.timedelta(days=3)
# Delete notifications that are older than 3 days and dismissed # Delete notifications that are older than 3 days and dismissed
Notification.objects.filter( Notification.objects.filter(dismissed=True, timestamp__lte=three_days_ago).delete()
dismissed=True, timestamp__lte=three_days_ago).delete()

View file

@ -3,8 +3,7 @@ from notifications.views import NotificationViewSet
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
router = DefaultRouter() router = DefaultRouter()
router.register(r'', NotificationViewSet, router.register(r"", NotificationViewSet, basename="Notifications")
basename="Notifications")
urlpatterns = [ urlpatterns = [
path('', include(router.urls)), path("", include(router.urls)),
] ]

View file

@ -6,17 +6,18 @@ from django.core.cache import cache
class NotificationViewSet(viewsets.ModelViewSet): class NotificationViewSet(viewsets.ModelViewSet):
http_method_names = ['get', 'patch', 'delete'] http_method_names = ["get", "patch", "delete"]
serializer_class = NotificationSerializer serializer_class = NotificationSerializer
queryset = Notification.objects.all() queryset = Notification.objects.all()
def get_queryset(self): def get_queryset(self):
user = self.request.user user = self.request.user
key = f'notifications_user:{user.id}' key = f"notifications_user:{user.id}"
queryset = cache.get(key) queryset = cache.get(key)
if not queryset: if not queryset:
queryset = Notification.objects.filter( queryset = Notification.objects.filter(recipient=user).order_by(
recipient=user).order_by('-timestamp') "-timestamp"
)
cache.set(key, queryset, 60 * 60) cache.set(key, queryset, 60 * 60)
return queryset return queryset
@ -24,12 +25,14 @@ class NotificationViewSet(viewsets.ModelViewSet):
instance = self.get_object() instance = self.get_object()
if instance.recipient != request.user: if instance.recipient != request.user:
raise PermissionDenied( raise PermissionDenied(
"You do not have permission to update this notification.") "You do not have permission to update this notification."
)
return super().update(request, *args, **kwargs) return super().update(request, *args, **kwargs)
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
instance = self.get_object() instance = self.get_object()
if instance.recipient != request.user: if instance.recipient != request.user:
raise PermissionDenied( raise PermissionDenied(
"You do not have permission to delete this notification.") "You do not have permission to delete this notification."
)
return super().destroy(request, *args, **kwargs) return super().destroy(request, *args, **kwargs)

View file

@ -2,5 +2,5 @@ from django.apps import AppConfig
class PaymentsConfig(AppConfig): class PaymentsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField' default_auto_field = "django.db.models.BigAutoField"
name = 'payments' name = "payments"

View file

@ -3,6 +3,6 @@ from payments import views
urlpatterns = [ urlpatterns = [
path('checkout_session/', views.StripeCheckoutView.as_view()), path("checkout_session/", views.StripeCheckoutView.as_view()),
path('webhook/', views.stripe_webhook_view, name='Stripe Webhook'), path("webhook/", views.stripe_webhook_view, name="Stripe Webhook"),
] ]

View file

@ -1,4 +1,10 @@
from config.settings import STRIPE_SECRET_KEY, STRIPE_SECRET_WEBHOOK, URL_SCHEME, FRONTEND_ADDRESS, FRONTEND_PORT from config.settings import (
STRIPE_SECRET_KEY,
STRIPE_SECRET_WEBHOOK,
URL_SCHEME,
FRONTEND_ADDRESS,
FRONTEND_PORT,
)
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.response import Response from rest_framework.response import Response
@ -12,16 +18,19 @@ from accounts.models import CustomUser
from rest_framework.decorators import api_view from rest_framework.decorators import api_view
from subscriptions.tasks import get_user_subscription from subscriptions.tasks import get_user_subscription
import json import json
from emails.templates import SubscriptionAvailedEmail, SubscriptionRefundedEmail, SubscriptionCancelledEmail from emails.templates import (
SubscriptionAvailedEmail,
SubscriptionRefundedEmail,
SubscriptionCancelledEmail,
)
from django.core.cache import cache from django.core.cache import cache
from payments.serializers import CheckoutSerializer from payments.serializers import CheckoutSerializer
from drf_spectacular.utils import extend_schema from drf_spectacular.utils import extend_schema
stripe.api_key = STRIPE_SECRET_KEY stripe.api_key = STRIPE_SECRET_KEY
@extend_schema( @extend_schema(request=CheckoutSerializer)
request=CheckoutSerializer
)
class StripeCheckoutView(APIView): class StripeCheckoutView(APIView):
permission_classes = [IsAuthenticated] permission_classes = [IsAuthenticated]
@ -30,41 +39,46 @@ class StripeCheckoutView(APIView):
# Get subscription ID from POST # Get subscription ID from POST
USER = CustomUser.objects.get(id=self.request.user.id) USER = CustomUser.objects.get(id=self.request.user.id)
data = json.loads(request.body) data = json.loads(request.body)
subscription_id = data.get('subscription_id') subscription_id = data.get("subscription_id")
annual = data.get('annual') annual = data.get("annual")
# Validation for subscription_id field # Validation for subscription_id field
try: try:
subscription_id = int(subscription_id) subscription_id = int(subscription_id)
except: except:
return Response({ return Response(
'error': 'Invalid value specified in subscription_id field' {"error": "Invalid value specified in subscription_id field"},
}, status=status.HTTP_403_FORBIDDEN) status=status.HTTP_403_FORBIDDEN,
)
# Validation for annual field # Validation for annual field
try: try:
annual = bool(annual) annual = bool(annual)
except: except:
return Response({ return Response(
'error': 'Invalid value specified in annual field' {"error": "Invalid value specified in annual field"},
}, status=status.HTTP_403_FORBIDDEN) status=status.HTTP_403_FORBIDDEN,
)
# Return an error if the user already has an active subscription # Return an error if the user already has an active subscription
EXISTING_SUBSCRIPTION = get_user_subscription(USER.id) EXISTING_SUBSCRIPTION = get_user_subscription(USER.id)
if EXISTING_SUBSCRIPTION: if EXISTING_SUBSCRIPTION:
return Response({ return Response(
'error': f'User is already subscribed to: {EXISTING_SUBSCRIPTION.subscription.name}' {
}, status=status.HTTP_403_FORBIDDEN) "error": f"User is already subscribed to: {EXISTING_SUBSCRIPTION.subscription.name}"
},
status=status.HTTP_403_FORBIDDEN,
)
# Attempt to query the subscription # Attempt to query the subscription
SUBSCRIPTION = SubscriptionPlan.objects.filter( SUBSCRIPTION = SubscriptionPlan.objects.filter(id=subscription_id).first()
id=subscription_id).first()
# Return an error if the plan does not exist # Return an error if the plan does not exist
if not SUBSCRIPTION: if not SUBSCRIPTION:
return Response({ return Response(
'error': 'Subscription plan not found' {"error": "Subscription plan not found"},
}, status=status.HTTP_404_NOT_FOUND) status=status.HTTP_404_NOT_FOUND,
)
# Get the stripe_price_id from the related StripePrice instances # Get the stripe_price_id from the related StripePrice instances
if annual: if annual:
@ -74,52 +88,58 @@ class StripeCheckoutView(APIView):
# Return 404 if no price is set # Return 404 if no price is set
if not PRICE: if not PRICE:
return Response({ return Response(
'error': 'Specified price does not exist for plan' {"error": "Specified price does not exist for plan"},
}, status=status.HTTP_404_NOT_FOUND) status=status.HTTP_404_NOT_FOUND,
)
PRICE_ID = PRICE.stripe_price_id PRICE_ID = PRICE.stripe_price_id
prorated = PRICE.prorated prorated = PRICE.prorated
# Return an error if a user is in a user_group and is availing pro-rated plans # Return an error if a user is in a user_group and is availing pro-rated plans
if not USER.user_group and SUBSCRIPTION.group_exclusive: if not USER.user_group and SUBSCRIPTION.group_exclusive:
return Response({ return Response(
'error': 'Regular users cannot avail prorated plans' {"error": "Regular users cannot avail prorated plans"},
}, status=status.HTTP_403_FORBIDDEN) status=status.HTTP_403_FORBIDDEN,
)
success_url = f'{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}' + \ success_url = (
'/user/subscription/payment?success=true&agency=False&session_id={CHECKOUT_SESSION_ID}' f"{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}"
cancel_url = f'{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}' + \ + "/user/subscription/payment?success=true&agency=False&session_id={CHECKOUT_SESSION_ID}"
'/user/subscription/payment?success=false&user_group=False' )
cancel_url = (
f"{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}"
+ "/user/subscription/payment?success=false&user_group=False"
)
checkout_session = stripe.checkout.Session.create( checkout_session = stripe.checkout.Session.create(
line_items=[ line_items=[
{ (
'price': PRICE_ID, {"price": PRICE_ID, "quantity": 1}
'quantity': 1 if not prorated
} if not prorated else else {
{ "price": PRICE_ID,
'price': PRICE_ID,
} }
)
], ],
mode='subscription', mode="subscription",
payment_method_types=['card'], payment_method_types=["card"],
success_url=success_url, success_url=success_url,
cancel_url=cancel_url, cancel_url=cancel_url,
) )
return Response({"url": checkout_session.url}) return Response({"url": checkout_session.url})
except Exception as e: except Exception as e:
logging.error(str(e)) logging.error(str(e))
return Response({ return Response(
'error': str(e) {"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) )
@ api_view(['POST']) @api_view(["POST"])
@csrf_exempt @csrf_exempt
def stripe_webhook_view(request): def stripe_webhook_view(request):
payload = request.body payload = request.body
sig_header = request.META['HTTP_STRIPE_SIGNATURE'] sig_header = request.META["HTTP_STRIPE_SIGNATURE"]
event = None event = None
try: try:
@ -133,12 +153,12 @@ def stripe_webhook_view(request):
# Invalid signature # Invalid signature
return Response(status=401) return Response(status=401)
if event['type'] == 'customer.subscription.created': if event["type"] == "customer.subscription.created":
subscription = event['data']['object'] subscription = event["data"]["object"]
# Get the Invoice object from the Subscription object # Get the Invoice object from the Subscription object
invoice = stripe.Invoice.retrieve(subscription['latest_invoice']) invoice = stripe.Invoice.retrieve(subscription["latest_invoice"])
# Get the Charge object from the Invoice object # Get the Charge object from the Invoice object
charge = stripe.Charge.retrieve(invoice['charge']) charge = stripe.Charge.retrieve(invoice["charge"])
# Get paying user # Get paying user
customer = stripe.Customer.retrieve(subscription["customer"]) customer = stripe.Customer.retrieve(subscription["customer"])
@ -146,18 +166,20 @@ def stripe_webhook_view(request):
product = subscription["items"]["data"][0] product = subscription["items"]["data"][0]
SUBSCRIPTION_PLAN = SubscriptionPlan.objects.get( SUBSCRIPTION_PLAN = SubscriptionPlan.objects.get(
stripe_product_id=product["plan"]["product"]) stripe_product_id=product["plan"]["product"]
)
SUBSCRIPTION = UserSubscription.objects.create( SUBSCRIPTION = UserSubscription.objects.create(
subscription=SUBSCRIPTION_PLAN, subscription=SUBSCRIPTION_PLAN,
annual=product["plan"]["interval"] == "year", annual=product["plan"]["interval"] == "year",
valid=True, valid=True,
user=USER, user=USER,
stripe_id=subscription['id']) stripe_id=subscription["id"],
)
email = SubscriptionAvailedEmail() email = SubscriptionAvailedEmail()
paid = { paid = {
"amount": charge['amount']/100, "amount": charge["amount"] / 100,
"currency": str(charge['currency']).upper() "currency": str(charge["currency"]).upper(),
} }
email.context = { email.context = {
@ -169,19 +191,20 @@ def stripe_webhook_view(request):
email.send(to=[customer.email]) email.send(to=[customer.email])
# Clear cache # Clear cache
cache.delete(f'billing_user:{USER.id}') cache.delete(f"billing_user:{USER.id}")
cache.delete(f'subscriptions_user:{USER.id}') cache.delete(f"subscriptions_user:{USER.id}")
# On chargebacks/refunds, invalidate the subscription # On chargebacks/refunds, invalidate the subscription
elif event['type'] == 'charge.refunded': elif event["type"] == "charge.refunded":
charge = event['data']['object'] charge = event["data"]["object"]
# Get the Invoice object from the Charge object # Get the Invoice object from the Charge object
invoice = stripe.Invoice.retrieve(charge['invoice']) invoice = stripe.Invoice.retrieve(charge["invoice"])
# Check if the subscription exists # Check if the subscription exists
SUBSCRIPTION = UserSubscription.objects.filter( SUBSCRIPTION = UserSubscription.objects.filter(
stripe_id=invoice['subscription']).first() stripe_id=invoice["subscription"]
).first()
if not (SUBSCRIPTION): if not (SUBSCRIPTION):
return HttpResponse(status=404) return HttpResponse(status=404)
@ -196,8 +219,8 @@ def stripe_webhook_view(request):
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
refund = { refund = {
"amount": charge['amount_refunded']/100, "amount": charge["amount_refunded"] / 100,
"currency": str(charge['currency']).upper() "currency": str(charge["currency"]).upper(),
} }
# Send an email # Send an email
@ -206,13 +229,13 @@ def stripe_webhook_view(request):
email.context = { email.context = {
"user": USER, "user": USER,
"subscription_plan": SUBSCRIPTION_PLAN, "subscription_plan": SUBSCRIPTION_PLAN,
"refund": refund "refund": refund,
} }
email.send(to=[USER.email]) email.send(to=[USER.email])
# Clear cache # Clear cache
cache.delete(f'billing_user:{USER.id}') cache.delete(f"billing_user:{USER.id}")
elif SUBSCRIPTION.user_group: elif SUBSCRIPTION.user_group:
OWNER = SUBSCRIPTION.user_group.owner OWNER = SUBSCRIPTION.user_group.owner
@ -223,8 +246,8 @@ def stripe_webhook_view(request):
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
refund = { refund = {
"amount": charge['amount_refunded']/100, "amount": charge["amount_refunded"] / 100,
"currency": str(charge['currency']).upper() "currency": str(charge["currency"]).upper(),
} }
# Send en email # Send en email
@ -233,36 +256,38 @@ def stripe_webhook_view(request):
email.context = { email.context = {
"user": OWNER, "user": OWNER,
"subscription_plan": SUBSCRIPTION_PLAN, "subscription_plan": SUBSCRIPTION_PLAN,
"refund": refund "refund": refund,
} }
email.send(to=[OWNER.email]) email.send(to=[OWNER.email])
# Clear cache # Clear cache
cache.delete(f'billing_user:{USER.id}') cache.delete(f"billing_user:{USER.id}")
cache.delete(f'subscriptions_user:{USER.id}') cache.delete(f"subscriptions_user:{USER.id}")
elif event['type'] == 'customer.subscription.updated': elif event["type"] == "customer.subscription.updated":
subscription = event['data']['object'] subscription = event["data"]["object"]
# Check if the subscription exists # Check if the subscription exists
SUBSCRIPTION = UserSubscription.objects.filter( SUBSCRIPTION = UserSubscription.objects.filter(
stripe_id=subscription['id']).first() stripe_id=subscription["id"]
).first()
if not (SUBSCRIPTION): if not (SUBSCRIPTION):
return HttpResponse(status=404) return HttpResponse(status=404)
# Check if a subscription has been upgraded/downgraded # Check if a subscription has been upgraded/downgraded
new_stripe_product_id = subscription['items']['data'][0]['plan']['product'] new_stripe_product_id = subscription["items"]["data"][0]["plan"]["product"]
current_stripe_product_id = SUBSCRIPTION.subscription.stripe_product_id current_stripe_product_id = SUBSCRIPTION.subscription.stripe_product_id
if new_stripe_product_id != current_stripe_product_id: if new_stripe_product_id != current_stripe_product_id:
SUBSCRIPTION_PLAN = SubscriptionPlan.objects.get( SUBSCRIPTION_PLAN = SubscriptionPlan.objects.get(
stripe_product_id=new_stripe_product_id) stripe_product_id=new_stripe_product_id
)
SUBSCRIPTION.subscription = SUBSCRIPTION_PLAN SUBSCRIPTION.subscription = SUBSCRIPTION_PLAN
SUBSCRIPTION.save() SUBSCRIPTION.save()
# TODO: Add a plan upgraded email message here # TODO: Add a plan upgraded email message here
# Subscription activation/reactivation # Subscription activation/reactivation
if subscription['status'] == 'active': if subscription["status"] == "active":
SUBSCRIPTION.valid = True SUBSCRIPTION.valid = True
SUBSCRIPTION.save() SUBSCRIPTION.save()
@ -270,26 +295,24 @@ def stripe_webhook_view(request):
USER = SUBSCRIPTION.user USER = SUBSCRIPTION.user
# Clear cache # Clear cache
cache.delete(f'billing_user:{USER.id}') cache.delete(f"billing_user:{USER.id}")
cache.delete( cache.delete(f"subscriptions_user:{USER.id}")
f'subscriptions_user:{USER.id}')
elif SUBSCRIPTION.user_group: elif SUBSCRIPTION.user_group:
OWNER = SUBSCRIPTION.user_group.owner OWNER = SUBSCRIPTION.user_group.owner
# Clear cache # Clear cache
cache.delete(f'billing_user:{OWNER.id}') cache.delete(f"billing_user:{OWNER.id}")
cache.delete( cache.delete(f"subscriptions_usergroup:{SUBSCRIPTION.user_group.id}")
f'subscriptions_usergroup:{SUBSCRIPTION.user_group.id}')
# TODO: Add notification here to inform users if their plan has been reactivated # TODO: Add notification here to inform users if their plan has been reactivated
elif subscription['status'] == 'past_due': elif subscription["status"] == "past_due":
# TODO: Add notification here to inform users if their payment method for an existing subscription payment is failing # TODO: Add notification here to inform users if their payment method for an existing subscription payment is failing
pass pass
# If subscriptions get cancelled due to non-payment, invalidate the UserSubscription # If subscriptions get cancelled due to non-payment, invalidate the UserSubscription
elif subscription['status'] == 'cancelled': elif subscription["status"] == "cancelled":
if SUBSCRIPTION.user: if SUBSCRIPTION.user:
USER = SUBSCRIPTION.user USER = SUBSCRIPTION.user
@ -310,8 +333,8 @@ def stripe_webhook_view(request):
email.send(to=[USER.email]) email.send(to=[USER.email])
# Clear cache # Clear cache
cache.delete(f'billing_user:{USER.id}') cache.delete(f"billing_user:{USER.id}")
cache.delete(f'subscriptions_user:{USER.id}') cache.delete(f"subscriptions_user:{USER.id}")
elif SUBSCRIPTION.user_group: elif SUBSCRIPTION.user_group:
OWNER = SUBSCRIPTION.user_group.owner OWNER = SUBSCRIPTION.user_group.owner
@ -325,24 +348,21 @@ def stripe_webhook_view(request):
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
email.context = { email.context = {"user": OWNER, "subscription_plan": SUBSCRIPTION_PLAN}
"user": OWNER,
"subscription_plan": SUBSCRIPTION_PLAN
}
email.send(to=[OWNER.email]) email.send(to=[OWNER.email])
# Clear cache # Clear cache
cache.delete(f'billing_user:{OWNER.id}') cache.delete(f"billing_user:{OWNER.id}")
cache.delete( cache.delete(f"subscriptions_usergroup:{SUBSCRIPTION.user_group.id}")
f'subscriptions_usergroup:{SUBSCRIPTION.user_group.id}')
# If a subscription gets cancelled, invalidate it # If a subscription gets cancelled, invalidate it
elif event['type'] == 'customer.subscription.deleted': elif event["type"] == "customer.subscription.deleted":
subscription = event['data']['object'] subscription = event["data"]["object"]
# Check if the subscription exists # Check if the subscription exists
SUBSCRIPTION = UserSubscription.objects.filter( SUBSCRIPTION = UserSubscription.objects.filter(
stripe_id=subscription['id']).first() stripe_id=subscription["id"]
).first()
if not (SUBSCRIPTION): if not (SUBSCRIPTION):
return HttpResponse(status=404) return HttpResponse(status=404)
@ -367,7 +387,7 @@ def stripe_webhook_view(request):
email.send(to=[USER.email]) email.send(to=[USER.email])
# Clear cache # Clear cache
cache.delete(f'billing_user:{USER.id}') cache.delete(f"billing_user:{USER.id}")
elif SUBSCRIPTION.user_group: elif SUBSCRIPTION.user_group:
OWNER = SUBSCRIPTION.user_group.owner OWNER = SUBSCRIPTION.user_group.owner
@ -381,14 +401,11 @@ def stripe_webhook_view(request):
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
email.context = { email.context = {"user": OWNER, "subscription_plan": SUBSCRIPTION_PLAN}
"user": OWNER,
"subscription_plan": SUBSCRIPTION_PLAN
}
email.send(to=[OWNER.email]) email.send(to=[OWNER.email])
# Clear cache # Clear cache
cache.delete(f'billing_user:{OWNER.id}') cache.delete(f"billing_user:{OWNER.id}")
# Passed signature verification # Passed signature verification
return HttpResponse(status=200) return HttpResponse(status=200)

View file

@ -805,6 +805,9 @@ components:
first_name: first_name:
type: string type: string
maxLength: 150 maxLength: 150
is_new:
type: string
readOnly: true
last_name: last_name:
type: string type: string
maxLength: 150 maxLength: 150
@ -824,6 +827,7 @@ components:
- group_member - group_member
- group_owner - group_owner
- id - id
- is_new
- user_group - user_group
- username - username
Notification: Notification:
@ -885,6 +889,9 @@ components:
first_name: first_name:
type: string type: string
maxLength: 150 maxLength: 150
is_new:
type: string
readOnly: true
last_name: last_name:
type: string type: string
maxLength: 150 maxLength: 150

View file

@ -7,12 +7,11 @@ from unfold.contrib.filters.admin import RangeDateFilter
@admin.register(SearchResult) @admin.register(SearchResult)
class SearchResultAdmin(ModelAdmin): class SearchResultAdmin(ModelAdmin):
model = SearchResult model = SearchResult
search_fields = ('id', 'title', 'link') search_fields = ("id", "title", "link")
list_display = ['id', 'title', 'timestamp'] list_display = ["id", "title", "timestamp"]
list_filter_submit = True list_filter_submit = True
list_filter = (( list_filter = (
"timestamp", RangeDateFilter ("timestamp", RangeDateFilter),
), ( ("timestamp", RangeDateFilter),
"timestamp", RangeDateFilter )
),)

View file

@ -2,5 +2,5 @@ from django.apps import AppConfig
class SearchResultsConfig(AppConfig): class SearchResultsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField' default_auto_field = "django.db.models.BigAutoField"
name = 'search_results' name = "search_results"

View file

@ -7,17 +7,24 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies = []
]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='SearchResult', name="SearchResult",
fields=[ fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('title', models.CharField(max_length=1000)), "id",
('link', models.CharField(max_length=1000)), models.BigAutoField(
('timestamp', models.DateTimeField(auto_now_add=True)), auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("title", models.CharField(max_length=1000)),
("link", models.CharField(max_length=1000)),
("timestamp", models.DateTimeField(auto_now_add=True)),
], ],
), ),
] ]

View file

@ -1,16 +1,13 @@
from celery import shared_task from celery import shared_task
from .models import SearchResult from .models import SearchResult
@shared_task(autoretry_for=(Exception,), retry_kwargs={'max_retries': 0, 'countdown': 5}) @shared_task(
autoretry_for=(Exception,), retry_kwargs={"max_retries": 0, "countdown": 5}
)
def create_search_result(title, link): def create_search_result(title, link):
if SearchResult.objects.filter(title=title, link=link).exists(): if SearchResult.objects.filter(title=title, link=link).exists():
return ("SearchResult entry already exists") return "SearchResult entry already exists"
else: else:
SearchResult.objects.create( SearchResult.objects.create(title=title, link=link)
title=title,
link=link
)
return f"Created new SearchResult entry titled: {title}" return f"Created new SearchResult entry titled: {title}"

View file

@ -6,10 +6,24 @@ from unfold.contrib.filters.admin import RangeDateFilter
@admin.register(StripePrice) @admin.register(StripePrice)
class StripePriceAdmin(ModelAdmin): class StripePriceAdmin(ModelAdmin):
search_fields = ["id", "lookup_key", search_fields = [
"stripe_price_id","price","currency", "prorated", "annual"] "id",
list_display = ["id", "lookup_key", "lookup_key",
"stripe_price_id", "price", "currency", "prorated", "annual"] "stripe_price_id",
"price",
"currency",
"prorated",
"annual",
]
list_display = [
"id",
"lookup_key",
"stripe_price_id",
"price",
"currency",
"prorated",
"annual",
]
@admin.register(SubscriptionPlan) @admin.register(SubscriptionPlan)
@ -21,9 +35,6 @@ class SubscriptionPlanAdmin(ModelAdmin):
@admin.register(UserSubscription) @admin.register(UserSubscription)
class UserSubscriptionAdmin(ModelAdmin): class UserSubscriptionAdmin(ModelAdmin):
list_filter_submit = True list_filter_submit = True
list_filter = (( list_filter = (("date", RangeDateFilter),)
"date", RangeDateFilter list_display = ["id", "__str__", "valid", "annual", "date"]
),)
list_display = ["id", "__str__", "valid", "annual",
"date"]
search_fields = ["id", "date"] search_fields = ["id", "date"]

View file

@ -2,8 +2,8 @@ from django.apps import AppConfig
class SubscriptionConfig(AppConfig): class SubscriptionConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField' default_auto_field = "django.db.models.BigAutoField"
name = 'subscriptions' name = "subscriptions"
def ready(self): def ready(self):
import subscriptions.signals import subscriptions.signals

View file

@ -11,46 +11,118 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies = [
('user_groups', '0001_initial'), ("user_groups", "0001_initial"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL), migrations.swappable_dependency(settings.AUTH_USER_MODEL),
] ]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='StripePrice', name="StripePrice",
fields=[ fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('annual', models.BooleanField(default=False)), "id",
('stripe_price_id', models.CharField(max_length=100)), models.BigAutoField(
('price', models.DecimalField(decimal_places=2, default=0.0, max_digits=10)), auto_created=True,
('currency', models.CharField(max_length=20)), primary_key=True,
('lookup_key', models.CharField(blank=True, max_length=100, null=True)), serialize=False,
('prorated', models.BooleanField(default=False)), verbose_name="ID",
),
),
("annual", models.BooleanField(default=False)),
("stripe_price_id", models.CharField(max_length=100)),
(
"price",
models.DecimalField(decimal_places=2, default=0.0, max_digits=10),
),
("currency", models.CharField(max_length=20)),
("lookup_key", models.CharField(blank=True, max_length=100, null=True)),
("prorated", models.BooleanField(default=False)),
], ],
), ),
migrations.CreateModel( migrations.CreateModel(
name='SubscriptionPlan', name="SubscriptionPlan",
fields=[ fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('name', models.CharField(max_length=100)), "id",
('description', models.TextField(max_length=1024, null=True)), models.BigAutoField(
('stripe_product_id', models.CharField(max_length=100)), auto_created=True,
('group_exclusive', models.BooleanField(default=False)), primary_key=True,
('annual_price', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='annual_plan', to='subscriptions.stripeprice')), serialize=False,
('monthly_price', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='monthly_plan', to='subscriptions.stripeprice')), verbose_name="ID",
),
),
("name", models.CharField(max_length=100)),
("description", models.TextField(max_length=1024, null=True)),
("stripe_product_id", models.CharField(max_length=100)),
("group_exclusive", models.BooleanField(default=False)),
(
"annual_price",
models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="annual_plan",
to="subscriptions.stripeprice",
),
),
(
"monthly_price",
models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="monthly_plan",
to="subscriptions.stripeprice",
),
),
], ],
), ),
migrations.CreateModel( migrations.CreateModel(
name='UserSubscription', name="UserSubscription",
fields=[ fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('stripe_id', models.CharField(max_length=100)), "id",
('date', models.DateTimeField(default=django.utils.timezone.now, editable=False)), models.BigAutoField(
('valid', models.BooleanField()), auto_created=True,
('annual', models.BooleanField()), primary_key=True,
('subscription', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='subscriptions.subscriptionplan')), serialize=False,
('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), verbose_name="ID",
('user_group', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='user_groups.usergroup')), ),
),
("stripe_id", models.CharField(max_length=100)),
(
"date",
models.DateTimeField(
default=django.utils.timezone.now, editable=False
),
),
("valid", models.BooleanField()),
("annual", models.BooleanField()),
(
"subscription",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="subscriptions.subscriptionplan",
),
),
(
"user",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
(
"user_group",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="user_groups.usergroup",
),
),
], ],
), ),
] ]

View file

@ -1,4 +1,3 @@
from django.db import models from django.db import models
from accounts.models import CustomUser from accounts.models import CustomUser
from user_groups.models import UserGroup from user_groups.models import UserGroup
@ -25,9 +24,11 @@ class SubscriptionPlan(models.Model):
description = models.TextField(max_length=1024, null=True) description = models.TextField(max_length=1024, null=True)
stripe_product_id = models.CharField(max_length=100) stripe_product_id = models.CharField(max_length=100)
annual_price = models.ForeignKey( annual_price = models.ForeignKey(
StripePrice, on_delete=models.SET_NULL, related_name='annual_plan', null=True) StripePrice, on_delete=models.SET_NULL, related_name="annual_plan", null=True
)
monthly_price = models.ForeignKey( monthly_price = models.ForeignKey(
StripePrice, on_delete=models.SET_NULL, related_name='monthly_plan', null=True) StripePrice, on_delete=models.SET_NULL, related_name="monthly_plan", null=True
)
group_exclusive = models.BooleanField(default=False) group_exclusive = models.BooleanField(default=False)
def __str__(self): def __str__(self):
@ -39,11 +40,14 @@ class SubscriptionPlan(models.Model):
class UserSubscription(models.Model): class UserSubscription(models.Model):
user = models.ForeignKey( user = models.ForeignKey(
CustomUser, on_delete=models.CASCADE, blank=True, null=True) CustomUser, on_delete=models.CASCADE, blank=True, null=True
)
user_group = models.ForeignKey( user_group = models.ForeignKey(
UserGroup, on_delete=models.CASCADE, blank=True, null=True) UserGroup, on_delete=models.CASCADE, blank=True, null=True
)
subscription = models.ForeignKey( subscription = models.ForeignKey(
SubscriptionPlan, on_delete=models.SET_NULL, blank=True, null=True) SubscriptionPlan, on_delete=models.SET_NULL, blank=True, null=True
)
stripe_id = models.CharField(max_length=100) stripe_id = models.CharField(max_length=100)
date = models.DateTimeField(default=now, editable=False) date = models.DateTimeField(default=now, editable=False)
valid = models.BooleanField() valid = models.BooleanField()
@ -51,6 +55,6 @@ class UserSubscription(models.Model):
def __str__(self): def __str__(self):
if self.user: if self.user:
return f'Subscription {self.subscription.name} for {self.user}' return f"Subscription {self.subscription.name} for {self.user}"
else: else:
return f'Subscription {self.subscription.name} for {self.user_group}' return f"Subscription {self.subscription.name} for {self.user_group}"

View file

@ -7,38 +7,46 @@ class SimpleStripePriceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = StripePrice model = StripePrice
fields = ['price', 'currency', 'prorated'] fields = ["price", "currency", "prorated"]
class SubscriptionPlanSerializer(serializers.ModelSerializer): class SubscriptionPlanSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = SubscriptionPlan model = SubscriptionPlan
fields = ['id', 'name', 'description', fields = [
'annual_price', 'monthly_price', 'group_exclusive'] "id",
"name",
"description",
"annual_price",
"monthly_price",
"group_exclusive",
]
def to_representation(self, instance): def to_representation(self, instance):
representation = super().to_representation(instance) representation = super().to_representation(instance)
representation['annual_price'] = SimpleStripePriceSerializer( representation["annual_price"] = SimpleStripePriceSerializer(
instance.annual_price, many=False).data instance.annual_price, many=False
representation['monthly_price'] = SimpleStripePriceSerializer( ).data
instance.monthly_price, many=False).data representation["monthly_price"] = SimpleStripePriceSerializer(
instance.monthly_price, many=False
).data
return representation return representation
class UserSubscriptionSerializer(serializers.ModelSerializer): class UserSubscriptionSerializer(serializers.ModelSerializer):
date = serializers.DateTimeField( date = serializers.DateTimeField(format="%m-%d-%Y %I:%M %p", read_only=True)
format="%m-%d-%Y %I:%M %p", read_only=True)
class Meta: class Meta:
model = UserSubscription model = UserSubscription
fields = ['id', 'user', 'user_group', 'subscription', fields = ["id", "user", "user_group", "subscription", "date", "valid", "annual"]
'date', 'valid', 'annual']
def to_representation(self, instance): def to_representation(self, instance):
representation = super().to_representation(instance) representation = super().to_representation(instance)
representation['user'] = SimpleCustomUserSerializer( representation["user"] = SimpleCustomUserSerializer(
instance.user, many=False).data instance.user, many=False
representation['subscription'] = SubscriptionPlanSerializer( ).data
instance.subscription, many=False).data representation["subscription"] = SubscriptionPlanSerializer(
instance.subscription, many=False
).data
return representation return representation

View file

@ -4,6 +4,7 @@ from .models import UserSubscription, StripePrice, SubscriptionPlan
from django.core.cache import cache from django.core.cache import cache
from config.settings import STRIPE_SECRET_KEY from config.settings import STRIPE_SECRET_KEY
import stripe import stripe
stripe.api_key = STRIPE_SECRET_KEY stripe.api_key = STRIPE_SECRET_KEY
# Template for running actions after user have paid for a subscription # Template for running actions after user have paid for a subscription
@ -12,7 +13,7 @@ stripe.api_key = STRIPE_SECRET_KEY
@receiver(post_save, sender=SubscriptionPlan) @receiver(post_save, sender=SubscriptionPlan)
def clear_cache_after_plan_updates(sender, instance, **kwargs): def clear_cache_after_plan_updates(sender, instance, **kwargs):
# Clear cache # Clear cache
cache.delete('subscriptionplans') cache.delete("subscriptionplans")
@receiver(post_save, sender=UserSubscription) @receiver(post_save, sender=UserSubscription)
@ -25,8 +26,8 @@ def scan_after_payment(sender, instance, **kwargs):
@receiver(post_migrate) @receiver(post_migrate)
def create_subscriptions(sender, **kwargs): def create_subscriptions(sender, **kwargs):
if sender.name == 'subscriptions': if sender.name == "subscriptions":
print('Importing data from Stripe') print("Importing data from Stripe")
created_prices = 0 created_prices = 0
created_plans = 0 created_plans = 0
skipped_prices = 0 skipped_prices = 0
@ -35,16 +36,19 @@ def create_subscriptions(sender, **kwargs):
prices = stripe.Price.list(expand=["data.tiers"], active=True) prices = stripe.Price.list(expand=["data.tiers"], active=True)
# Create the StripePrice # Create the StripePrice
for price in prices['data']: for price in prices["data"]:
annual = (price['recurring']['interval'] == annual = (
'year') if price['recurring'] else False (price["recurring"]["interval"] == "year")
if price["recurring"]
else False
)
STRIPE_PRICE, CREATED = StripePrice.objects.get_or_create( STRIPE_PRICE, CREATED = StripePrice.objects.get_or_create(
stripe_price_id=price['id'], stripe_price_id=price["id"],
price=price['unit_amount'] / 100, price=price["unit_amount"] / 100,
annual=annual, annual=annual,
lookup_key=price['lookup_key'], lookup_key=price["lookup_key"],
prorated=price['recurring']['usage_type'] == 'metered', prorated=price["recurring"]["usage_type"] == "metered",
currency=price['currency'] currency=price["currency"],
) )
if CREATED: if CREATED:
created_prices += 1 created_prices += 1
@ -52,13 +56,13 @@ def create_subscriptions(sender, **kwargs):
skipped_prices += 1 skipped_prices += 1
# Create the SubscriptionPlan # Create the SubscriptionPlan
for product in products['data']: for product in products["data"]:
ANNUAL_PRICE = None ANNUAL_PRICE = None
MONTHLY_PRICE = None MONTHLY_PRICE = None
for price in prices['data']: for price in prices["data"]:
if price['product'] == product['id']: if price["product"] == product["id"]:
STRIPE_PRICE = StripePrice.objects.get( STRIPE_PRICE = StripePrice.objects.get(
stripe_price_id=price['id'], stripe_price_id=price["id"],
) )
if STRIPE_PRICE.annual: if STRIPE_PRICE.annual:
ANNUAL_PRICE = STRIPE_PRICE ANNUAL_PRICE = STRIPE_PRICE
@ -66,12 +70,12 @@ def create_subscriptions(sender, **kwargs):
MONTHLY_PRICE = STRIPE_PRICE MONTHLY_PRICE = STRIPE_PRICE
if ANNUAL_PRICE or MONTHLY_PRICE: if ANNUAL_PRICE or MONTHLY_PRICE:
SUBSCRIPTION_PLAN, CREATED = SubscriptionPlan.objects.get_or_create( SUBSCRIPTION_PLAN, CREATED = SubscriptionPlan.objects.get_or_create(
name=product['name'], name=product["name"],
description=product['description'], description=product["description"],
stripe_product_id=product['id'], stripe_product_id=product["id"],
annual_price=ANNUAL_PRICE, annual_price=ANNUAL_PRICE,
monthly_price=MONTHLY_PRICE, monthly_price=MONTHLY_PRICE,
group_exclusive=product['metadata']['group_exclusive'] == 'True' group_exclusive=product["metadata"]["group_exclusive"] == "True",
) )
if CREATED: if CREATED:
created_plans += 1 created_plans += 1
@ -79,13 +83,12 @@ def create_subscriptions(sender, **kwargs):
skipped_plans += 1 skipped_plans += 1
# Skip over plans with missing pricing rates # Skip over plans with missing pricing rates
else: else:
print('Skipping plan' + print("Skipping plan" + product["name"] + "with missing pricing data")
product['name'] + 'with missing pricing data')
# Assign the StripePrice to the SubscriptionPlan # Assign the StripePrice to the SubscriptionPlan
SUBSCRIPTION_PLAN.save() SUBSCRIPTION_PLAN.save()
print('Created', created_plans, 'new plans') print("Created", created_plans, "new plans")
print('Skipped', skipped_plans, 'existing plans') print("Skipped", skipped_plans, "existing plans")
print('Created', created_prices, 'new prices') print("Created", created_prices, "new prices")
print('Skipped', skipped_prices, 'existing prices') print("Skipped", skipped_prices, "existing prices")

View file

@ -12,10 +12,10 @@ def get_user_subscription(user_id):
active_subscriptions = None active_subscriptions = None
if USER.user_group: if USER.user_group:
active_subscriptions = UserSubscription.objects.filter( active_subscriptions = UserSubscription.objects.filter(
user_group=USER.user_group, valid=True) user_group=USER.user_group, valid=True
)
else: else:
active_subscriptions = UserSubscription.objects.filter( active_subscriptions = UserSubscription.objects.filter(user=USER, valid=True)
user=USER, valid=True)
# Return first valid subscription if there is one # Return first valid subscription if there is one
if len(active_subscriptions) > 0: if len(active_subscriptions) > 0:
@ -33,7 +33,8 @@ def get_user_group_subscription(user_group):
# Get a list of subscriptions for the specified user # Get a list of subscriptions for the specified user
active_subscriptions = None active_subscriptions = None
active_subscriptions = UserSubscription.objects.filter( active_subscriptions = UserSubscription.objects.filter(
user_group=USER_GROUP, valid=True) user_group=USER_GROUP, valid=True
)
# Return first valid subscription if there is one # Return first valid subscription if there is one
if len(active_subscriptions) > 0: if len(active_subscriptions) > 0:

View file

@ -3,12 +3,11 @@ from subscriptions import views
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
router = DefaultRouter() router = DefaultRouter()
router.register(r'plans', views.SubscriptionPlanViewset, router.register(r"plans", views.SubscriptionPlanViewset, basename="Subscription Plans")
basename="Subscription Plans") router.register(r"self", views.UserSubscriptionViewset, basename="Self Subscriptions")
router.register(r'self', views.UserSubscriptionViewset, router.register(
basename="Self Subscriptions") r"user_group", views.UserGroupSubscriptionViewet, basename="Group Subscriptions"
router.register(r'user_group', views.UserGroupSubscriptionViewet, )
basename="Group Subscriptions")
urlpatterns = [ urlpatterns = [
path('', include(router.urls)), path("", include(router.urls)),
] ]

View file

@ -1,4 +1,7 @@
from subscriptions.serializers import SubscriptionPlanSerializer, UserSubscriptionSerializer from subscriptions.serializers import (
SubscriptionPlanSerializer,
UserSubscriptionSerializer,
)
from subscriptions.models import SubscriptionPlan, UserSubscription from subscriptions.models import SubscriptionPlan, UserSubscription
from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework import viewsets from rest_framework import viewsets
@ -6,13 +9,13 @@ from django.core.cache import cache
class SubscriptionPlanViewset(viewsets.ModelViewSet): class SubscriptionPlanViewset(viewsets.ModelViewSet):
http_method_names = ['get'] http_method_names = ["get"]
serializer_class = SubscriptionPlanSerializer serializer_class = SubscriptionPlanSerializer
permission_classes = [AllowAny] permission_classes = [AllowAny]
queryset = SubscriptionPlan.objects.all() queryset = SubscriptionPlan.objects.all()
def get_queryset(self): def get_queryset(self):
key = 'subscriptionplans' key = "subscriptionplans"
queryset = cache.get(key) queryset = cache.get(key)
if not queryset: if not queryset:
queryset = super().get_queryset() queryset = super().get_queryset()
@ -21,14 +24,14 @@ class SubscriptionPlanViewset(viewsets.ModelViewSet):
class UserSubscriptionViewset(viewsets.ModelViewSet): class UserSubscriptionViewset(viewsets.ModelViewSet):
http_method_names = ['get'] http_method_names = ["get"]
serializer_class = UserSubscriptionSerializer serializer_class = UserSubscriptionSerializer
permission_classes = [IsAuthenticated] permission_classes = [IsAuthenticated]
queryset = UserSubscription.objects.all() queryset = UserSubscription.objects.all()
def get_queryset(self): def get_queryset(self):
user = self.request.user user = self.request.user
key = f'subscriptions_user:{user.id}' key = f"subscriptions_user:{user.id}"
queryset = cache.get(key) queryset = cache.get(key)
if not queryset: if not queryset:
queryset = UserSubscription.objects.filter(user=user) queryset = UserSubscription.objects.filter(user=user)
@ -37,7 +40,7 @@ class UserSubscriptionViewset(viewsets.ModelViewSet):
class UserGroupSubscriptionViewet(viewsets.ModelViewSet): class UserGroupSubscriptionViewet(viewsets.ModelViewSet):
http_method_names = ['get'] http_method_names = ["get"]
serializer_class = UserSubscriptionSerializer serializer_class = UserSubscriptionSerializer
permission_classes = [IsAuthenticated] permission_classes = [IsAuthenticated]
queryset = UserSubscription.objects.all() queryset = UserSubscription.objects.all()
@ -47,10 +50,9 @@ class UserGroupSubscriptionViewet(viewsets.ModelViewSet):
if not user.user_group: if not user.user_group:
return UserSubscription.objects.none() return UserSubscription.objects.none()
else: else:
key = f'subscriptions_usergroup:{user.user_group.id}' key = f"subscriptions_usergroup:{user.user_group.id}"
queryset = cache.get(key) queryset = cache.get(key)
if not cache: if not cache:
queryset = UserSubscription.objects.filter( queryset = UserSubscription.objects.filter(user_group=user.user_group)
user_group=user.user_group)
cache.set(key, queryset, 60 * 60) cache.set(key, queryset, 60 * 60)
return queryset return queryset

View file

@ -7,9 +7,7 @@ from unfold.contrib.filters.admin import RangeDateFilter
@admin.register(UserGroup) @admin.register(UserGroup)
class UserGroupAdmin(ModelAdmin): class UserGroupAdmin(ModelAdmin):
list_filter_submit = True list_filter_submit = True
list_filter = (( list_filter = (("date_created", RangeDateFilter),)
"date_created", RangeDateFilter
),)
list_display = ['id', 'name'] list_display = ["id", "name"]
search_fields = ['id', 'name'] search_fields = ["id", "name"]

View file

@ -8,16 +8,28 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies = []
]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='UserGroup', name="UserGroup",
fields=[ fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('name', models.CharField(max_length=128)), "id",
('date_created', models.DateTimeField(default=django.utils.timezone.now, editable=False)), models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.CharField(max_length=128)),
(
"date_created",
models.DateTimeField(
default=django.utils.timezone.now, editable=False
),
),
], ],
), ),
] ]

View file

@ -8,24 +8,33 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('user_groups', '0001_initial'), ("user_groups", "0001_initial"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL), migrations.swappable_dependency(settings.AUTH_USER_MODEL),
] ]
operations = [ operations = [
migrations.AddField( migrations.AddField(
model_name='usergroup', model_name="usergroup",
name='managers', name="managers",
field=models.ManyToManyField(related_name='usergroup_managers', to=settings.AUTH_USER_MODEL), field=models.ManyToManyField(
related_name="usergroup_managers", to=settings.AUTH_USER_MODEL
),
), ),
migrations.AddField( migrations.AddField(
model_name='usergroup', model_name="usergroup",
name='members', name="members",
field=models.ManyToManyField(related_name='usergroup_members', to=settings.AUTH_USER_MODEL), field=models.ManyToManyField(
related_name="usergroup_members", to=settings.AUTH_USER_MODEL
),
), ),
migrations.AddField( migrations.AddField(
model_name='usergroup', model_name="usergroup",
name='owner', name="owner",
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='usergroup_owner', to=settings.AUTH_USER_MODEL), field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="usergroup_owner",
to=settings.AUTH_USER_MODEL,
),
), ),
] ]

View file

@ -2,17 +2,24 @@ from django.db import models
from django.utils.timezone import now from django.utils.timezone import now
from config.settings import STRIPE_SECRET_KEY from config.settings import STRIPE_SECRET_KEY
import stripe import stripe
stripe.api_key = STRIPE_SECRET_KEY stripe.api_key = STRIPE_SECRET_KEY
class UserGroup(models.Model): class UserGroup(models.Model):
name = models.CharField(max_length=128, null=False) name = models.CharField(max_length=128, null=False)
owner = models.ForeignKey( owner = models.ForeignKey(
'accounts.CustomUser', on_delete=models.SET_NULL, null=True, related_name='usergroup_owner') "accounts.CustomUser",
on_delete=models.SET_NULL,
null=True,
related_name="usergroup_owner",
)
managers = models.ManyToManyField( managers = models.ManyToManyField(
'accounts.CustomUser', related_name='usergroup_managers') "accounts.CustomUser", related_name="usergroup_managers"
)
members = models.ManyToManyField( members = models.ManyToManyField(
'accounts.CustomUser', related_name='usergroup_members') "accounts.CustomUser", related_name="usergroup_members"
)
date_created = models.DateTimeField(default=now, editable=False) date_created = models.DateTimeField(default=now, editable=False)
# Derived from email of owner, may be used for billing # Derived from email of owner, may be used for billing

View file

@ -3,10 +3,9 @@ from .models import UserGroup
class SimpleUserGroupSerializer(serializers.ModelSerializer): class SimpleUserGroupSerializer(serializers.ModelSerializer):
date_created = serializers.DateTimeField( date_created = serializers.DateTimeField(format="%m-%d-%Y %I:%M %p", read_only=True)
format="%m-%d-%Y %I:%M %p", read_only=True)
class Meta: class Meta:
model = UserGroup model = UserGroup
fields = ['id', 'name', 'date_created'] fields = ["id", "name", "date_created"]
read_only_fields = ['id', 'name', 'date_created'] read_only_fields = ["id", "name", "date_created"]

View file

@ -8,15 +8,16 @@ from config.settings import STRIPE_SECRET_KEY, ROOT_DIR
import os import os
import json import json
import stripe import stripe
stripe.api_key = STRIPE_SECRET_KEY stripe.api_key = STRIPE_SECRET_KEY
@receiver(m2m_changed, sender=UserGroup.managers.through) @receiver(m2m_changed, sender=UserGroup.managers.through)
def update_group_managers(sender, instance, action, **kwargs): def update_group_managers(sender, instance, action, **kwargs):
# When adding new managers to a UserGroup, associate them with it # When adding new managers to a UserGroup, associate them with it
if action == 'post_add': if action == "post_add":
# Get the newly added managers # Get the newly added managers
new_managers = kwargs.get('pk_set', set()) new_managers = kwargs.get("pk_set", set())
for manager in new_managers: for manager in new_managers:
# Retrieve the member # Retrieve the member
USER = CustomUser.objects.get(pk=manager) USER = CustomUser.objects.get(pk=manager)
@ -27,8 +28,8 @@ def update_group_managers(sender, instance, action, **kwargs):
if USER not in instance.members.all(): if USER not in instance.members.all():
instance.members.add(USER) instance.members.add(USER)
# When removing managers from a UserGroup, remove their association with it # When removing managers from a UserGroup, remove their association with it
elif action == 'post_remove': elif action == "post_remove":
for manager in kwargs['pk_set']: for manager in kwargs["pk_set"]:
# Retrieve the manager # Retrieve the manager
USER = CustomUser.objects.get(pk=manager) USER = CustomUser.objects.get(pk=manager)
if USER not in instance.members.all(): if USER not in instance.members.all():
@ -39,9 +40,9 @@ def update_group_managers(sender, instance, action, **kwargs):
@receiver(m2m_changed, sender=UserGroup.members.through) @receiver(m2m_changed, sender=UserGroup.members.through)
def update_group_members(sender, instance, action, **kwargs): def update_group_members(sender, instance, action, **kwargs):
# When adding new members to a UserGroup, associate them with it # When adding new members to a UserGroup, associate them with it
if action == 'post_add': if action == "post_add":
# Get the newly added members # Get the newly added members
new_members = kwargs.get('pk_set', set()) new_members = kwargs.get("pk_set", set())
for member in new_members: for member in new_members:
# Retrieve the member # Retrieve the member
USER = CustomUser.objects.get(pk=member) USER = CustomUser.objects.get(pk=member)
@ -50,10 +51,13 @@ def update_group_members(sender, instance, action, **kwargs):
USER.user_group = instance USER.user_group = instance
USER.save() USER.save()
# When removing members from a UserGroup, remove their association with it # When removing members from a UserGroup, remove their association with it
elif action == 'post_remove': elif action == "post_remove":
for client in kwargs['pk_set']: for client in kwargs["pk_set"]:
USER = CustomUser.objects.get(pk=client) USER = CustomUser.objects.get(pk=client)
if USER not in instance.members.all() and USER not in instance.managers.all(): if (
USER not in instance.members.all()
and USER not in instance.managers.all()
):
USER.user_group = None USER.user_group = None
USER.save() USER.save()
# Update usage records # Update usage records
@ -66,42 +70,42 @@ def update_group_members(sender, instance, action, **kwargs):
stripe.SubscriptionItem.create_usage_record( stripe.SubscriptionItem.create_usage_record(
SUBSCRIPTION_ITEM.stripe_id, SUBSCRIPTION_ITEM.stripe_id,
quantity=len(instance.members.all()), quantity=len(instance.members.all()),
action="set" action="set",
) )
except: except:
print( print(
f'Warning: Unable to update usage record for SubscriptionGroup ID:{instance.id}') f"Warning: Unable to update usage record for SubscriptionGroup ID:{instance.id}"
)
@receiver(post_migrate) @receiver(post_migrate)
def create_groups(sender, **kwargs): def create_groups(sender, **kwargs):
if sender.name == "agencies": if sender.name == "agencies":
with open(os.path.join(ROOT_DIR, 'seed_data.json'), "r") as f: with open(os.path.join(ROOT_DIR, "seed_data.json"), "r") as f:
seed_data = json.loads(f.read()) seed_data = json.loads(f.read())
for user_group in seed_data['user_groups']: for user_group in seed_data["user_groups"]:
OWNER = CustomUser.objects.filter( OWNER = CustomUser.objects.filter(email=user_group["owner"]).first()
email=user_group['owner']).first()
USER_GROUP, CREATED = UserGroup.objects.get_or_create( USER_GROUP, CREATED = UserGroup.objects.get_or_create(
owner=OWNER, owner=OWNER,
agency_name=user_group['name'], agency_name=user_group["name"],
) )
if CREATED: if CREATED:
print(f"Created UserGroup {USER_GROUP.agency_name}") print(f"Created UserGroup {USER_GROUP.agency_name}")
# Add managers # Add managers
USERS = CustomUser.objects.filter( USERS = CustomUser.objects.filter(email__in=user_group["managers"])
email__in=user_group['managers'])
for USER in USERS: for USER in USERS:
if USER not in USER_GROUP.managers.all(): if USER not in USER_GROUP.managers.all():
print( print(
f"Adding User {USER.full_name} as manager to UserGroup {USER_GROUP.agency_name}") f"Adding User {USER.full_name} as manager to UserGroup {USER_GROUP.agency_name}"
)
USER_GROUP.managers.add(USER) USER_GROUP.managers.add(USER)
# Add members # Add members
USERS = CustomUser.objects.filter( USERS = CustomUser.objects.filter(email__in=user_group["members"])
email__in=user_group['members'])
for USER in USERS: for USER in USERS:
if USER not in USER_GROUP.members.all(): if USER not in USER_GROUP.members.all():
print( print(
f"Adding User {USER.full_name} as member to UserGroup {USER_GROUP.agency_name}") f"Adding User {USER.full_name} as member to UserGroup {USER_GROUP.agency_name}"
)
USER_GROUP.clients.add(USER) USER_GROUP.clients.add(USER)
USER_GROUP.save() USER_GROUP.save()

View file

@ -2,5 +2,5 @@ from django.apps import AppConfig
class EmailsConfig(AppConfig): class EmailsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField' default_auto_field = "django.db.models.BigAutoField"
name = 'webdriver' name = "webdriver"

View file

@ -1,11 +1,19 @@
from celery import shared_task from celery import shared_task
from webdriver.utils import setup_webdriver, selenium_action_template, google_search, get_element, get_elements from webdriver.utils import (
setup_webdriver,
selenium_action_template,
google_search,
get_element,
get_elements,
)
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from search_results.tasks import create_search_result from search_results.tasks import create_search_result
# Task template # Task template
@shared_task(autoretry_for=(Exception,), retry_kwargs={'max_retries': 3, 'countdown': 5}) @shared_task(
autoretry_for=(Exception,), retry_kwargs={"max_retries": 3, "countdown": 5}
)
def sample_selenium_task(): def sample_selenium_task():
driver = setup_webdriver(use_proxy=False, use_saved_session=False) driver = setup_webdriver(use_proxy=False, use_saved_session=False)
@ -18,27 +26,29 @@ def sample_selenium_task():
driver.close() driver.close()
driver.quit() driver.quit()
# Sample task to scrape Google for search results based on a keyword # Sample task to scrape Google for search results based on a keyword
@shared_task(autoretry_for=(Exception,), retry_kwargs={'max_retries': 3, 'countdown': 5}) @shared_task(
autoretry_for=(Exception,), retry_kwargs={"max_retries": 3, "countdown": 5}
)
def simple_google_search(): def simple_google_search():
driver = setup_webdriver(driver_type="firefox", driver = setup_webdriver(
use_proxy=False, use_saved_session=False) driver_type="firefox", use_proxy=False, use_saved_session=False
)
driver.get(f"https://google.com/") driver.get(f"https://google.com/")
google_search(driver, search_term="cat blog posts") google_search(driver, search_term="cat blog posts")
# Count number of Google search results # Count number of Google search results
search_items = get_elements( search_items = get_elements(driver, "xpath", '//*[@id="search"]/div[1]/div[1]/*')
driver, "xpath", '//*[@id="search"]/div[1]/div[1]/*')
for item in search_items: for item in search_items:
title = item.find_element(By.TAG_NAME, 'h3').text title = item.find_element(By.TAG_NAME, "h3").text
link = item.find_element(By.TAG_NAME, 'a').get_attribute('href') link = item.find_element(By.TAG_NAME, "a").get_attribute("href")
create_search_result.apply_async( create_search_result.apply_async(kwargs={"title": title, "link": link})
kwargs={"title": title, "link": link})
driver.close() driver.close()
driver.quit() driver.quit()

View file

@ -1,6 +1,7 @@
""" """
Settings file to hold constants and functions Settings file to hold constants and functions
""" """
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys from selenium.webdriver.common.keys import Keys
from config.settings import get_secret from config.settings import get_secret
@ -18,24 +19,26 @@ import os
import random import random
def take_snapshot(driver, filename='dump.png'): def take_snapshot(driver, filename="dump.png"):
# Set window size # Set window size
required_width = driver.execute_script( required_width = driver.execute_script(
'return document.body.parentNode.scrollWidth') "return document.body.parentNode.scrollWidth"
)
required_height = driver.execute_script( required_height = driver.execute_script(
'return document.body.parentNode.scrollHeight') "return document.body.parentNode.scrollHeight"
driver.set_window_size( )
required_width, required_height+(required_height*0.05)) driver.set_window_size(required_width, required_height + (required_height * 0.05))
# Take the snapshot # Take the snapshot
driver.find_element(By.TAG_NAME, driver.find_element(By.TAG_NAME, "body").screenshot(
'body').screenshot('/dumps/'+filename) # avoids any scrollbars "/dumps/" + filename
print('Snapshot saved') ) # avoids any scrollbars
print("Snapshot saved")
def dump_html(driver, filename='dump.html'): def dump_html(driver, filename="dump.html"):
# Save the page source to error.html # Save the page source to error.html
with open(('/dumps/'+filename), 'w', encoding='utf-8') as file: with open(("/dumps/" + filename), "w", encoding="utf-8") as file:
file.write(driver.page_source) file.write(driver.page_source)
@ -44,83 +47,83 @@ def setup_webdriver(driver_type="chrome", use_proxy=True, use_saved_session=Fals
if not USE_PROXY: if not USE_PROXY:
use_proxy = False use_proxy = False
if use_proxy: if use_proxy:
print('Running driver with proxy enabled') print("Running driver with proxy enabled")
else: else:
print('Running driver with proxy disabled') print("Running driver with proxy disabled")
if use_saved_session: if use_saved_session:
print('Running with saved session') print("Running with saved session")
else: else:
print('Running without using saved session') print("Running without using saved session")
if driver_type == "chrome": if driver_type == "chrome":
print('Using Chrome driver') print("Using Chrome driver")
opts = uc.ChromeOptions() opts = uc.ChromeOptions()
if use_saved_session: if use_saved_session:
if os.path.exists("/tmp_chrome_profile"): if os.path.exists("/tmp_chrome_profile"):
print('Existing Chrome ephemeral profile found') print("Existing Chrome ephemeral profile found")
else: else:
print('No existing Chrome ephemeral profile found') print("No existing Chrome ephemeral profile found")
os.system("mkdir /tmp_chrome_profile") os.system("mkdir /tmp_chrome_profile")
if os.path.exists('/chrome'): if os.path.exists("/chrome"):
print('Copying Chrome Profile to ephemeral directory') print("Copying Chrome Profile to ephemeral directory")
# Flush any non-essential cache directories from the existing profile as they may balloon in size overtime # Flush any non-essential cache directories from the existing profile as they may balloon in size overtime
os.system( os.system('rm -rf "/chrome/Selenium Profile/Code Cache/*"')
'rm -rf "/chrome/Selenium Profile/Code Cache/*"')
# Create a copy of the Chrome Profile # Create a copy of the Chrome Profile
os.system("cp -r /chrome/* /tmp_chrome_profile") os.system("cp -r /chrome/* /tmp_chrome_profile")
try: try:
# Remove some items related to file locks # Remove some items related to file locks
os.remove('/tmp_chrome_profile/SingletonLock') os.remove("/tmp_chrome_profile/SingletonLock")
os.remove('/tmp_chrome_profile/SingletonSocket') os.remove("/tmp_chrome_profile/SingletonSocket")
os.remove('/tmp_chrome_profile/SingletonLock') os.remove("/tmp_chrome_profile/SingletonLock")
except: except:
pass pass
else: else:
print('No existing Chrome Profile found. Creating one from scratch') print("No existing Chrome Profile found. Creating one from scratch")
if use_saved_session: if use_saved_session:
# Specify the user data directory # Specify the user data directory
opts.add_argument(f'--user-data-dir=/tmp_chrome_profile') opts.add_argument(f"--user-data-dir=/tmp_chrome_profile")
opts.add_argument('--profile-directory=Selenium Profile') opts.add_argument("--profile-directory=Selenium Profile")
# Set proxy # Set proxy
if use_proxy: if use_proxy:
opts.add_argument( opts.add_argument(
f'--proxy-server=socks5://{get_secret("PROXY_IP")}:{get_secret("PROXY_PORT_IP_AUTH")}') f'--proxy-server=socks5://{get_secret("PROXY_IP")}:{get_secret("PROXY_PORT_IP_AUTH")}'
)
opts.add_argument("--disable-extensions") opts.add_argument("--disable-extensions")
opts.add_argument('--disable-application-cache') opts.add_argument("--disable-application-cache")
opts.add_argument("--disable-setuid-sandbox") opts.add_argument("--disable-setuid-sandbox")
opts.add_argument('--disable-dev-shm-usage') opts.add_argument("--disable-dev-shm-usage")
opts.add_argument("--disable-gpu") opts.add_argument("--disable-gpu")
opts.add_argument("--no-sandbox") opts.add_argument("--no-sandbox")
opts.add_argument("--headless=new") opts.add_argument("--headless=new")
driver = uc.Chrome(options=opts) driver = uc.Chrome(options=opts)
elif driver_type == "firefox": elif driver_type == "firefox":
print('Using firefox driver') print("Using firefox driver")
opts = FirefoxOptions() opts = FirefoxOptions()
if use_saved_session: if use_saved_session:
if not os.path.exists("/firefox"): if not os.path.exists("/firefox"):
print('No profile found') print("No profile found")
os.makedirs("/firefox") os.makedirs("/firefox")
else: else:
print('Existing profile found') print("Existing profile found")
# Specify a profile if it exists # Specify a profile if it exists
opts.profile = "/firefox" opts.profile = "/firefox"
# Set proxy # Set proxy
if use_proxy: if use_proxy:
opts.set_preference('network.proxy.type', 1) opts.set_preference("network.proxy.type", 1)
opts.set_preference('network.proxy.socks', opts.set_preference("network.proxy.socks", get_secret("PROXY_IP"))
get_secret('PROXY_IP')) opts.set_preference(
opts.set_preference('network.proxy.socks_port', "network.proxy.socks_port", int(get_secret("PROXY_PORT_IP_AUTH"))
int(get_secret('PROXY_PORT_IP_AUTH'))) )
opts.set_preference('network.proxy.socks_remote_dns', False) opts.set_preference("network.proxy.socks_remote_dns", False)
opts.add_argument('--disable-dev-shm-usage') opts.add_argument("--disable-dev-shm-usage")
opts.add_argument("--headless") opts.add_argument("--headless")
opts.add_argument("--disable-gpu") opts.add_argument("--disable-gpu")
driver = webdriver.Firefox(options=opts) driver = webdriver.Firefox(options=opts)
@ -128,13 +131,15 @@ def setup_webdriver(driver_type="chrome", use_proxy=True, use_saved_session=Fals
driver.maximize_window() driver.maximize_window()
# Check if proxy is working # Check if proxy is working
driver.get('https://api.ipify.org/') driver.get("https://api.ipify.org/")
body = WebDriverWait(driver, 10).until( body = WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))) EC.presence_of_element_located((By.TAG_NAME, "body"))
)
ip_address = body.text ip_address = body.text
print(f'External IP: {ip_address}') print(f"External IP: {ip_address}")
return driver return driver
# These are wrapper function for quickly automating multiple steps in webscraping (logins, button presses, text inputs, etc.) # These are wrapper function for quickly automating multiple steps in webscraping (logins, button presses, text inputs, etc.)
# Depending on your use case, you may have to opt out of using this # Depending on your use case, you may have to opt out of using this
@ -151,10 +156,11 @@ def get_element(driver, by, key, hidden_element=False, timeout=8):
wait = WebDriverWait(driver, timeout=timeout) wait = WebDriverWait(driver, timeout=timeout)
if not hidden_element: if not hidden_element:
element = wait.until( element = wait.until(
EC.element_to_be_clickable((by, key)) and EC.visibility_of_element_located((by, key))) EC.element_to_be_clickable((by, key))
and EC.visibility_of_element_located((by, key))
)
else: else:
element = wait.until(EC.presence_of_element_located( element = wait.until(EC.presence_of_element_located((by, key)))
(by, key)))
return element return element
except Exception: except Exception:
dump_html(driver) dump_html(driver)
@ -173,13 +179,12 @@ def get_elements(driver, by, key, hidden_element=False, timeout=8):
wait = WebDriverWait(driver, timeout=timeout) wait = WebDriverWait(driver, timeout=timeout)
if hidden_element: if hidden_element:
elements = wait.until( elements = wait.until(EC.presence_of_all_elements_located((by, key)))
EC.presence_of_all_elements_located((by, key)))
else: else:
visible_elements = wait.until( visible_elements = wait.until(
EC.visibility_of_any_elements_located((by, key))) EC.visibility_of_any_elements_located((by, key))
elements = [ )
element for element in visible_elements if element.is_enabled()] elements = [element for element in visible_elements if element.is_enabled()]
return elements return elements
except Exception: except Exception:
@ -193,17 +198,22 @@ def get_elements(driver, by, key, hidden_element=False, timeout=8):
def execute_selenium_elements(driver, timeout, elements): def execute_selenium_elements(driver, timeout, elements):
try: try:
for index, element in enumerate(elements): for index, element in enumerate(elements):
print('Waiting...') print("Waiting...")
# Element may have a keyword specified, check if that exists before running any actions # Element may have a keyword specified, check if that exists before running any actions
if "keyword" in element: if "keyword" in element:
# Skip a step if the keyword does not exist # Skip a step if the keyword does not exist
if element['keyword'] not in driver.page_source: if element["keyword"] not in driver.page_source:
print( print(
f'Keyword {element["keyword"]} does not exist. Skipping step: {index+1} - {element["name"]}') f'Keyword {element["keyword"]} does not exist. Skipping step: {index+1} - {element["name"]}'
)
continue continue
elif element['keyword'] in driver.page_source and element['type'] == 'skip': elif (
element["keyword"] in driver.page_source
and element["type"] == "skip"
):
print( print(
f'Keyword {element["keyword"]} does exists. Stopping at step: {index+1} - {element["name"]}') f'Keyword {element["keyword"]} does exists. Stopping at step: {index+1} - {element["name"]}'
)
break break
print(f'Step: {index+1} - {element["name"]}') print(f'Step: {index+1} - {element["name"]}')
# Revert to default iframe action # Revert to default iframe action
@ -217,31 +227,47 @@ def execute_selenium_elements(driver, timeout, elements):
else: else:
values = element["input"] values = element["input"]
if type(values) is list: if type(values) is list:
raise Exception( raise Exception('Invalid input value specified for "callback" type')
'Invalid input value specified for "callback" type')
else: else:
# For single input values # For single input values
driver.execute_script( driver.execute_script(f'onRecaptcha("{values}");')
f'onRecaptcha("{values}");')
continue continue
try: try:
# Try to get default element # Try to get default element
if "hidden" in element: if "hidden" in element:
site_element = get_element( site_element = get_element(
driver, element["default"]["type"], element["default"]["key"], hidden_element=True, timeout=timeout) driver,
element["default"]["type"],
element["default"]["key"],
hidden_element=True,
timeout=timeout,
)
else: else:
site_element = get_element( site_element = get_element(
driver, element["default"]["type"], element["default"]["key"], timeout=timeout) driver,
element["default"]["type"],
element["default"]["key"],
timeout=timeout,
)
except Exception as e: except Exception as e:
print(f'Failed to find primary element') print(f"Failed to find primary element")
# If that fails, try to get the failover one # If that fails, try to get the failover one
print('Trying to find legacy element') print("Trying to find legacy element")
if "hidden" in element: if "hidden" in element:
site_element = get_element( site_element = get_element(
driver, element["failover"]["type"], element["failover"]["key"], hidden_element=True, timeout=timeout) driver,
element["failover"]["type"],
element["failover"]["key"],
hidden_element=True,
timeout=timeout,
)
else: else:
site_element = get_element( site_element = get_element(
driver, element["failover"]["type"], element["failover"]["key"], timeout=timeout) driver,
element["failover"]["type"],
element["failover"]["key"],
timeout=timeout,
)
# Clicking an element # Clicking an element
if element["type"] == "click": if element["type"] == "click":
site_element.click() site_element.click()
@ -272,11 +298,13 @@ def execute_selenium_elements(driver, timeout, elements):
values = element["input"] values = element["input"]
if type(values) is list: if type(values) is list:
raise Exception( raise Exception(
'Invalid input value specified for "input_replace" type') 'Invalid input value specified for "input_replace" type'
)
else: else:
# For single input values # For single input values
driver.execute_script( driver.execute_script(
f'arguments[0].value = "{values}";', site_element) f'arguments[0].value = "{values}";', site_element
)
except Exception as e: except Exception as e:
take_snapshot(driver) take_snapshot(driver)
dump_html(driver) dump_html(driver)
@ -285,30 +313,33 @@ def execute_selenium_elements(driver, timeout, elements):
raise Exception(e) raise Exception(e)
def solve_captcha(site_key, url, retry_attempts=3, version='v2', enterprise=False, use_proxy=True): def solve_captcha(
site_key, url, retry_attempts=3, version="v2", enterprise=False, use_proxy=True
):
# Manual proxy override set via $ENV # Manual proxy override set via $ENV
if not USE_PROXY: if not USE_PROXY:
use_proxy = False use_proxy = False
if CAPTCHA_TESTING: if CAPTCHA_TESTING:
print('Initializing CAPTCHA solver in dummy mode') print("Initializing CAPTCHA solver in dummy mode")
code = random.randint() code = random.randint()
print("CAPTCHA Successful") print("CAPTCHA Successful")
return code return code
elif use_proxy: elif use_proxy:
print('Using CAPTCHA solver with proxy') print("Using CAPTCHA solver with proxy")
else: else:
print('Using CAPTCHA solver without proxy') print("Using CAPTCHA solver without proxy")
captcha_params = { captcha_params = {
"url": url, "url": url,
"sitekey": site_key, "sitekey": site_key,
"version": version, "version": version,
"enterprise": 1 if enterprise else 0, "enterprise": 1 if enterprise else 0,
"proxy": { "proxy": (
'type': 'socks5', {"type": "socks5", "uri": get_secret("PROXY_USER_AUTH")}
'uri': get_secret('PROXY_USER_AUTH') if use_proxy
} if use_proxy else None else None
),
} }
# Keep retrying until max attempts is reached # Keep retrying until max attempts is reached
@ -316,12 +347,12 @@ def solve_captcha(site_key, url, retry_attempts=3, version='v2', enterprise=Fals
# Solver uses 2CAPTCHA by default # Solver uses 2CAPTCHA by default
solver = TwoCaptcha(get_secret("CAPTCHA_API_KEY")) solver = TwoCaptcha(get_secret("CAPTCHA_API_KEY"))
try: try:
print('Waiting for CAPTCHA code...') print("Waiting for CAPTCHA code...")
code = solver.recaptcha(**captcha_params)["code"] code = solver.recaptcha(**captcha_params)["code"]
print("CAPTCHA Successful") print("CAPTCHA Successful")
return code return code
except Exception as e: except Exception as e:
print(f'CAPTCHA Failed! {e}') print(f"CAPTCHA Failed! {e}")
raise Exception(f"CAPTCHA API Failed!") raise Exception(f"CAPTCHA API Failed!")
@ -339,13 +370,12 @@ def save_browser_session(driver):
# Copy over the profile once we finish logging in # Copy over the profile once we finish logging in
if isinstance(driver, webdriver.Firefox): if isinstance(driver, webdriver.Firefox):
# Copy process for Firefox # Copy process for Firefox
print('Updating saved Firefox profile') print("Updating saved Firefox profile")
# Get the current profile directory from about:support page # Get the current profile directory from about:support page
driver.get("about:support") driver.get("about:support")
box = get_element( box = get_element(driver, "id", "profile-dir-box", timeout=4)
driver, "id", "profile-dir-box", timeout=4)
temp_profile_path = os.path.join(os.getcwd(), box.text) temp_profile_path = os.path.join(os.getcwd(), box.text)
profile_path = '/firefox' profile_path = "/firefox"
# Create the command # Create the command
copy_command = "cp -r " + temp_profile_path + "/* " + profile_path copy_command = "cp -r " + temp_profile_path + "/* " + profile_path
# Copy over the Firefox profile # Copy over the Firefox profile
@ -353,13 +383,13 @@ def save_browser_session(driver):
print("Firefox profile saved") print("Firefox profile saved")
elif isinstance(driver, uc.Chrome): elif isinstance(driver, uc.Chrome):
# Copy the Chrome profile # Copy the Chrome profile
print('Updating non-ephemeral Chrome profile') print("Updating non-ephemeral Chrome profile")
# Flush Code Cache again to speed up copy # Flush Code Cache again to speed up copy
os.system( os.system('rm -rf "/tmp_chrome_profile/SimpleDMCA Profile/Code Cache/*"')
'rm -rf "/tmp_chrome_profile/SimpleDMCA Profile/Code Cache/*"')
if os.system("cp -r /tmp_chrome_profile/* /chrome"): if os.system("cp -r /tmp_chrome_profile/* /chrome"):
print("Chrome profile saved") print("Chrome profile saved")
# Sample function # Sample function
# Call this within a Celery task # Call this within a Celery task
# TODO: Modify as needed to your needs # TODO: Modify as needed to your needs
@ -370,7 +400,7 @@ def selenium_action_template(driver):
info = { info = {
"sample_field1": "sample_data", "sample_field1": "sample_data",
"sample_field2": "sample_data", "sample_field2": "sample_data",
"captcha_code": lambda: solve_captcha('SITE_KEY', 'SITE_URL') "captcha_code": lambda: solve_captcha("SITE_KEY", "SITE_URL"),
} }
elements = [ elements = [
@ -382,13 +412,10 @@ def selenium_action_template(driver):
"default": { "default": {
# See get_element() for possible selector types # See get_element() for possible selector types
"type": "xpath", "type": "xpath",
"key": '' "key": "",
}, },
# If a site implements canary design releases, you can place the ID for the element in the old design here # If a site implements canary design releases, you can place the ID for the element in the old design here
"failover": { "failover": {"type": "xpath", "key": ""},
"type": "xpath",
"key": ''
}
}, },
] ]
@ -398,8 +425,8 @@ def selenium_action_template(driver):
# Fill in final fstring values in elements # Fill in final fstring values in elements
for element in elements: for element in elements:
if 'input' in element and '{' in element['input']: if "input" in element and "{" in element["input"]:
a = element['input'].strip('{}') a = element["input"].strip("{}")
if a in info: if a in info:
value = info[a] value = info[a]
# Check if the value is a callable (a lambda function) and call it if so # Check if the value is a callable (a lambda function) and call it if so
@ -411,11 +438,12 @@ def selenium_action_template(driver):
# Use the stored value # Use the stored value
value = site_form_values[a] value = site_form_values[a]
# Replace the placeholder with the actual value # Replace the placeholder with the actual value
element['input'] = str(value) element["input"] = str(value)
# Execute the selenium actions # Execute the selenium actions
execute_selenium_elements(driver, 8, elements) execute_selenium_elements(driver, 8, elements)
# Sample task for Google search # Sample task for Google search
@ -429,40 +457,28 @@ def google_search(driver, search_term):
"name": "Type in search term", "name": "Type in search term",
"type": "input", "type": "input",
"input": "{search_term}", "input": "{search_term}",
"default": { "default": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
"type": "xpath", "failover": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
"key": '//*[@id="APjFqb"]'
},
"failover": {
"type": "xpath",
"key": '//*[@id="APjFqb"]'
}
}, },
{ {
"name": "Press enter", "name": "Press enter",
"type": "input_enter", "type": "input_enter",
"default": { "default": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
"type": "xpath", "failover": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
"key": '//*[@id="APjFqb"]'
},
"failover": {
"type": "xpath",
"key": '//*[@id="APjFqb"]'
}
}, },
] ]
site_form_values = {} site_form_values = {}
for element in elements: for element in elements:
if 'input' in element and '{' in element['input']: if "input" in element and "{" in element["input"]:
a = element['input'].strip('{}') a = element["input"].strip("{}")
if a in info: if a in info:
value = info[a] value = info[a]
if callable(value): if callable(value):
if a not in site_form_values: if a not in site_form_values:
site_form_values[a] = value() site_form_values[a] = value()
value = site_form_values[a] value = site_form_values[a]
element['input'] = str(value) element["input"] = str(value)
execute_selenium_elements(driver, 8, elements) execute_selenium_elements(driver, 8, elements)

View file

@ -5,7 +5,7 @@ services:
build: build:
context: . context: .
dockerfile: Dockerfile dockerfile: Dockerfile
image: drf_template:latest image: drf_template
ports: ports:
- "${BACKEND_PORT}:${BACKEND_PORT}" - "${BACKEND_PORT}:${BACKEND_PORT}"
environment: environment:
@ -23,7 +23,7 @@ services:
env_file: .env env_file: .env
environment: environment:
- RUN_TYPE=worker - RUN_TYPE=worker
image: drf_template:latest image: drf_template
volumes: volumes:
- .:/code - .:/code
- ./chrome:/chrome - ./chrome:/chrome
@ -42,7 +42,7 @@ services:
env_file: .env env_file: .env
environment: environment:
- RUN_TYPE=beat - RUN_TYPE=beat
image: drf_template:latest image: drf_template
volumes: volumes:
- .:/code - .:/code
depends_on: depends_on:
@ -58,7 +58,7 @@ services:
env_file: .env env_file: .env
environment: environment:
- RUN_TYPE=monitor - RUN_TYPE=monitor
image: drf_template:latest image: drf_template
ports: ports:
- "${CELERY_FLOWER_PORT}:5555" - "${CELERY_FLOWER_PORT}:5555"
volumes: volumes: